diff --git a/include/aidge/backend/cpu/operator/AddImpl.hpp b/include/aidge/backend/cpu/operator/AddImpl.hpp index 5ec33e9767914034b92731fb12630f7994d34ac7..806bbb02d760dbdec58df137641d4c211443039e 100644 --- a/include/aidge/backend/cpu/operator/AddImpl.hpp +++ b/include/aidge/backend/cpu/operator/AddImpl.hpp @@ -31,13 +31,8 @@ class AddImplBackward_cpu class AddImpl_cpu : public OperatorImpl { -private: - const Add_Op& mOp; - std::vector<NbElts_t> mNbConsumedData; - std::array<NbElts_t, 1> mNbProducedData = {}; - public: - AddImpl_cpu(const Add_Op& op) : mOp(op), mNbConsumedData(std::vector<NbElts_t>(op.nbInputs())) {} + AddImpl_cpu(const Add_Op& op) : OperatorImpl(op) {} static std::unique_ptr<AddImpl_cpu> create(const Add_Op& op) { return std::make_unique<AddImpl_cpu>(op); diff --git a/include/aidge/backend/cpu/operator/ConcatImpl.hpp b/include/aidge/backend/cpu/operator/ConcatImpl.hpp index 880a2e6635ba8ab2ded6f934bf6eb2e3f6d38d5b..a5e0c56e856217509445844be2ae5631bad05728 100644 --- a/include/aidge/backend/cpu/operator/ConcatImpl.hpp +++ b/include/aidge/backend/cpu/operator/ConcatImpl.hpp @@ -39,13 +39,8 @@ class ConcatImplBackward_cpu class ConcatImpl_cpu : public OperatorImpl { -private: - const Concat_Op& mOp; - std::vector<NbElts_t> mNbConsumedData; - std::array<NbElts_t, 1> mNbProducedData = {}; - public: - ConcatImpl_cpu(const Concat_Op& op) : mOp(op), mNbConsumedData(std::vector<NbElts_t>(op.nbInputs())) {} + ConcatImpl_cpu(const Concat_Op& op) : OperatorImpl(op) {} static std::unique_ptr<ConcatImpl_cpu> create(const Concat_Op& op) { return std::make_unique<ConcatImpl_cpu>(op); diff --git a/include/aidge/backend/cpu/operator/SliceImpl.hpp b/include/aidge/backend/cpu/operator/SliceImpl.hpp index dddab386005114c67e72e8db4e1d56eebc11b08e..69dd88bb5e6e463b1122bc46b294f3220c88a905 100644 --- a/include/aidge/backend/cpu/operator/SliceImpl.hpp +++ b/include/aidge/backend/cpu/operator/SliceImpl.hpp @@ -43,13 +43,8 @@ class SliceImplBackward_cpu template <DimIdx_t DIM> class SliceImpl_cpu : public OperatorImpl { - private: - const Slice_Op<DIM>& mOp; - std::array<NbElts_t, 1> mNbConsumedData; - std::array<NbElts_t, 1> mNbProducedData; - public: - SliceImpl_cpu(const Slice_Op<DIM>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} + SliceImpl_cpu(const Slice_Op<DIM>& op) : OperatorImpl(op) {} static std::unique_ptr<SliceImpl_cpu<DIM>> create(const Slice_Op<DIM>& op) { return std::make_unique<SliceImpl_cpu<DIM>>(op); @@ -57,10 +52,10 @@ class SliceImpl_cpu : public OperatorImpl { public: NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final { - assert(mOp.getInput(0) && "requires valid input"); + assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input"); // Requires the whole tensors - const auto& inputDims = mOp.getInput(0)->dims(); + const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(); return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), std::multiplies<NbElts_t>()); @@ -70,7 +65,7 @@ class SliceImpl_cpu : public OperatorImpl { const std::vector<DimSize_t>& inputsSize) const override final { (void)outputIdx; (void)inputsSize; - const auto& outputDims = mOp.getOutput(0)->dims(); + const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(); return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), std::multiplies<NbElts_t>()); } @@ -89,17 +84,17 @@ class SliceImpl_cpu : public OperatorImpl { void forward() { // FIXME: uncomment the following code once memory handling will work - assert(mOp.getInput(0) && "missing input #0"); + assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); // Find the correct kernel type auto kernelFunc = Registrar<SliceImplForward_cpu<DIM>>::create( - {mOp.getInput(0)->dataType()}); + {std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType()}); // Call kernel - kernelFunc(mOp.getInput(0)->template dims<DIM>(), - std::get<1>(mOp.getStaticAttributes()), - mOp.getInput(0)->getImpl()->rawPtr(), - mOp.getOutput(0)->getImpl()->rawPtr() + kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<DIM>(), + std::get<1>(std::static_pointer_cast<const Slice_Op<DIM>&>(mOp).getStaticAttributes()), + std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), + std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr() ); // each input is consumed by the minimum amount for a forward pass @@ -115,19 +110,14 @@ class SliceImpl_cpu : public OperatorImpl { template <> class SliceImpl_cpu<1> : public OperatorImpl { - private: - const Slice_Op<1>& mOp; - std::array<NbElts_t, 1> mNbConsumedData; - std::array<NbElts_t, 1> mNbProducedData; - - public: - SliceImpl_cpu(const Slice_Op<1>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} +public: + SliceImpl_cpu(const Slice_Op<1>& op) : OperatorImpl(op) {} static std::unique_ptr<SliceImpl_cpu<1>> create(const Slice_Op<1>& op) { return std::make_unique<SliceImpl_cpu<1>>(op); } - public: +public: NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; NbElts_t getRequiredMemory(const IOIndex_t outputIdx, @@ -144,13 +134,8 @@ class SliceImpl_cpu<1> : public OperatorImpl { template <> class SliceImpl_cpu<2> : public OperatorImpl { - private: - const Slice_Op<2>& mOp; - std::array<NbElts_t, 1> mNbConsumedData; - std::array<NbElts_t, 1> mNbProducedData; - public: - SliceImpl_cpu(const Slice_Op<2>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} + SliceImpl_cpu(const Slice_Op<2>& op) : OperatorImpl(op) {} static std::unique_ptr<SliceImpl_cpu<2>> create(const Slice_Op<2>& op) { return std::make_unique<SliceImpl_cpu<2>>(op); @@ -173,13 +158,8 @@ class SliceImpl_cpu<2> : public OperatorImpl { template <> class SliceImpl_cpu<3> : public OperatorImpl { - private: - const Slice_Op<3>& mOp; - std::array<NbElts_t, 1> mNbConsumedData; - std::array<NbElts_t, 1> mNbProducedData; - public: - SliceImpl_cpu(const Slice_Op<3>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} + SliceImpl_cpu(const Slice_Op<3>& op) : OperatorImpl(op) {} static std::unique_ptr<SliceImpl_cpu<3>> create(const Slice_Op<3>& op) { return std::make_unique<SliceImpl_cpu<3>>(op); @@ -202,13 +182,8 @@ class SliceImpl_cpu<3> : public OperatorImpl { template <> class SliceImpl_cpu<4> : public OperatorImpl { - private: - const Slice_Op<4>& mOp; - std::array<NbElts_t, 1> mNbConsumedData; - std::array<NbElts_t, 1> mNbProducedData; - public: - SliceImpl_cpu(const Slice_Op<4>& op) : mOp(op), mNbConsumedData({0}), mNbProducedData({0}) {} + SliceImpl_cpu(const Slice_Op<4>& op) : OperatorImpl(op) {} static std::unique_ptr<SliceImpl_cpu<4>> create(const Slice_Op<4>& op) { return std::make_unique<SliceImpl_cpu<4>>(op);