From 1fdb9c996866d57bd8333f692d47c413c512128a Mon Sep 17 00:00:00 2001
From: Noam ZERAH <noam.zerah@cea.fr>
Date: Wed, 25 Sep 2024 09:13:14 +0000
Subject: [PATCH] Merge dev updates in feat_operator_bitshift (New Registrar
 System)

---
 .../backend/cpu/operator/BitShiftImpl.hpp     | 39 +++++++------------
 ...d_kernels.hpp => BitShiftImpl_kernels.hpp} | 22 +++++------
 src/operator/BitShiftImpl.cpp                 | 22 +++++------
 3 files changed, 35 insertions(+), 48 deletions(-)
 rename include/aidge/backend/cpu/operator/{BitShiftImpl_forward_kernels.hpp => BitShiftImpl_kernels.hpp} (75%)

diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
index ad4bdc6a..6da67bb7 100644
--- a/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
+++ b/include/aidge/backend/cpu/operator/BitShiftImpl.hpp
@@ -12,7 +12,7 @@
 #ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_
 #define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_
 
-#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/backend/cpu/operator/OperatorImpl.hpp"
 #include "aidge/operator/BitShift.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
@@ -21,31 +21,18 @@
 #include <vector>
 
 namespace Aidge {
-// class BitShift_Op;
-
-// compute kernel registry for forward and backward
-class BitShiftImplForward_cpu
-    : public Registrable<BitShiftImplForward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*,void*)> {
-};
-class BitShiftImplBackward_cpu
-    : public Registrable<BitShiftImplBackward_cpu, std::tuple<DataType, DataType, DataType>, void(const BitShift_Op::BitShiftDirection,const std::vector<std::size_t>&, const std::vector<std::size_t>&, const std::vector<std::size_t>&, const void*, const void*, void*)> {
-};
-
-class BitShiftImpl_cpu : public OperatorImpl {
-public:
-    BitShiftImpl_cpu(const BitShift_Op& op) : OperatorImpl(op, "cpu") {}
-
-    static std::unique_ptr<BitShiftImpl_cpu> create(const BitShift_Op& op) {
-        return std::make_unique<BitShiftImpl_cpu>(op);
-    }
-
-    Elts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
-    void forward() override;
-};
-
-namespace {
-static Registrar<BitShift_Op> registrarBitShiftImpl_cpu("cpu", Aidge::BitShiftImpl_cpu::create);
-}
+// Operator implementation entry point for the backend
+using BitShiftImpl_cpu = OperatorImpl_cpu<BitShift_Op,
+    void(const BitShift_Op::BitShiftDirection,
+    const std::vector<std::size_t>&, 
+    const std::vector<std::size_t>&, 
+    const std::vector<std::size_t>&, 
+    const void*, 
+    const void*,
+    void*)>;
+    
+    // Implementation entry point registration to Operator
+    REGISTRAR(BitShift_Op,"cpu",Aidge::BitShiftImpl_cpu::create);
 }  // namespace Aidge
 
 #endif /* AIDGE_CPU_OPERATOR_BITSHIFTIMPL_H_ */
diff --git a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp
similarity index 75%
rename from include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp
rename to include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp
index 76018735..f815e946 100644
--- a/include/aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp
@@ -9,8 +9,8 @@
  *
  ********************************************************************************/
 
-#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_FORWARD_KERNEL_H_
-#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_FORWARD_KERNEL_H_
+#ifndef AIDGE_CPU_OPERATOR_BITSHIFTIMPL_KERNELS_H_
+#define AIDGE_CPU_OPERATOR_BITSHIFTIMPL_KERNELS_H_
 
 #include "aidge/utils/Registrar.hpp"
 
@@ -57,14 +57,14 @@ void BitShiftImpl_cpu_forward_kernel(
     }
 }
 
-namespace {
-static Registrar<BitShiftImplForward_cpu> registrarBitShiftImplForward_cpu_Int32(
-        {DataType::Int32, DataType::Int32, DataType::Int32},
-        Aidge::BitShiftImpl_cpu_forward_kernel<std::int32_t, std::int32_t, std::int32_t>);
-static Registrar<BitShiftImplForward_cpu> registrarBitShiftImplForward_cpu_Int64(
-        {DataType::Int64, DataType::Int64, DataType::Int64},
-        Aidge::BitShiftImpl_cpu_forward_kernel<std::int64_t, std::int64_t, std::int64_t>);
-}  // namespace
+REGISTRAR(BitShiftImpl_cpu,
+{DataType::Int32},
+{ProdConso::inPlaceModel,Aidge::BitShiftImpl_cpu_forward_kernel<std::int32_t, std::int32_t, std::int32_t>,nullptr});
+REGISTRAR(BitShiftImpl_cpu,
+{DataType::Int64},
+{ProdConso::inPlaceModel,Aidge::BitShiftImpl_cpu_forward_kernel<std::int64_t, std::int64_t, std::int64_t>,nullptr});
+
+
 }  // namespace Aidge
 
-#endif /* AIDGE_CPU_OPERATOR_BitShiftIMPL_FORWARD_KERNEL_H_ */
+#endif /* AIDGE_CPU_OPERATOR_BitShiftIMPL_KERNELS_H_ */
\ No newline at end of file
diff --git a/src/operator/BitShiftImpl.cpp b/src/operator/BitShiftImpl.cpp
index a3ffa89d..1e0f79fd 100644
--- a/src/operator/BitShiftImpl.cpp
+++ b/src/operator/BitShiftImpl.cpp
@@ -21,21 +21,16 @@
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 #include "aidge/backend/cpu/operator/BitShiftImpl.hpp"
-#include "aidge/backend/cpu/operator/BitShiftImpl_forward_kernels.hpp"
-
-Aidge::Elts_t Aidge::BitShiftImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
-    // this implementation can be in-place
-    return Elts_t::DataElts(0);
-}
+#include "aidge/backend/cpu/operator/BitShiftImpl_kernels.hpp"
 
+template<>
 void Aidge::BitShiftImpl_cpu::forward() {
 
     const auto& op_ = dynamic_cast<const BitShift_Op&>(mOp);
 
-    auto kernelFunc = Registrar<BitShiftImplForward_cpu>::create({
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
-        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
+
+    const auto impl = Registrar<BitShiftImpl_cpu>::create(getBestMatch(getRequiredSpec()));
+
 
     const std::vector<std::size_t> inputDims0 = getBroadcastedDims(std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dims(),
                                                                    std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims());
@@ -45,7 +40,7 @@ void Aidge::BitShiftImpl_cpu::forward() {
     BitShift_Op::BitShiftDirection direction = op_.direction();
 
     // Call kernel
-    kernelFunc(
+    impl.forward(
         direction,
         inputDims0,
         inputDims1,
@@ -55,3 +50,8 @@ void Aidge::BitShiftImpl_cpu::forward() {
         getCPUPtr(mOp.getRawOutput(0)));
         
 }
+
+template <>
+void Aidge::BitShiftImpl_cpu::backward() {
+    AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for BitShift_Op on backend cpu");
+}
\ No newline at end of file
-- 
GitLab