From ac8554fd6b4c50bc0e7919dc6fc185fa162f969e Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 26 Mar 2024 15:43:19 +0000
Subject: [PATCH] Upd FC, Pow, Sqrt implementation arguments

---
 include/aidge/backend/cpu/operator/FCImpl.hpp | 28 +++++++--
 .../aidge/backend/cpu/operator/PowImpl.hpp    |  1 +
 src/operator/FCImpl.cpp                       | 63 ++++++++++++++++---
 src/operator/PowImpl.cpp                      | 22 +++++++
 src/operator/SqrtImpl.cpp                     | 17 ++---
 5 files changed, 107 insertions(+), 24 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/FCImpl.hpp b/include/aidge/backend/cpu/operator/FCImpl.hpp
index 514cb999..71fdf8e2 100644
--- a/include/aidge/backend/cpu/operator/FCImpl.hpp
+++ b/include/aidge/backend/cpu/operator/FCImpl.hpp
@@ -26,13 +26,29 @@ namespace Aidge {
 
 // compute kernel registry for forward and backward
 class FCImplForward_cpu : public Registrable<FCImplForward_cpu,
-                                                 std::tuple<DataType, DataType, DataType, DataType>,
-                                                 void(const FC_Op::Attrs &, const DimSize_t, const DimSize_t,
-                                                      const void *, const void *, const void *, void *)> {};
+                                             std::tuple<DataType,
+                                                        DataType,
+                                                        DataType,
+                                                        DataType>,
+                                             void(const FC_Op::Attrs&,
+                                                  const DimSize_t,
+                                                  const DimSize_t,
+                                                  const void *,
+                                                  const void *,
+                                                  const void *,
+                                                  void *)> {};
 class FCImplBackward_cpu : public Registrable<FCImplBackward_cpu,
-                                                  std::tuple<DataType, DataType, DataType, DataType>,
-                                                  void(const FC_Op::Attrs &, const DimSize_t, const DimSize_t,
-                                                       const void *, const void *, const void *, void *)> {};
+                                              std::tuple<DataType,
+                                                         DataType,
+                                                         DataType,
+                                                         DataType>,
+                                              void(const FC_Op::Attrs&,
+                                              const DimSize_t,
+                                              const DimSize_t,
+                                              const void *,
+                                              const void *,
+                                              const void *,
+                                              void *)> {};
 
 class FCImpl_cpu : public OperatorImpl {
 public:
diff --git a/include/aidge/backend/cpu/operator/PowImpl.hpp b/include/aidge/backend/cpu/operator/PowImpl.hpp
index 3d63160a..f82b3dfd 100644
--- a/include/aidge/backend/cpu/operator/PowImpl.hpp
+++ b/include/aidge/backend/cpu/operator/PowImpl.hpp
@@ -41,6 +41,7 @@ public:
 
     NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
     void forward() override;
+    void backward() override;
 };
 
 namespace {
diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp
index 99524590..8b0ffca8 100644
--- a/src/operator/FCImpl.cpp
+++ b/src/operator/FCImpl.cpp
@@ -24,16 +24,17 @@
 
 void Aidge::FCImpl_cpu::forward()
 {
-    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
-    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1)) && "missing input #1");
-    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(2)) && "missing input #2");
+    const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
+    assert((op_.getInput(0)) && "missing input #0");
+    assert((op_.getInput(1)) && "missing input #1");
+    assert((op_.getInput(2)) && "missing input #2");
 
     // Find the correct kernel type
-    const auto outputDataType = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType();
+    const auto outputDataType = op_.getOutput(0)->dataType();
     const Registrar<FCImplForward_cpu>::registrar_key registrarKey = {
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(),
+        op_.getInput(0)->dataType(),
+        op_.getInput(1)->dataType(),
+        op_.getInput(2)->dataType(),
         outputDataType};
 
     Registrar<FCImplForward_cpu>::registrar_type kernelFunc;
@@ -52,9 +53,9 @@ void Aidge::FCImpl_cpu::forward()
     // call to forward(). We might put the following shared_ptr as members of
     // this class to avoid that.
     std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
-    const auto& input0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->refCastFrom(input0Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
-    const auto& input1 = std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->refCastFrom(input1Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
-    const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
+    const auto& input0 = op_.getInput(0)->refCastFrom(input0Fallback, *(op_.getOutput(0)));
+    const auto& input1 = op_.getInput(1)->refCastFrom(input1Fallback, *(op_.getOutput(0)));
+    const auto& input2 = op_.getInput(2)->refCastFrom(input2Fallback, *(op_.getOutput(0)));
 
     // Call kernel
     const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1;
@@ -64,3 +65,45 @@ void Aidge::FCImpl_cpu::forward()
         input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
         getCPUPtr(mOp.getRawOutput(0)));
 }
+
+// void Aidge::FCImpl_cpu::backward()
+// {
+//     const FC_Op& op_ = dynamic_cast<const FC_Op&>(mOp);
+//     const auto& fc_grad = op_.getOutput(0)->grad();
+//     assert(fc_grad && "missing ouput #0 gradient");
+
+//     // Find the correct kernel type
+//     const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = {
+//         op_.getInput(0)->grad()->dataType(),
+//         op_.getInput(1)->grad()->dataType(),
+//         op_.getInput(2)->grad()->dataType(),
+//         fc_grad->dataType()};
+
+//     Registrar<FCImplBackward_cpu>::registrar_type kernelFunc;
+//     if (Registrar<FCImplBackward_cpu>::exists(registrarKey)) {
+//         // One exists with the right inputs/output types
+//         kernelFunc = Registrar<FCImplBackward_cpu>::create(registrarKey);
+//     }
+//     else {
+//         // Otherwise, fallback to the kernel with all types matching output type
+//         kernelFunc = Registrar<FCImplBackward_cpu>::create({
+//             fc_grad->dataType(), fc_grad->dataType(), fc_grad->dataType(), fc_grad->dataType()});
+//     }
+
+//     // Convert input data (no overhead if not needed!)
+//     // TODO: right now, if needed, memory will be allocated/deallocated at each
+//     // call to forward(). We might put the following shared_ptr as members of
+//     // this class to avoid that.
+//     std::shared_ptr<Tensor> input0gradFallback, input1gradFallback, input2gradFallback;
+//     const auto& input0grad = op_.getInput(0)->grad()->refCastFrom(input0gradFallback, *(op_.getOutput(0)));
+//     const auto& input1grad = op_.getInput(1)->grad()->refCastFrom(input1gradFallback, *(op_.getOutput(0)));
+//     const auto& input2grad = op_.getInput(2)->grad()->refCastFrom(input2gradFallback, *(op_.getOutput(0)));
+
+//     // Call kernel
+//     const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1;
+//     kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(),
+//         batchSize,
+//         input0.size() / batchSize,
+//         input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
+//         getCPUPtr(mOp.getRawOutput(0)));
+// }
diff --git a/src/operator/PowImpl.cpp b/src/operator/PowImpl.cpp
index 22b4e27a..de79e197 100644
--- a/src/operator/PowImpl.cpp
+++ b/src/operator/PowImpl.cpp
@@ -48,3 +48,25 @@ void Aidge::PowImpl_cpu::forward() {
         getCPUPtr(mOp.getRawInput(1)),
         getCPUPtr(mOp.getRawOutput(0)));
 }
+
+void Aidge::PowImpl_cpu::backward() {
+    // Find the correct kernel type
+    const Pow_Op& op_ = dynamic_cast<const Pow_Op&>(mOp);
+    auto kernelFunc = Registrar<PowImplForward_cpu>::create({
+        op_.getOutput(0)->grad()->dataType(),
+        op_.getInput(0)->grad()->dataType(),
+        op_.getInput(1)->grad()->dataType()});
+
+    const std::vector<std::size_t> input0gradDims = getBroadcastedDims(op_.getInput(0)->grad()->dims(),
+                                                                   op_.getOutput(0)->grad()->dims());
+    const std::vector<std::size_t> input1gradDims = getBroadcastedDims(op_.getInput(1)->grad()->dims(),
+                                                                   op_.getOutput(0)->grad()->dims());
+
+    // Call kernel
+    kernelFunc(op_.getOutput(0)->grad()->dims(),
+               input0gradDims,
+               input1gradDims,
+               getCPUPtr(mOp.getRawOutput(0)),
+               getCPUPtr(mOp.getRawInput(0)),
+               getCPUPtr(mOp.getRawInput(1)));
+}
\ No newline at end of file
diff --git a/src/operator/SqrtImpl.cpp b/src/operator/SqrtImpl.cpp
index ba9b57e8..cb635cce 100644
--- a/src/operator/SqrtImpl.cpp
+++ b/src/operator/SqrtImpl.cpp
@@ -45,17 +45,18 @@ void Aidge::SqrtImpl_cpu::forward() {
 
 void Aidge::SqrtImpl_cpu::backward() {
     // reversing in and out Data for backprop
-    std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
-    std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
-    AIDGE_ASSERT(out0, "missing output #0");
+    const Sqrt_Op& op_ = dynamic_cast<const Sqrt_Op&>(mOp);
+    std::shared_ptr<Tensor> out0grad  = op_.getOutput(0)->grad();
+    std::shared_ptr<Tensor> in0grad = op_.getInput(0)->grad();
+    AIDGE_ASSERT(out0grad, "missing output #0");
 
     // Find the correct kernel type
     auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
-        in0->dataType(),
-        out0->dataType()});
+        out0grad->dataType(),
+        in0grad->dataType()});
 
     // Call kernel
-    kernelFunc(in0->size(),
-        getCPUPtr(in0),
-        getCPUPtr(out0));
+    kernelFunc(out0grad->size(),
+        getCPUPtr(out0grad),
+        getCPUPtr(in0grad));
 }
\ No newline at end of file
-- 
GitLab