diff --git a/include/aidge/backend/cpu/operator/AddImpl.hpp b/include/aidge/backend/cpu/operator/AddImpl.hpp index 57669c628b4fa650f137c2b28c8c0a4584bf6c35..6cb72e9100b1437afa13a23cb5933e77aabaaae8 100644 --- a/include/aidge/backend/cpu/operator/AddImpl.hpp +++ b/include/aidge/backend/cpu/operator/AddImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<AddImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/AvgPoolingImpl.hpp b/include/aidge/backend/cpu/operator/AvgPoolingImpl.hpp index bfb2b1947281fc30e38fd1fe1663bd5de415d3ee..38dbd4b528b4f0fbd24f7f8d2b53e7ea16bae5d0 100644 --- a/include/aidge/backend/cpu/operator/AvgPoolingImpl.hpp +++ b/include/aidge/backend/cpu/operator/AvgPoolingImpl.hpp @@ -44,7 +44,7 @@ public: return std::make_unique<AvgPoolingImpl2D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/BatchNormImpl.hpp b/include/aidge/backend/cpu/operator/BatchNormImpl.hpp index a599aeb7b427161eb7541829242820c0306d0d31..92797ab09148f16255851f4bf51d7c62b7bd6f70 100644 --- a/include/aidge/backend/cpu/operator/BatchNormImpl.hpp +++ b/include/aidge/backend/cpu/operator/BatchNormImpl.hpp @@ -59,7 +59,7 @@ public: return std::make_unique<BatchNormImpl2D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/ConcatImpl.hpp b/include/aidge/backend/cpu/operator/ConcatImpl.hpp index d0d3e06365c524da1af485583dda6d6208ef3fb9..02d52c850a5a3e628980fcc7502ffab8aa166e17 100644 --- a/include/aidge/backend/cpu/operator/ConcatImpl.hpp +++ b/include/aidge/backend/cpu/operator/ConcatImpl.hpp @@ -48,18 +48,6 @@ public: } public: - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; - - NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; - - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, const std::vector<DimSize_t>& /*inputsSize*/) const override final; - - NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const override final; - - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; - - void updateConsummerProducer() override final; - void forward() override; void backward() override; diff --git a/include/aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp b/include/aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp index f72890d8903ca4a9876809759587ed4b1ac22e67..44bc5da3fa752d9fd52e43366099d20de35d866e 100644 --- a/include/aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp +++ b/include/aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp @@ -46,7 +46,7 @@ public: return std::make_unique<ConvDepthWiseImpl2D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/ConvImpl.hpp b/include/aidge/backend/cpu/operator/ConvImpl.hpp index 9bc2f27412f388a7fd03db06ac97c612044fab5f..2915210dbdeb9b32aca006a171efbca9ccc288b5 100644 --- a/include/aidge/backend/cpu/operator/ConvImpl.hpp +++ b/include/aidge/backend/cpu/operator/ConvImpl.hpp @@ -47,7 +47,7 @@ class ConvImpl2D_cpu : public OperatorImpl { } public: - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/DivImpl.hpp b/include/aidge/backend/cpu/operator/DivImpl.hpp index 710e288d8e0f95b69a2f4973679f1195e6d9cb6a..6bedf627548f63cf14626c69bf91fbd8c9434784 100644 --- a/include/aidge/backend/cpu/operator/DivImpl.hpp +++ b/include/aidge/backend/cpu/operator/DivImpl.hpp @@ -40,7 +40,7 @@ public: return std::make_unique<DivImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override final; }; diff --git a/include/aidge/backend/cpu/operator/ErfImpl.hpp b/include/aidge/backend/cpu/operator/ErfImpl.hpp index 5c0a6fd49f4e2d435eed8e8baa979f59dbd84e68..517eab354a7f44f1d4c7ebbc33efe12edd4159d1 100644 --- a/include/aidge/backend/cpu/operator/ErfImpl.hpp +++ b/include/aidge/backend/cpu/operator/ErfImpl.hpp @@ -38,7 +38,7 @@ public: return std::make_unique<ErfImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/GatherImpl.hpp b/include/aidge/backend/cpu/operator/GatherImpl.hpp index 1d235ff14ca01955c268a7b061e6ecb7b2bbbb2a..28c9a31db337977405b66cbca61d950160679fa1 100644 --- a/include/aidge/backend/cpu/operator/GatherImpl.hpp +++ b/include/aidge/backend/cpu/operator/GatherImpl.hpp @@ -38,7 +38,6 @@ public: return std::make_unique<GatherImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp b/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp index 4a1da034935e6b1f6c2069b4f91153b77a9f0636..b60143dba18a11f7521f265ca0816984b67c6920 100644 --- a/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp +++ b/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<LeakyReLUImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/MaxPoolingImpl.hpp b/include/aidge/backend/cpu/operator/MaxPoolingImpl.hpp index 6cde34d9b123b4f83cbfce412ffa62e0144af8d4..675f3c4a030a4f668da63fd10f9dc91d39e524dd 100644 --- a/include/aidge/backend/cpu/operator/MaxPoolingImpl.hpp +++ b/include/aidge/backend/cpu/operator/MaxPoolingImpl.hpp @@ -44,7 +44,7 @@ public: return std::make_unique<MaxPoolingImpl2D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/MemorizeImpl.hpp b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp index 6569478001189b60795f21cf618c77c65aeefbfb..af571a0cd49f80dd6c9a3abf87dae4ba586af5c4 100644 --- a/include/aidge/backend/cpu/operator/MemorizeImpl.hpp +++ b/include/aidge/backend/cpu/operator/MemorizeImpl.hpp @@ -29,8 +29,8 @@ public: return std::make_unique<MemorizeImpl_cpu>(op); } - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; - NbElts_t getRequiredMemory(const Aidge::IOIndex_t outputIdx, + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; + Elts_t getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const override final; void updateConsummerProducer() override final; void forward() override; diff --git a/include/aidge/backend/cpu/operator/MulImpl.hpp b/include/aidge/backend/cpu/operator/MulImpl.hpp index a6f63ba284baf4cc12190d6b96a89f0baa821c95..6773b6f42497977679e5b6590c699aaf877bc3fc 100644 --- a/include/aidge/backend/cpu/operator/MulImpl.hpp +++ b/include/aidge/backend/cpu/operator/MulImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<MulImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/PadImpl.hpp b/include/aidge/backend/cpu/operator/PadImpl.hpp index 2320662710f9802878811e51ec4439bd812aea67..41032c7220411b29de828763499c8bb751805369 100644 --- a/include/aidge/backend/cpu/operator/PadImpl.hpp +++ b/include/aidge/backend/cpu/operator/PadImpl.hpp @@ -46,7 +46,7 @@ public: return std::make_unique<PadImpl2D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/PopImpl.hpp b/include/aidge/backend/cpu/operator/PopImpl.hpp index 86c20349d5554e400c15a6e3488cb547f86abee2..d7e484a509c05e5d0e2796542d6a0a8d5acdd3a7 100644 --- a/include/aidge/backend/cpu/operator/PopImpl.hpp +++ b/include/aidge/backend/cpu/operator/PopImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<PopImpl_cpu>(op); } - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredData(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/PowImpl.hpp b/include/aidge/backend/cpu/operator/PowImpl.hpp index c6e4cd36746141d7f1d1092c9bd45af41d8a9173..7d17b370dd53817fd5ed61cd21d527e2850d0125 100644 --- a/include/aidge/backend/cpu/operator/PowImpl.hpp +++ b/include/aidge/backend/cpu/operator/PowImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<PowImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/ReLUImpl.hpp b/include/aidge/backend/cpu/operator/ReLUImpl.hpp index 3338d0c40c057995fe37b1652966241bf4a96b59..d8f8272ff09ec49924fe47825f56ee72faf4a644 100644 --- a/include/aidge/backend/cpu/operator/ReLUImpl.hpp +++ b/include/aidge/backend/cpu/operator/ReLUImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<ReLUImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/ReduceMeanImpl.hpp b/include/aidge/backend/cpu/operator/ReduceMeanImpl.hpp index 9b85eb812caffca3820a711d46775e1134db863f..3c0fe6370ffdcdeff0702d9dbdff64b8297fd61e 100644 --- a/include/aidge/backend/cpu/operator/ReduceMeanImpl.hpp +++ b/include/aidge/backend/cpu/operator/ReduceMeanImpl.hpp @@ -64,7 +64,6 @@ class ReduceMeanImpl1D_cpu : public OperatorImpl { } public: - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; @@ -77,7 +76,6 @@ class ReduceMeanImpl2D_cpu : public OperatorImpl { } public: - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; @@ -90,7 +88,6 @@ class ReduceMeanImpl3D_cpu : public OperatorImpl { } public: - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; namespace { diff --git a/include/aidge/backend/cpu/operator/ReshapeImpl.hpp b/include/aidge/backend/cpu/operator/ReshapeImpl.hpp index d5754b34e952d52b2071744e9f8e863074ef9fa3..0a8b851fd8acf14c35434887d054d530eb1228bc 100644 --- a/include/aidge/backend/cpu/operator/ReshapeImpl.hpp +++ b/include/aidge/backend/cpu/operator/ReshapeImpl.hpp @@ -38,7 +38,7 @@ public: return std::make_unique<ReshapeImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/ScalingImpl.hpp b/include/aidge/backend/cpu/operator/ScalingImpl.hpp index bbcb4553d7aa4b17d733e0f455373bebb9c3581c..29b61704f6acd85db1c635547e17f5f002e620f0 100644 --- a/include/aidge/backend/cpu/operator/ScalingImpl.hpp +++ b/include/aidge/backend/cpu/operator/ScalingImpl.hpp @@ -40,7 +40,7 @@ public: return std::make_unique<ScalingImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/SigmoidImpl.hpp b/include/aidge/backend/cpu/operator/SigmoidImpl.hpp index 8678a5a56500ec9e37689df7a37ae72bfb3f74d4..a34650d6326331320c69befc790752cb4023e0ba 100644 --- a/include/aidge/backend/cpu/operator/SigmoidImpl.hpp +++ b/include/aidge/backend/cpu/operator/SigmoidImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<SigmoidImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/SliceImpl.hpp b/include/aidge/backend/cpu/operator/SliceImpl.hpp index 1cba5906064c51a4f0da2f1f3682b0828a080d43..e129c2e680cbe9bf12ec97c347768e73b7775cf0 100644 --- a/include/aidge/backend/cpu/operator/SliceImpl.hpp +++ b/include/aidge/backend/cpu/operator/SliceImpl.hpp @@ -46,14 +46,6 @@ public: return std::make_unique<SliceImpl_cpu>(op); } - NbElts_t getNbRequiredData(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getNbRequiredProtected(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getRequiredMemory(const IOIndex_t outputIdx, - const std::vector<DimSize_t>& inputsSize) const override final; - NbElts_t getNbConsumedData(const IOIndex_t /*inputIdx*/) const override final; - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final; - void updateConsummerProducer() override final; - void forward() override; void backward() override; }; diff --git a/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp b/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp index 005b52f646f9e9ddf14af09cc22d9e2a44ba6dd4..5625f7de7d65577c6829a1def514f8f69824dc9d 100644 --- a/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp +++ b/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<SoftmaxImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/SqrtImpl.hpp b/include/aidge/backend/cpu/operator/SqrtImpl.hpp index b3723f27b077b9d5ea7e69fd33bd012d02654ffe..f1848bde355c7b71e92395ef3901a69e7dca766f 100644 --- a/include/aidge/backend/cpu/operator/SqrtImpl.hpp +++ b/include/aidge/backend/cpu/operator/SqrtImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<SqrtImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/SubImpl.hpp b/include/aidge/backend/cpu/operator/SubImpl.hpp index b329ec6eb0ed7f450b62cdbe289d69acf4f4edc4..a9006a04bc3690429532c1a6b9cc76f9ef32880e 100644 --- a/include/aidge/backend/cpu/operator/SubImpl.hpp +++ b/include/aidge/backend/cpu/operator/SubImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<SubImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/TanhImpl.hpp b/include/aidge/backend/cpu/operator/TanhImpl.hpp index 3e88a3d00b5829fc24d8dc77ce53cb358551c7e4..b477d0bd0ee5434d942dfa1057968fa904300dde 100644 --- a/include/aidge/backend/cpu/operator/TanhImpl.hpp +++ b/include/aidge/backend/cpu/operator/TanhImpl.hpp @@ -39,7 +39,7 @@ public: return std::make_unique<TanhImpl_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/TransposeImpl.hpp b/include/aidge/backend/cpu/operator/TransposeImpl.hpp index 712e672752648f5ff8a3c073f6c81bbe7cc85d9d..a1b9d274d2c14064ed9305b5d6c969dfa544b26b 100644 --- a/include/aidge/backend/cpu/operator/TransposeImpl.hpp +++ b/include/aidge/backend/cpu/operator/TransposeImpl.hpp @@ -63,7 +63,6 @@ public: return std::make_unique<TransposeImpl2D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; class TransposeImpl3D_cpu : public OperatorImpl { @@ -74,7 +73,6 @@ public: return std::make_unique<TransposeImpl3D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; class TransposeImpl4D_cpu : public OperatorImpl { @@ -85,7 +83,6 @@ public: return std::make_unique<TransposeImpl4D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; class TransposeImpl5D_cpu : public OperatorImpl { @@ -96,7 +93,6 @@ public: return std::make_unique<TransposeImpl5D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; class TransposeImpl6D_cpu : public OperatorImpl { @@ -107,7 +103,6 @@ public: return std::make_unique<TransposeImpl6D_cpu>(op); } - NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; void forward() override; }; diff --git a/src/operator/AddImpl.cpp b/src/operator/AddImpl.cpp index 7355ebcb3e8fb68bf74dbd1ce831bf471d285cb7..98de9188dad5539275bba9ae7961153099fb1b9f 100644 --- a/src/operator/AddImpl.cpp +++ b/src/operator/AddImpl.cpp @@ -21,9 +21,9 @@ #include "aidge/backend/cpu/operator/AddImpl.hpp" #include "aidge/backend/cpu/operator/AddImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::AddImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::AddImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::AddImpl_cpu::forward() { diff --git a/src/operator/AvgPoolingImpl.cpp b/src/operator/AvgPoolingImpl.cpp index 9e0a77e3285c1e3701142828c74898cb9da5b405..8ba6751bf4068a69ed07e362924f59d0f4aca6c5 100644 --- a/src/operator/AvgPoolingImpl.cpp +++ b/src/operator/AvgPoolingImpl.cpp @@ -21,9 +21,9 @@ #include "aidge/backend/cpu/operator/AvgPoolingImpl.hpp" #include "aidge/backend/cpu/operator/AvgPoolingImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::AvgPoolingImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::AvgPoolingImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::AvgPoolingImpl2D_cpu::forward() { diff --git a/src/operator/BatchNormImpl.cpp b/src/operator/BatchNormImpl.cpp index c84f2cb6b09c707f68ed83cc7554624fc6489b84..96179d11850624f831333c9a4badaddf2221ecff 100644 --- a/src/operator/BatchNormImpl.cpp +++ b/src/operator/BatchNormImpl.cpp @@ -20,9 +20,9 @@ #include "aidge/backend/cpu/operator/BatchNormImpl.hpp" #include "aidge/backend/cpu/operator/BatchNormImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::BatchNormImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::BatchNormImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::BatchNormImpl2D_cpu::forward() { diff --git a/src/operator/ConcatImpl.cpp b/src/operator/ConcatImpl.cpp index e142b79a8aad5a99a65fdf38de630f3b5668c804..605f4a19ff3856924593b0e6d7815d5de1579c01 100644 --- a/src/operator/ConcatImpl.cpp +++ b/src/operator/ConcatImpl.cpp @@ -21,46 +21,6 @@ #include "aidge/backend/cpu/operator/ConcatImpl.hpp" #include "aidge/backend/cpu/operator/ConcatImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { - assert(mOp.getRawInput(inputIdx) && "requires valid input"); - - // Requires the whole tensors - const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims(); - return std::accumulate(inputDims.begin(), inputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>()); -} - -Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { - // for the direct convolution algorithm, convolutions can be in-place, if there is no padding! - return 0; -} - -Aidge::NbElts_t Aidge::ConcatImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t>& /*inputsSize*/) const { - // Requires the whole tensors, regardless of available data on inputs - assert(outputIdx == 0 && "operator has only one output"); - (void) outputIdx; - - const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(); - return std::accumulate(outputDims.begin(), outputDims.end(), NbElts_t(1), std::multiplies<NbElts_t>()); -} - -Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbConsumedData(const Aidge::IOIndex_t inputIdx) const { - assert(inputIdx < mNbConsumedData.size()); - return mNbConsumedData[inputIdx]; -} - -Aidge::NbElts_t Aidge::ConcatImpl_cpu::getNbProducedData(const Aidge::IOIndex_t outputIdx) const { - assert(outputIdx < mNbProducedData.size()); - return mNbProducedData[outputIdx]; -} - -void Aidge::ConcatImpl_cpu::updateConsummerProducer() { - for (IOIndex_t inputIdx = 0; static_cast<NbElts_t>(inputIdx) < mNbConsumedData.size(); ++inputIdx) - mNbConsumedData[inputIdx]+= getNbRequiredData(inputIdx); // each input is consumed by the minimum amount for a forward pass - - mNbProducedData[0]+= getRequiredMemory(0, {}); - -} - void Aidge::ConcatImpl_cpu::forward() { assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input in Concat operator"); DataType datatypeFirstInput = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(); diff --git a/src/operator/ConvDepthWiseImpl.cpp b/src/operator/ConvDepthWiseImpl.cpp index 1b4262e394f78ab0bda4a36440ac7b9cb15c164c..5c8d2fe307c70bd7ee3f64e14735417f7ffb0c67 100644 --- a/src/operator/ConvDepthWiseImpl.cpp +++ b/src/operator/ConvDepthWiseImpl.cpp @@ -22,9 +22,9 @@ #include "aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp" #include "aidge/backend/cpu/operator/ConvDepthWiseImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::ConvDepthWiseImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::ConvDepthWiseImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::ConvDepthWiseImpl2D_cpu::forward() { diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index b849142dd3abe0131fb0c6c448530a7669ce27dc..3beb2bcf72ed9e318733dce9e69d41c61bf11e5b 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -22,9 +22,9 @@ #include "aidge/backend/cpu/operator/ConvImpl.hpp" #include "aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::ConvImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::ConvImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::ConvImpl2D_cpu::forward() { diff --git a/src/operator/DivImpl.cpp b/src/operator/DivImpl.cpp index 729aff2452b46f00eb6d3e0b558c0b3d58ea2f0e..bfb2ae643a02d67ea1a289d0383b816b5a6ad110 100644 --- a/src/operator/DivImpl.cpp +++ b/src/operator/DivImpl.cpp @@ -19,9 +19,9 @@ #include "aidge/data/Tensor.hpp" #include "aidge/utils/Types.h" -Aidge::NbElts_t Aidge::DivImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::DivImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::DivImpl_cpu::forward() { diff --git a/src/operator/ErfImpl.cpp b/src/operator/ErfImpl.cpp index 06ec65008aee41215192cd05e126ac4f82388c1b..1e6d2766f49a0a0b65c1cdb974f42d2865ae59f5 100644 --- a/src/operator/ErfImpl.cpp +++ b/src/operator/ErfImpl.cpp @@ -21,9 +21,9 @@ #include "aidge/backend/cpu/operator/ErfImpl.hpp" #include "aidge/backend/cpu/operator/ErfImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::ErfImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::ErfImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::ErfImpl_cpu::forward() { diff --git a/src/operator/GatherImpl.cpp b/src/operator/GatherImpl.cpp index ce98627d95e0d05541db1ccaf4896abe756431b0..523cc0365884cb0496a46eb550aa90fa6f4c421d 100644 --- a/src/operator/GatherImpl.cpp +++ b/src/operator/GatherImpl.cpp @@ -21,11 +21,6 @@ #include "aidge/backend/cpu/operator/GatherImpl.hpp" #include "aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} - void Aidge::GatherImpl_cpu::forward() { auto kernelFunc = Registrar<GatherImplForward_cpu>::create({ diff --git a/src/operator/LeakyReLUImpl.cpp b/src/operator/LeakyReLUImpl.cpp index 17912eb1dc75930eaf7595eb189af39df4d4fa2e..7d41163e6c0dc3e1bc7a4ca3075520243aac6958 100644 --- a/src/operator/LeakyReLUImpl.cpp +++ b/src/operator/LeakyReLUImpl.cpp @@ -22,9 +22,9 @@ #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp" #include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::LeakyReLUImpl_cpu::forward() { diff --git a/src/operator/MaxPoolingImpl.cpp b/src/operator/MaxPoolingImpl.cpp index e21dab07df4c20eb7253e680146042f205bc210b..94591eaa9848b24aeb7afa1e8b6b87a3e6e2b45f 100644 --- a/src/operator/MaxPoolingImpl.cpp +++ b/src/operator/MaxPoolingImpl.cpp @@ -21,9 +21,9 @@ #include "aidge/backend/cpu/operator/MaxPoolingImpl.hpp" #include "aidge/backend/cpu/operator/MaxPoolingImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::MaxPoolingImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::MaxPoolingImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::MaxPoolingImpl2D_cpu::forward() { diff --git a/src/operator/MemorizeImpl.cpp b/src/operator/MemorizeImpl.cpp index b2956231ec29784158ea27c68d4ec21a8c4ccc64..8a23bd35585c03c91567c0da5b0727fe1323b754 100644 --- a/src/operator/MemorizeImpl.cpp +++ b/src/operator/MemorizeImpl.cpp @@ -21,7 +21,7 @@ #include "aidge/backend/cpu/operator/MemorizeImpl.hpp" -Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbRequiredData( +Aidge::Elts_t Aidge::MemorizeImpl_cpu::getNbRequiredData( Aidge::IOIndex_t inputIdx) const { const Memorize_Op& op = dynamic_cast<const Memorize_Op&>(mOp); @@ -30,18 +30,18 @@ Aidge::DimSize_t Aidge::MemorizeImpl_cpu::getNbRequiredData( if (scheduleStep == 0 && inputIdx == 0) { // No data input is required for the initial step. // Initialization data is required however. - return 0; + return Elts_t::NoneElts(); } else if (scheduleStep > 0 && inputIdx == 1) { // No initialization data is required after the initial step. - return 0; + return Elts_t::NoneElts(); } else { return OperatorImpl::getNbRequiredData(inputIdx); } } -Aidge::NbElts_t Aidge::MemorizeImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx, +Aidge::Elts_t Aidge::MemorizeImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx, const std::vector<Aidge::DimSize_t> &/*inputsSize*/) const { assert(mOp.getRawOutput(outputIdx) && "requires valid output"); @@ -50,10 +50,10 @@ Aidge::NbElts_t Aidge::MemorizeImpl_cpu::getRequiredMemory(const Aidge::IOIndex_ const unsigned int endStep = op.template getAttr<MemorizeAttr::EndStep>(); if (endStep > 0 && outputIdx == 1 && scheduleStep >= endStep) { - return 0; + return Elts_t::NoneElts(); } else { - return std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size(); + return Elts_t::DataElts(std::static_pointer_cast<Tensor>(mOp.getRawOutput(outputIdx))->size()); } } diff --git a/src/operator/MulImpl.cpp b/src/operator/MulImpl.cpp index 87d180b013e44a49cb887ce722533c50206f3889..d7feb9b76e25a0e874b3682cdc5b3e53bf8e9228 100644 --- a/src/operator/MulImpl.cpp +++ b/src/operator/MulImpl.cpp @@ -23,9 +23,9 @@ #include "aidge/backend/cpu/operator/MulImpl.hpp" #include "aidge/backend/cpu/operator/MulImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::MulImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::MulImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::MulImpl_cpu::forward() { diff --git a/src/operator/PadImpl.cpp b/src/operator/PadImpl.cpp index 219bf425fa34cdaaa378c49dd7c9837f9d94d97e..cd420a6241723c5d3fa5836838f84ce6bfe965d1 100644 --- a/src/operator/PadImpl.cpp +++ b/src/operator/PadImpl.cpp @@ -22,7 +22,7 @@ #include "aidge/backend/cpu/operator/PadImpl.hpp" #include "aidge/backend/cpu/operator/PadImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::PadImpl2D_cpu::getNbRequiredProtected(IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::PadImpl2D_cpu::getNbRequiredProtected(IOIndex_t inputIdx) const { assert(inputIdx == 0 && "operator has only one input"); (void) inputIdx; @@ -30,7 +30,7 @@ Aidge::NbElts_t Aidge::PadImpl2D_cpu::getNbRequiredProtected(IOIndex_t inputIdx) // We must ensure that we do not override data that has not been consummed yet. const auto inputSize = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(); const auto outputSize = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size(); - return (outputSize - inputSize); + return Elts_t::DataElts(outputSize - inputSize); } void Aidge::PadImpl2D_cpu::forward() { diff --git a/src/operator/PopImpl.cpp b/src/operator/PopImpl.cpp index 86850610c75f827d9c29e6a0506397c5a844cb00..02bbddbaed6d9d89e729d6c778a1765fcbab4b4f 100644 --- a/src/operator/PopImpl.cpp +++ b/src/operator/PopImpl.cpp @@ -21,11 +21,11 @@ #include "aidge/backend/cpu/operator/PopImpl.hpp" -Aidge::NbElts_t Aidge::PopImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { +Aidge::Elts_t Aidge::PopImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { assert(mOp.getRawInput(inputIdx) && "requires valid input"); - return std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size() - / std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims()[0]; + return Elts_t::DataElts(std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->size() + / std::static_pointer_cast<Tensor>(mOp.getRawInput(inputIdx))->dims()[0]); } void Aidge::PopImpl_cpu::forward() { diff --git a/src/operator/PowImpl.cpp b/src/operator/PowImpl.cpp index 22b4e27afd4e327c42be066bf7eeb6effdd8b2a9..782ca35706b5fd28e376f97651c847492b9bf755 100644 --- a/src/operator/PowImpl.cpp +++ b/src/operator/PowImpl.cpp @@ -23,9 +23,9 @@ #include "aidge/backend/cpu/operator/PowImpl.hpp" #include "aidge/backend/cpu/operator/PowImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::PowImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::PowImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::PowImpl_cpu::forward() { diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp index 8863be282ce0c7b7bfbfb938372cf304bc4cc4bd..81d1639daf02e8fac0bff3bc30de482b4f0a76d8 100644 --- a/src/operator/ReLUImpl.cpp +++ b/src/operator/ReLUImpl.cpp @@ -22,9 +22,9 @@ #include "aidge/backend/cpu/operator/ReLUImpl.hpp" #include "aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::ReLUImpl_cpu::forward() { diff --git a/src/operator/ReduceMeanImpl.cpp b/src/operator/ReduceMeanImpl.cpp index e31a53d84947e5b2ced14ee9ee6e2badaef07071..324daa9ea2cf49ad15bde0d6c41c6bbcd7eb0c45 100644 --- a/src/operator/ReduceMeanImpl.cpp +++ b/src/operator/ReduceMeanImpl.cpp @@ -20,18 +20,6 @@ #include "aidge/backend/cpu/operator/ReduceMeanImpl.hpp" #include "aidge/backend/cpu/operator/ReduceMeanImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::ReduceMeanImpl1D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} -Aidge::NbElts_t Aidge::ReduceMeanImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} -Aidge::NbElts_t Aidge::ReduceMeanImpl3D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} void Aidge::ReduceMeanImpl1D_cpu::forward() { diff --git a/src/operator/ReshapeImpl.cpp b/src/operator/ReshapeImpl.cpp index 02dea1da3d4422abf37b62193bba83e83c87a83f..8cd71c4ed65a808b573736f13c4f64f61b2e4795 100644 --- a/src/operator/ReshapeImpl.cpp +++ b/src/operator/ReshapeImpl.cpp @@ -17,9 +17,9 @@ #include "aidge/backend/cpu/operator/ReshapeImpl.hpp" #include "aidge/backend/cpu/operator/ReshapeImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::ReshapeImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::ReshapeImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::ReshapeImpl_cpu::forward() { diff --git a/src/operator/ScalingImpl.cpp b/src/operator/ScalingImpl.cpp index 6b9aab31a9d61d2d7a5ff89961de3fa6a2b5ebd2..d0b58702c73f01fb62114d335f5c2342908542ea 100644 --- a/src/operator/ScalingImpl.cpp +++ b/src/operator/ScalingImpl.cpp @@ -21,9 +21,9 @@ #include "aidge/backend/cpu/data/GetCPUPtr.h" #include <vector> -Aidge::NbElts_t Aidge::ScalingImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::ScalingImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::ScalingImpl_cpu::forward() { diff --git a/src/operator/SigmoidImpl.cpp b/src/operator/SigmoidImpl.cpp index 7322e08ba01bfb931382cf17691e705dfaeeb6c1..dd7ec26cb36777f79d382c815b60d2381544a0bd 100644 --- a/src/operator/SigmoidImpl.cpp +++ b/src/operator/SigmoidImpl.cpp @@ -22,9 +22,9 @@ #include "aidge/backend/cpu/operator/SigmoidImpl.hpp" #include "aidge/backend/cpu/operator/SigmoidImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::SigmoidImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::SigmoidImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::SigmoidImpl_cpu::forward() { diff --git a/src/operator/SliceImpl.cpp b/src/operator/SliceImpl.cpp index c1a6480c1e7c0d681abef12f06a57e140d1e9efd..47b13c4694cea22421811c889b5627e9f1362ac0 100644 --- a/src/operator/SliceImpl.cpp +++ b/src/operator/SliceImpl.cpp @@ -22,42 +22,6 @@ #include <cassert> #include <tuple> -Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbRequiredData(const Aidge::IOIndex_t /*inputIdx*/) const { - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "requires valid input"); - - // Requires the whole tensors - const auto& inputDims = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(); - - return std::accumulate(inputDims.begin(), inputDims.end(), static_cast<NbElts_t>(1), - std::multiplies<NbElts_t>()); -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { return 0; } - -Aidge::NbElts_t Aidge::SliceImpl_cpu::getRequiredMemory(const Aidge::IOIndex_t outputIdx, - const std::vector<Aidge::DimSize_t>& inputsSize) const { - (void)outputIdx; - (void)inputsSize; - const auto& outputDims = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(); - return std::accumulate(outputDims.begin(), outputDims.end(), static_cast<NbElts_t>(1), - std::multiplies<NbElts_t>()); -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbConsumedData(const Aidge::IOIndex_t /*inputIdx*/) const { - return mNbConsumedData[0]; -} - -Aidge::NbElts_t Aidge::SliceImpl_cpu::getNbProducedData(const Aidge::IOIndex_t /*outputIdx*/) const { - return mNbProducedData[0]; -} - -void Aidge::SliceImpl_cpu::updateConsummerProducer() { - // each input is consumed by the minimum amount for a forward pass - mNbConsumedData[0] += getNbRequiredData(0); - - mNbProducedData[0] += getRequiredMemory(0, {}); -} - void Aidge::SliceImpl_cpu::forward() { // FIXME: uncomment the following code once memory handling will work assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0"); diff --git a/src/operator/SoftmaxImpl.cpp b/src/operator/SoftmaxImpl.cpp index 5f5d7411b7bb28ae28480b39c8bfdf5674f877ed..240267613e557c20edcc00e81f4bf20d17d9962f 100644 --- a/src/operator/SoftmaxImpl.cpp +++ b/src/operator/SoftmaxImpl.cpp @@ -22,9 +22,9 @@ #include "aidge/backend/cpu/operator/SoftmaxImpl.hpp" #include "aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::SoftmaxImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::SoftmaxImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::SoftmaxImpl_cpu::forward() { diff --git a/src/operator/SqrtImpl.cpp b/src/operator/SqrtImpl.cpp index 2766e8ae21738775aadad86629a99d0a180e537e..8fcb2e9d05b859b5f572f614a72ff42b1f20d4dd 100644 --- a/src/operator/SqrtImpl.cpp +++ b/src/operator/SqrtImpl.cpp @@ -22,9 +22,9 @@ #include "aidge/backend/cpu/operator/SqrtImpl.hpp" #include "aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::SqrtImpl_cpu::forward() { diff --git a/src/operator/SubImpl.cpp b/src/operator/SubImpl.cpp index 475f8cb8704739e091f0b8f01ffce680fd851e1f..ffddb59ee3373c4a0a6c2653747744a43fd471d9 100644 --- a/src/operator/SubImpl.cpp +++ b/src/operator/SubImpl.cpp @@ -23,9 +23,9 @@ #include "aidge/backend/cpu/operator/SubImpl.hpp" #include "aidge/backend/cpu/operator/SubImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::SubImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::SubImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::SubImpl_cpu::forward() { diff --git a/src/operator/TanhImpl.cpp b/src/operator/TanhImpl.cpp index c4658440ab00086be6a469c19d5ea89771857fb1..44e180739ed86e25d4be6d0beb693f73bdadbf35 100644 --- a/src/operator/TanhImpl.cpp +++ b/src/operator/TanhImpl.cpp @@ -22,9 +22,9 @@ #include "aidge/backend/cpu/operator/TanhImpl.hpp" #include "aidge/backend/cpu/operator/TanhImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::TanhImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { +Aidge::Elts_t Aidge::TanhImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { // this implementation can be in-place - return 0; + return Elts_t::DataElts(0); } void Aidge::TanhImpl_cpu::forward() { diff --git a/src/operator/TransposeImpl.cpp b/src/operator/TransposeImpl.cpp index 1fc4458ccb85e4776228a2bf9e1c73589c201a35..710e67b4f5aaa5261a111a8e131a0dd740694a4b 100644 --- a/src/operator/TransposeImpl.cpp +++ b/src/operator/TransposeImpl.cpp @@ -21,27 +21,6 @@ #include "aidge/backend/cpu/operator/TransposeImpl.hpp" #include "aidge/backend/cpu/operator/TransposeImpl_forward_kernels.hpp" -Aidge::NbElts_t Aidge::TransposeImpl2D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} -Aidge::NbElts_t Aidge::TransposeImpl3D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} -Aidge::NbElts_t Aidge::TransposeImpl4D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} -Aidge::NbElts_t Aidge::TransposeImpl5D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} -Aidge::NbElts_t Aidge::TransposeImpl6D_cpu::getNbRequiredProtected(IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return 0; -} - void Aidge::TransposeImpl2D_cpu::forward() { // Find the correct kernel type auto kernelFunc = diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 5eea881d113d72ecdd7f1efd39e077b218736ef0..63a11d19a025b5560075c4b85123d645522da09e 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -245,10 +245,10 @@ TEST_CASE("[cpu/operator] MetaOperator", "[MetaOperator][CPU]") { auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); microGraphScheduler->saveSchedulingDiagram("lstm_scheduling"); - REQUIRE(op->getNbConsumedData(0) == 512); - REQUIRE(op->getNbConsumedData(1) == 32768); - REQUIRE(op->getNbProducedData(0) == 34816); - REQUIRE(op->getNbProducedData(1) == 34816); + REQUIRE(op->getNbConsumedData(0).data == 512); + REQUIRE(op->getNbConsumedData(1).data == 32768); + REQUIRE(op->getNbProducedData(0).data == 34816); + REQUIRE(op->getNbProducedData(1).data == 34816); REQUIRE(microGraphScheduler->getStaticScheduling(0).size() == 26); REQUIRE(microGraphScheduler->getStaticScheduling(1).size() == 24); REQUIRE(microGraphScheduler->getStaticScheduling(15).size() == 24);