diff --git a/include/aidge/backend/cpu/operator/AbsImpl.hpp b/include/aidge/backend/cpu/operator/AbsImpl.hpp index faba3ef69ff27fbfb92393e1e0dacaebd5d69b07..e53b3154b13aac67f0b97f5d2993da292f89cca8 100644 --- a/include/aidge/backend/cpu/operator/AbsImpl.hpp +++ b/include/aidge/backend/cpu/operator/AbsImpl.hpp @@ -24,10 +24,10 @@ namespace Aidge { // compute kernel registry for forward and backward class AbsImplForward_cpu - : public Registrable<AbsImplForward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { + : public Registrable<AbsImplForward_cpu, std::tuple<DataType, DataType>, std::function<void(const std::size_t, const void*, void*)>> { }; class AbsImplBackward_cpu - : public Registrable<AbsImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> { + : public Registrable<AbsImplBackward_cpu, std::tuple<DataType, DataType>, std::function<void(const std::size_t, const void*, void*)>> { }; class AbsImpl_cpu : public OperatorImpl { @@ -38,7 +38,7 @@ public: return std::make_unique<AbsImpl_cpu>(op); } - Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + std::shared_ptr<ProdConso> getProdConso() const override { return std::make_unique<ProdConso>(mOp, true); }; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/AndImpl.hpp b/include/aidge/backend/cpu/operator/AndImpl.hpp index 139b1f08e4c4e2900e07d2bb470cb27fb878807f..fd8cebbc5b55ffd6cf6abf22b757ebb040eb07d4 100644 --- a/include/aidge/backend/cpu/operator/AndImpl.hpp +++ b/include/aidge/backend/cpu/operator/AndImpl.hpp @@ -23,10 +23,10 @@ namespace Aidge { // compute kernel registry for forward and backward class AndImplForward_cpu - : public Registrable<AndImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { + : public Registrable<AndImplForward_cpu, std::tuple<DataType, DataType, DataType>, std::function<void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)>> { }; class AndImplBackward_cpu - : public Registrable<AndImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { + : public Registrable<AndImplBackward_cpu, std::tuple<DataType, DataType, DataType>, std::function<void(const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)>> { }; class AndImpl_cpu : public OperatorImpl { @@ -37,7 +37,7 @@ public: return std::make_unique<AndImpl_cpu>(op); } - Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + std::shared_ptr<ProdConso> getProdConso() const override { return std::make_unique<ProdConso>(mOp, true); }; void forward() override; }; diff --git a/include/aidge/backend/cpu/operator/ArgMaxImpl.hpp b/include/aidge/backend/cpu/operator/ArgMaxImpl.hpp index f93abbbca9630a8b60290bae7660f6428b41b0b3..72b9cc40ea33dcb20c65e20bdc1d1bd21d0dbf19 100644 --- a/include/aidge/backend/cpu/operator/ArgMaxImpl.hpp +++ b/include/aidge/backend/cpu/operator/ArgMaxImpl.hpp @@ -26,19 +26,19 @@ namespace Aidge { class ArgMaxImplForward_cpu : public Registrable<ArgMaxImplForward_cpu, std::tuple<DataType, DataType>, - void(std::int32_t, + std::function<void(std::int32_t, DimSize_t, const std::vector<DimSize_t>&, const void *, - void *)> {}; + void *)>> {}; class ArgMaxImplBackward_cpu : public Registrable<ArgMaxImplBackward_cpu, std::tuple<DataType, DataType>, - void(std::int32_t, + std::function<void(std::int32_t, DimSize_t, const std::vector<DimSize_t>&, const void *, - void *)> {}; + void *)>> {}; class ArgMaxImpl_cpu : public OperatorImpl { public: diff --git a/include/aidge/backend/cpu/operator/ReduceSumImpl.hpp b/include/aidge/backend/cpu/operator/ReduceSumImpl.hpp index 3b265e134e7282a81476b5aa562237ecbc93141e..0f37ef5d2ae4c2752d0930ddbf082f87e0bfe825 100644 --- a/include/aidge/backend/cpu/operator/ReduceSumImpl.hpp +++ b/include/aidge/backend/cpu/operator/ReduceSumImpl.hpp @@ -26,19 +26,19 @@ namespace Aidge { class ReduceSumImplForward_cpu : public Registrable<ReduceSumImplForward_cpu, std::tuple<DataType, DataType>, - void(const std::vector<std::int32_t>&, + std::function<void(const std::vector<std::int32_t>&, DimSize_t, const std::vector<DimSize_t>&, const void *, - void *)> {}; + void *)>> {}; class ReduceSumImpl1DBackward_cpu : public Registrable<ReduceSumImpl1DBackward_cpu, std::tuple<DataType, DataType>, - void(const std::vector<std::int32_t>&, + std::function<void(const std::vector<std::int32_t>&, DimSize_t, const std::vector<DimSize_t>&, const void *, - void *)> {}; + void *)>> {}; class ReduceSumImpl_cpu : public OperatorImpl { public: diff --git a/src/operator/AbsImpl.cpp b/src/operator/AbsImpl.cpp index 1eb86c91289d3000f9cc0792e5ca2da29d4d8c24..589de2d8685443cdb567da92a694e2db25c99d2d 100644 --- a/src/operator/AbsImpl.cpp +++ b/src/operator/AbsImpl.cpp @@ -19,11 +19,6 @@ #include "aidge/operator/Abs.hpp" #include "aidge/utils/Types.h" -Aidge::Elts_t Aidge::AbsImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return Elts_t::DataElts(0); -} - void Aidge::AbsImpl_cpu::forward() { const Abs_Op& op = static_cast<const Abs_Op&>(mOp); diff --git a/src/operator/AndImpl.cpp b/src/operator/AndImpl.cpp index bc447e74a1af797a69c942eab9ff816bc195388a..78a273d0ab66bfc82e36a69926eeccc74730e87f 100644 --- a/src/operator/AndImpl.cpp +++ b/src/operator/AndImpl.cpp @@ -23,11 +23,6 @@ #include "aidge/backend/cpu/operator/AndImpl.hpp" #include "aidge/backend/cpu/operator/AndImpl_forward_kernels.hpp" -Aidge::Elts_t Aidge::AndImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return Elts_t::DataElts(0); -} - void Aidge::AndImpl_cpu::forward() { // Find the correct kernel type auto kernelFunc = Registrar<AndImplForward_cpu>::create({