diff --git a/include/aidge/backend/cuda/operator/ConvImpl.hpp b/include/aidge/backend/cuda/operator/ConvImpl.hpp index 939199862b7c745dafe27226c45e08005539a301..31b80adbae0602211fa5c11873875a1a10eb40db 100644 --- a/include/aidge/backend/cuda/operator/ConvImpl.hpp +++ b/include/aidge/backend/cuda/operator/ConvImpl.hpp @@ -27,15 +27,9 @@ #include "aidge/backend/cuda/utils/CudaUtils.hpp" namespace Aidge { -// class Conv_Op; - template <DimIdx_t DIM> class ConvImpl_cuda : public OperatorImpl { private: - const Conv_Op<DIM> &mOp; - std::array<NbElts_t, 3> mNbConsumedData = {0, 0, 0}; - std::array<NbElts_t, 1> mNbProducedData = {0}; - // CuDNN specific variables cudnnConvolutionDescriptor_t mConvDesc = nullptr; cudnnFilterDescriptor_t mFilterDesc = nullptr; @@ -44,24 +38,14 @@ private: void* mWorkspace = nullptr; public: - ConvImpl_cuda(const Conv_Op<DIM> &op) : mOp(op) {} + ConvImpl_cuda(const Conv_Op<DIM> &op) : OperatorImpl(op) {} static std::unique_ptr<ConvImpl_cuda> create(const Conv_Op<2> &op) { return std::make_unique<ConvImpl_cuda>(op); } public: - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; - NbElts_t getRequiredMemory(const IOIndex_t /*outputIdx*/, const std::vector<DimSize_t> &/*inputsSize*/) const override final; - NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final; - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; - void updateConsummerProducer() override final; - void forward(); - - void backward(); - ~ConvImpl_cuda(); private: diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index e21bc92c2b30300eb8ec64461d55ce0c599f52f4..515f5f19d7702ea5bc037b672e182e97800a703b 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -22,56 +22,6 @@ #include "aidge/backend/cuda/operator/ConvImpl.hpp" #include "aidge/backend/cuda/utils/CudaContext.hpp" -template <Aidge::DimIdx_t DIM> -Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { - assert(mOp.getInput(inputIdx) && "requires valid input"); - - // Requires the whole tensors - const auto &inputDims = std::static_pointer_cast<Tensor>(mOp.getInput(inputIdx))->dims(); - - return std::accumulate(inputDims.begin(), inputDims.end(), Aidge::NbElts_t(1), std::multiplies<NbElts_t>()); -} - -template <Aidge::DimIdx_t DIM> -Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // for the direct convolution algorithm, convolutions can be in-place, if - // there is no padding! - return 0; -} - -template <Aidge::DimIdx_t DIM> -Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getRequiredMemory(const Aidge::IOIndex_t outputIdx, - const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { - // Requires the whole tensors, regardless of available data on inputs - assert(outputIdx == 0 && "operator has only one output"); - (void) outputIdx; - - const auto &outputDims = std::static_pointer_cast<Tensor>(mOp.getOutput(0))->dims(); - return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>()); -} - -template <Aidge::DimIdx_t DIM> -Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbConsumedData(Aidge::IOIndex_t inputIdx) const { - assert(static_cast<std::size_t>(inputIdx) < mNbConsumedData.size()); - return mNbConsumedData[static_cast<std::size_t>(inputIdx)]; -} - -template <Aidge::DimIdx_t DIM> -Aidge::NbElts_t Aidge::ConvImpl_cuda<DIM>::getNbProducedData(Aidge::IOIndex_t outputIdx) const { - assert((outputIdx == 0) && (static_cast<std::size_t>(outputIdx) < mNbProducedData.size())); - return mNbProducedData[static_cast<std::size_t>(outputIdx)]; -} - -template <Aidge::DimIdx_t DIM> -void Aidge::ConvImpl_cuda<DIM>::updateConsummerProducer(){ - // Update producer-consumer data - for (std::size_t inputIdx = 0; inputIdx < mNbConsumedData.size(); ++inputIdx) - mNbConsumedData[inputIdx] += getNbRequiredData(static_cast<IOIndex_t>(inputIdx)); // each input is consumed by the minimum - // amount for a forward pass - - mNbProducedData[0] += getRequiredMemory(0, {}); -} - template <Aidge::DimIdx_t DIM> void Aidge::ConvImpl_cuda<DIM>::forward() { // FIXME: uncomment the following code once memory handling will work @@ -80,9 +30,10 @@ void Aidge::ConvImpl_cuda<DIM>::forward() { // Lazy-initialize CuDNN convolution descriptor if (mConvDesc == nullptr) { - const std::vector<int> strides(mOp.template getAttr<ConvAttr::StrideDims>().begin(), mOp.template getAttr<ConvAttr::StrideDims>().end()); + const Conv_Op<DIM>& convOp = static_cast<const Conv_Op<DIM>&>(mOp); + const std::vector<int> strides(convOp.template getAttr<ConvAttr::StrideDims>().begin(), convOp.template getAttr<ConvAttr::StrideDims>().end()); const std::vector<int> paddings(DIM, 0); - const std::vector<int> upscales(mOp.template getAttr<ConvAttr::DilationDims>().begin(), mOp.template getAttr<ConvAttr::DilationDims>().end()); + const std::vector<int> upscales(convOp.template getAttr<ConvAttr::DilationDims>().begin(), convOp.template getAttr<ConvAttr::DilationDims>().end()); CHECK_CUDNN_STATUS(cudnnCreateConvolutionDescriptor(&mConvDesc)); CHECK_CUDNN_STATUS( @@ -215,9 +166,6 @@ Aidge::ConvImpl_cuda<DIM>::~ConvImpl_cuda() { } } -template <Aidge::DimIdx_t DIM> -void Aidge::ConvImpl_cuda<DIM>::backward() { printf("Not implemented yet.\n"); } - // Template declarations template class Aidge::ConvImpl_cuda<2>;