Skip to content
Snippets Groups Projects
Commit 74886e26 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Updated op impl with new default impl

parent 6d5e20df
No related branches found
No related tags found
1 merge request!3Update with default operator impl
Pipeline #33161 failed
...@@ -27,15 +27,9 @@ ...@@ -27,15 +27,9 @@
#include "aidge/backend/cuda/utils/CudaUtils.hpp" #include "aidge/backend/cuda/utils/CudaUtils.hpp"
namespace Aidge { namespace Aidge {
// class Conv_Op;
template <DimIdx_t DIM> template <DimIdx_t DIM>
class ConvImpl_cuda : public OperatorImpl { class ConvImpl_cuda : public OperatorImpl {
private: private:
const Conv_Op<DIM> &mOp;
std::array<NbElts_t, 3> mNbConsumedData = {0, 0, 0};
std::array<NbElts_t, 1> mNbProducedData = {0};
// CuDNN specific variables // CuDNN specific variables
cudnnConvolutionDescriptor_t mConvDesc = nullptr; cudnnConvolutionDescriptor_t mConvDesc = nullptr;
cudnnFilterDescriptor_t mFilterDesc = nullptr; cudnnFilterDescriptor_t mFilterDesc = nullptr;
...@@ -44,24 +38,14 @@ private: ...@@ -44,24 +38,14 @@ private:
void* mWorkspace = nullptr; void* mWorkspace = nullptr;
public: public:
ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) {} ConvImpl_cuda(const Conv_Op<DIM> &op) : OperatorImpl(op) {}
static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) { static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) {
return std::make_unique<ConvImpl_cuda>(op); return std::make_unique<ConvImpl_cuda>(op);
} }
public: public:
NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final;
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
NbElts_t getRequiredMemory(const IOIndex_t /*outputIdx*/, const std::vector<DimSize_t> &/*inputsSize*/) const override final;
NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final;
NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
void updateConsummerProducer() override final;
void forward(); void forward();
void backward();
~ConvImpl_cuda(); ~ConvImpl_cuda();
private: private:
......
...@@ -22,56 +22,6 @@ ...@@ -22,56 +22,6 @@
#include "aidge/backend/cuda/operator/ConvImpl.hpp" #include "aidge/backend/cuda/operator/ConvImpl.hpp"
#include "aidge/backend/cuda/utils/CudaContext.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp"
template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const {
assert(mOp.getInput(inputIdx) && "requires valid input");
// Requires the whole tensors
const auto &inputDims = std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->dims();
return std::accumulate(inputDims.begin(), inputDims.end(), Aidge::NbElts_t(1), std::multiplies<NbElts_t>());
}
template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const {
// for the direct convolution algorithm, convolutions can be in-place, if
// there is no padding!
return 0;
}
template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getRequiredMemory(const Aidge::IOIndex_t outputIdx,
const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const {
// Requires the whole tensors, regardless of available data on inputs
assert(outputIdx == 0 && "operator has only one output");
(void) outputIdx;
const auto &outputDims = std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dims();
return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>());
}
template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbConsumedData(Aidge::IOIndex_t inputIdx) const {
assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size());
return mNbConsumedData[static_cast<std::size_t>(inputIdx)];
}
template <Aidge::DimIdx_t DIM>
Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbProducedData(Aidge::IOIndex_t outputIdx) const {
assert((outputIdx == 0) && (static_cast<std::size_t>(outputIdx) < mNbProducedData.size()));
return mNbProducedData[static_cast<std::size_t>(outputIdx)];
}
template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::updateConsummerProducer(){
// Update producer-consumer data
for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx)
mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx)); // each input is consumed by the minimum
// amount for a forward pass
mNbProducedData[0] += getRequiredMemory(0, {});
}
template <Aidge::DimIdx_t DIM> template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::forward() { void Aidge::ConvImpl_cuda<DIM>::forward() {
// FIXME: uncomment the following code once memory handling will work // FIXME: uncomment the following code once memory handling will work
...@@ -215,9 +165,6 @@ Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() { ...@@ -215,9 +165,6 @@ Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() {
} }
} }
template <Aidge::DimIdx_t DIM>
void Aidge::ConvImpl_cuda<DIM>::backward() { printf("Not implemented yet.\n"); }
// Template declarations // Template declarations
template class Aidge::ConvImpl_cuda<2>; template class Aidge::ConvImpl_cuda<2>;
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment