diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index 46dfae3d53b4b201507290bd538ea13737919c3e..6c92a925a592b69e7dc7b70c38f0f5a363d88601 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -41,7 +41,7 @@ public: std::size_t i = 0; for (; i < mNbElts && - *(mData.data()+i) == *static_cast<const T*>(typedOtherImpl.rawPtr(i)); + *static_cast<const T*>(rawPtr(i)) == *static_cast<const T*>(typedOtherImpl.rawPtr(i)); ++i) { } return i == mNbElts; diff --git a/include/aidge/backend/cpu/operator/MemorizeImpl.hpp b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp index c003e7b5757740f4282e7a39300ccc118558b1c0..6569478001189b60795f21cf618c77c65aeefbfb 100644 --- a/include/aidge/backend/cpu/operator/MemorizeImpl.hpp +++ b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp @@ -30,7 +30,8 @@ public: } NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; + NbElts_t getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const override final; void updateConsummerProducer() override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/PopImpl.hpp b/include/aidge/backend/cpu/operator/PopImpl.hpp index 6f4e583a70ddb046f6176484cf2d67688beaa49d..86c20349d5554e400c15a6e3488cb547f86abee2 100644 --- a/include/aidge/backend/cpu/operator/PopImpl.hpp +++ b/include/aidge/backend/cpu/operator/PopImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<PopImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/src/operator/MemorizeImpl.cpp b/src/operator/MemorizeImpl.cpp index 64cb3bcf237acd0aea706d8635eb4ab5e1b947b1..b2956231ec29784158ea27c68d4ec21a8c4ccc64 100644 --- a/src/operator/MemorizeImpl.cpp +++ b/src/operator/MemorizeImpl.cpp @@ -41,18 +41,19 @@ Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbRequiredData( } } -Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbProducedData( - Aidge::IOIndex_t outputIdx) const -{ +Aidge::NbElts_t Aidge::MemorizeImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx, + const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { + assert(mOp.getRawOutput(outputIdx) && "requires valid output"); + const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); - if (outputIdx == 1 && scheduleStep >= endStep) { - return endStep * std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(); + if (endStep > 0 && outputIdx == 1 && scheduleStep >= endStep) { + return 0; } else { - return scheduleStep * std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->size(); + return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size(); } } @@ -62,14 +63,14 @@ void Aidge::MemorizeImpl_cpu::updateConsummerProducer() { const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); const unsigned int scheduleStep = op.template getAttr<MemorizeAttr::ScheduleStep>(); const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); - AIDGE_ASSERT(scheduleStep <= endStep, "cannot update consumer producer anymore, number of cycles exceeded"); + AIDGE_ASSERT(endStep == 0 || scheduleStep <= endStep, "cannot update consumer producer anymore, number of cycles exceeded"); } void Aidge::MemorizeImpl_cpu::forward() { const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); const unsigned int forwardStep = op.template getAttr<MemorizeAttr::ForwardStep>(); const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); - AIDGE_ASSERT(forwardStep <= endStep, "cannot forward anymore, number of cycles exceeded"); + AIDGE_ASSERT(endStep == 0 || forwardStep <= endStep, "cannot forward anymore, number of cycles exceeded"); if (forwardStep == 0) { op.getOutput(0)->getImpl()->copy(op.getInput(1)->getImpl()->rawPtr(), op.getInput(1)->size()); diff --git a/src/operator/PopImpl.cpp b/src/operator/PopImpl.cpp index 76c0a495f38673e61360184b157dc37c47b966b5..86850610c75f827d9c29e6a0506397c5a844cb00 100644 --- a/src/operator/PopImpl.cpp +++ b/src/operator/PopImpl.cpp @@ -21,9 +21,11 @@ #include "aidge/backend/cpu/operator/PopImpl.hpp" -Aidge::NbElts_t Aidge::PopImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; +Aidge::NbElts_t Aidge::PopImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { + assert(mOp.getRawInput(inputIdx) && "requires valid input"); + + return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size() + / std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims()[0]; } void Aidge::PopImpl_cpu::forward() {