Skip to content
Snippets Groups Projects

Refactor OperatorImpl for backend/export

Merged Olivier BICHLER requested to merge backend_export into dev
32 files
+ 445
359
Compare changes
  • Side-by-side
  • Inline
Files
32
@@ -27,30 +27,33 @@
@@ -27,30 +27,33 @@
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
namespace Aidge {
 
// Operator implementation entry point for the backend
class AddImpl_cuda : public OperatorImpl {
class AddImpl_cuda : public OperatorImpl {
private:
public:
public:
AddImpl_cuda(const Add_Op &op) : OperatorImpl(op, "cuda") {}
AddImpl_cuda(const Add_Op& op) : OperatorImpl(op, "cuda") {}
static std::unique_ptr<AddImpl_cuda> create(const Add_Op &op) {
static std::unique_ptr<AddImpl_cuda> create(const Add_Op& op) {
return std::make_unique<AddImpl_cuda>(op);
return std::make_unique<AddImpl_cuda>(op);
}
}
public:
virtual std::set<ImplSpec> getAvailableImplSpecs() const override {
void forward();
return {
void backward();
{DataType::Float64},
// ~AddImpl_cuda();
{DataType::Float32},
 
{DataType::Float16},
 
};
 
}
 
 
void forward() override;
 
void backward() override;
 
private:
private:
template <class T> void forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
template <class T> void forward_(const std::vector<Tensor>& inputs, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
template <class T> void backward_(const Tensor& outGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
template <class T> void backward_(const Tensor& outGrad, const std::vector<std::vector<int>>& inputsDims, const std::vector<std::vector<int>>& inputsStrides);
};
};
namespace {
// Implementation entry point registration to Operator
// add cuda backend to Add_Op implementation registry
REGISTRAR(Add_Op, "cuda", Aidge::AddImpl_cuda::create);
static Registrar<Add_Op> registrarAddImpl_cuda("cuda", Aidge::AddImpl_cuda::create);
} // namespace
} // namespace Aidge
} // namespace Aidge
#endif /* AIDGE_BACKEND_CUDA_OPERATOR_ADDIMPL_H_ */
#endif /* AIDGE_BACKEND_CUDA_OPERATOR_ADDIMPL_H_ */
Loading