Skip to content
Snippets Groups Projects

Update with default operator impl

Merged Olivier BICHLER requested to merge simpl_op_impl into main
2 files
+ 4
72
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -27,15 +27,9 @@
@@ -27,15 +27,9 @@
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
#include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge {
namespace Aidge {
// class Conv_Op;
template <DimIdx_t DIM>
template <DimIdx_t DIM>
class ConvImpl_cuda : public OperatorImpl {
class ConvImpl_cuda : public OperatorImpl {
private:
private:
const Conv_Op<DIM> &mOp;
std::array<NbElts_t, 3> mNbConsumedData = {0, 0, 0};
std::array<NbElts_t, 1> mNbProducedData = {0};
// CuDNN specific variables
// CuDNN specific variables
cudnnConvolutionDescriptor_t mConvDesc = nullptr;
cudnnConvolutionDescriptor_t mConvDesc = nullptr;
cudnnFilterDescriptor_t mFilterDesc = nullptr;
cudnnFilterDescriptor_t mFilterDesc = nullptr;
@@ -44,24 +38,14 @@ private:
@@ -44,24 +38,14 @@ private:
void* mWorkspace = nullptr;
void* mWorkspace = nullptr;
public:
public:
ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) {}
ConvImpl_cuda(const Conv_Op<DIM> &op) : OperatorImpl(op) {}
static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) {
static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) {
return std::make_unique<ConvImpl_cuda>(op);
return std::make_unique<ConvImpl_cuda>(op);
}
}
public:
public:
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
NbElts_t getRequiredMemory(const IOIndex_t /*outputIdx*/, const std::vector<DimSize_t> &/*inputsSize*/) const override final;
NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward();
void forward();
void backward();
~ConvImpl_cuda();
~ConvImpl_cuda();
private:
private:
Loading