diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp index ad4bdc6aace7225f1eab9a9bc1bdf9edff6d691a..6da67bb7dd4469b6ca609c5aea1ae70dfca3f939 100644 --- a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp +++ b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp @@ -12,7 +12,7 @@ #ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ #define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ -#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/backend/cpu/operator/OperatorImpl.hpp" #include "aidge/operator/BitShift.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" @@ -21,31 +21,18 @@ #include <vector> namespace Aidge { -// class BitShift_Op; - -// compute kernel registry for forward and backward -class BitShiftImplForward_cpu - : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> { -}; -class BitShiftImplBackward_cpu - : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> { -}; - -class BitShiftImpl_cpu : public OperatorImpl { -public: - BitShiftImpl_cpu(const BitShift_Op& op) : OperatorImpl(op, "cpu") {} - - static std::unique_ptr<BitShiftImpl_cpu> create(const BitShift_Op& op) { - return std::make_unique<BitShiftImpl_cpu>(op); - } - - Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; - void forward() override; -}; - -namespace { -static Registrar<BitShift_Op> registrarBitShiftImpl_cpu("cpu", Aidge::BitShiftImpl_cpu::create); -} +// Operator implementation entry point for the backend +using BitShiftImpl_cpu = OperatorImpl_cpu<BitShift_Op, + void(const BitShift_Op::BitShiftDirection, + const std::vector<std::size_t>&, + const std::vector<std::size_t>&, + const std::vector<std::size_t>&, + const void*, + const void*, + void*)>; + + // Implementation entry point registration to Operator + REGISTRAR(BitShift_Op,"cpu",Aidge::BitShiftImpl_cpu::create); } // namespace Aidge #endif /* AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp similarity index 75% rename from include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp rename to include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp index 760187351fb6ba476d60eda4370de6553655520c..f815e946ea2e4abaff48a6e5155368d564e88e8c 100644 --- a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp @@ -9,8 +9,8 @@ * ********************************************************************************/ -#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_FORWARD_KERNEL_H_ -#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_FORWARD_KERNEL_H_ +#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_KERNELS_H_ +#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_KERNELS_H_ #include "aidge/utils/Registrar.hpp" @@ -57,14 +57,14 @@ void BitShiftImpl_cpu_forward_kernel( } } -namespace { -static Registrar<BitShiftImplForward_cpu> registrarBitShiftImplForward_cpu_Int32( - {DataType::Int32, DataType::Int32, DataType::Int32}, - Aidge::BitShiftImpl_cpu_forward_kernel<std::int32_t, std::int32_t, std::int32_t>); -static Registrar<BitShiftImplForward_cpu> registrarBitShiftImplForward_cpu_Int64( - {DataType::Int64, DataType::Int64, DataType::Int64}, - Aidge::BitShiftImpl_cpu_forward_kernel<std::int64_t, std::int64_t, std::int64_t>); -} // namespace +REGISTRAR(BitShiftImpl_cpu, +{DataType::Int32}, +{ProdConso::inPlaceModel,Aidge::BitShiftImpl_cpu_forward_kernel<std::int32_t, std::int32_t, std::int32_t>,nullptr}); +REGISTRAR(BitShiftImpl_cpu, +{DataType::Int64}, +{ProdConso::inPlaceModel,Aidge::BitShiftImpl_cpu_forward_kernel<std::int64_t, std::int64_t, std::int64_t>,nullptr}); + + } // namespace Aidge -#endif /* AIDGE_CPU_OPERATOR_BitShiftIMPL_FORWARD_KERNEL_H_ */ +#endif /* AIDGE_CPU_OPERATOR_BitShiftIMPL_KERNELS_H_ */ \ No newline at end of file diff --git a/src/operator/BitShiftImpl.cpp b/src/operator/BitShiftImpl.cpp index a3ffa89dc4904fbcc0b6a2e313b450e8d6f57ed9..1e0f79fd29fd140f0b41c64d245b9b240da80028 100644 --- a/src/operator/BitShiftImpl.cpp +++ b/src/operator/BitShiftImpl.cpp @@ -21,21 +21,16 @@ #include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/backend/cpu/operator/BitShiftImpl.hpp" -#include "aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp" - -Aidge::Elts_t Aidge::BitShiftImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { - // this implementation can be in-place - return Elts_t::DataElts(0); -} +#include "aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp" +template<> void Aidge::BitShiftImpl_cpu::forward() { const auto& op_ = dynamic_cast<const BitShift_Op&>(mOp); - auto kernelFunc = Registrar<BitShiftImplForward_cpu>::create({ - std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(), - std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); + + const auto impl = Registrar<BitShiftImpl_cpu>::create(getBestMatch(getRequiredSpec())); + const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims()); @@ -45,7 +40,7 @@ void Aidge::BitShiftImpl_cpu::forward() { BitShift_Op::BitShiftDirection direction = op_.direction(); // Call kernel - kernelFunc( + impl.forward( direction, inputDims0, inputDims1, @@ -55,3 +50,8 @@ void Aidge::BitShiftImpl_cpu::forward() { getCPUPtr(mOp.getRawOutput(0))); } + +template <> +void Aidge::BitShiftImpl_cpu::backward() { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for BitShift_Op on backend cpu"); +} \ No newline at end of file