From 5d4b29671d6e9ccaa277fc80a9d2cbef973d691e Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 15 Feb 2024 14:37:46 +0000
Subject: [PATCH] [Add] backward kernel for ReLU, LeakyReLU & Producer

---
 .../backend/cpu/operator/LeakyReLUImpl.hpp    | 13 ++++--
 .../LeakyReLUImpl_backward_kernels.hpp        | 45 ++++++++++++++++++
 .../backend/cpu/operator/ProducerImpl.hpp     |  6 ++-
 .../aidge/backend/cpu/operator/ReLUImpl.hpp   | 13 ++++--
 .../operator/ReLUImpl_backward_kernels.hpp    | 45 ++++++++++++++++++
 .../aidge/backend/cpu/operator/SqrtImpl.hpp   | 14 ++++--
 .../operator/SqrtImpl_backward_kernels.hpp    | 46 +++++++++++++++++++
 .../cpu/operator/SqrtImpl_forward_kernels.hpp |  8 ++--
 src/operator/LeakyReLUImpl.cpp                | 34 +++++++++++---
 src/operator/ProducerImpl.cpp                 |  8 +---
 src/operator/ReLUImpl.cpp                     | 31 ++++++++++---
 src/operator/SqrtImpl.cpp                     | 37 +++++++++++----
 12 files changed, 251 insertions(+), 49 deletions(-)
 create mode 100644 include/aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp
 create mode 100644 include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp
 create mode 100644 include/aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp

diff --git a/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp b/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp
index 4a1da034..a9c87b4d 100644
--- a/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp
+++ b/include/aidge/backend/cpu/operator/LeakyReLUImpl.hpp
@@ -12,17 +12,17 @@
 #ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_
 #define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_H_
 
+#include <memory>
+#include <tuple>
+#include <vector>
+
 #include "aidge/backend/OperatorImpl.hpp"
 #include "aidge/operator/LeakyReLU.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
-#include <memory>
-#include <vector>
 
 namespace Aidge {
-// class LeakyReLU_Op;
-
 // compute kernel registry for forward and backward
 class LeakyReLUImplForward_cpu
     : public Registrable<LeakyReLUImplForward_cpu, std::tuple<DataType, DataType>, void(const LeakyReLU_Op::Attrs&, std::size_t, const void*, void*)> {
@@ -40,7 +40,10 @@ public:
     }
 
     NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
-    void forward() override;
+
+    void forward() override final;
+
+    void backward() override final;
 };
 
 namespace {
diff --git a/include/aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp
new file mode 100644
index 00000000..0e2fc400
--- /dev/null
+++ b/include/aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp
@@ -0,0 +1,45 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_
+#define AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_
+
+#include "aidge/utils/Registrar.hpp"
+
+#include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
+
+namespace Aidge {
+template <class I, class O>
+void LeakyReLUImpl_cpu_backward_kernel(const LeakyReLU_Op::Attrs& attrs,
+                                     std::size_t inputLenght,
+                                     const void* input_,
+                                     void* output_) {
+
+    const I* input = static_cast<const I*>(input_);
+    O* output = static_cast<O*>(output_);
+    I negativeSlope = static_cast<I>(std::get<0>(attrs));
+
+    for (std::size_t i = 0; i < inputLenght; ++i) {
+        output[i] = input[i] > 0 ? 1 : negativeSlope;
+    }
+}
+
+namespace {
+static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Float32(
+        {DataType::Float32, DataType::Float32}, Aidge::LeakyReLUImpl_cpu_backward_kernel<float, float>);
+static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Int32(
+        {DataType::Int32, DataType::Int32}, Aidge::LeakyReLUImpl_cpu_backward_kernel<int, int>);
+static Registrar<LeakyReLUImplBackward_cpu> registrarLeakyReLUImplBackward_cpu_Float64(
+        {DataType::Float64, DataType::Float64}, Aidge::LeakyReLUImpl_cpu_backward_kernel<double, double>);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_CPU_OPERATOR_LEAKYRELUIMPL_BACKWARD_KERNEL_H_ */
diff --git a/include/aidge/backend/cpu/operator/ProducerImpl.hpp b/include/aidge/backend/cpu/operator/ProducerImpl.hpp
index c1d27f7e..f1fc7a75 100644
--- a/include/aidge/backend/cpu/operator/ProducerImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ProducerImpl.hpp
@@ -18,7 +18,6 @@
 #include "aidge/operator/Producer.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 namespace Aidge {
 class ProducerImpl_cpu : public OperatorImpl {
@@ -30,7 +29,10 @@ public:
     }
 
     NbElts_t getNbProducedData(const IOIndex_t outputIdx) const override final;
-    void forward() override;
+
+    inline void forward() noexcept override final {}
+
+    inline void backward() noexcept override final {}
 };
 
 namespace {
diff --git a/include/aidge/backend/cpu/operator/ReLUImpl.hpp b/include/aidge/backend/cpu/operator/ReLUImpl.hpp
index 3338d0c4..7aff2937 100644
--- a/include/aidge/backend/cpu/operator/ReLUImpl.hpp
+++ b/include/aidge/backend/cpu/operator/ReLUImpl.hpp
@@ -12,13 +12,15 @@
 #ifndef AIDGE_CPU_OPERATOR_RELUIMPL_H_
 #define AIDGE_CPU_OPERATOR_RELUIMPL_H_
 
+#include <cstddef>  // std::size_t
+#include <memory>
+#include <tuple>    // std::tuple
+#include <vector>
+
 #include "aidge/backend/OperatorImpl.hpp"
 #include "aidge/operator/ReLU.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/backend/cpu/data/GetCPUPtr.h"
-#include <memory>
-#include <vector>
 
 namespace Aidge {
 // class ReLU_Op;
@@ -40,7 +42,10 @@ public:
     }
 
     NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
-    void forward() override;
+
+    void forward() override final;
+
+    void backward() override final;
 };
 
 namespace {
diff --git a/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp
new file mode 100644
index 00000000..47d95ac4
--- /dev/null
+++ b/include/aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp
@@ -0,0 +1,45 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_
+#define AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_
+
+#include <cstddef>  // std::size_t
+
+#include "aidge/utils/Registrar.hpp"
+
+#include "aidge/backend/cpu/operator/ReLUImpl.hpp"
+
+namespace Aidge {
+template <class I, class O>
+void ReLUImpl_cpu_backward_kernel(const std::size_t inputLenght,
+                                     const void* input_,
+                                     void* output_) {
+
+    const I* input = static_cast<const I*>(input_);
+    O* output = static_cast<O*>(output_);
+
+    for (std::size_t i = 0; i < inputLenght; ++i) {
+        output[i] = (input[i] > I(0)) ? O(1) : O(0);
+    }
+}
+
+namespace {
+static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float32(
+        {DataType::Float32, DataType::Float32}, Aidge::ReLUImpl_cpu_backward_kernel<float, float>);
+static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Int32(
+        {DataType::Int32, DataType::Int32}, Aidge::ReLUImpl_cpu_backward_kernel<int, int>);
+static Registrar<ReLUImplBackward_cpu> registrarReLUImplBackward_cpu_Float64(
+        {DataType::Float64, DataType::Float64}, Aidge::ReLUImpl_cpu_backward_kernel<double, double>);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_CPU_OPERATOR_RELUIMPL_BACKWARD_KERNEL_H_ */
diff --git a/include/aidge/backend/cpu/operator/SqrtImpl.hpp b/include/aidge/backend/cpu/operator/SqrtImpl.hpp
index b3723f27..a2c9a030 100644
--- a/include/aidge/backend/cpu/operator/SqrtImpl.hpp
+++ b/include/aidge/backend/cpu/operator/SqrtImpl.hpp
@@ -12,16 +12,17 @@
 #ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_H_
 #define AIDGE_CPU_OPERATOR_SQRTIMPL_H_
 
+#include <cstddef>  // std::size_t
+#include <memory>
+#include <tuple>
+#include <vector>
+
 #include "aidge/backend/OperatorImpl.hpp"
 #include "aidge/operator/Sqrt.hpp"
 #include "aidge/utils/Registrar.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/backend/cpu/data/GetCPUPtr.h"
-#include <memory>
-#include <vector>
 
 namespace Aidge {
-// class Sqrt_Op;
 
 // compute kernel registry for forward and backward
 class SqrtImplForward_cpu
@@ -40,7 +41,10 @@ public:
     }
 
     NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
-    void forward() override;
+
+    void forward() override final;
+
+    void backward() override final;
 };
 
 namespace {
diff --git a/include/aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp
new file mode 100644
index 00000000..9cf5118a
--- /dev/null
+++ b/include/aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp
@@ -0,0 +1,46 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_
+#define AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_
+
+#include <cmath>    // std::sqrt
+#include <cstddef>  // std::size_t
+
+#include "aidge/utils/Registrar.hpp"
+
+#include "aidge/backend/cpu/operator/SqrtImpl.hpp"
+
+namespace Aidge {
+template <class I, class O>
+void SqrtImpl_cpu_backward_kernel(const std::size_t inputLenght,
+                                     const void* input_,
+                                     void* output_) {
+
+    const I* input = static_cast<const I*>(input_);
+    O* output = static_cast<O*>(output_);
+
+    for (std::size_t i = 0; i < inputLenght; ++i) {
+        output[i] = static_cast<O>(0.5/(std::sqrt(static_cast<float>(input[i]))));
+    }
+}
+
+namespace {
+static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float32(
+        {DataType::Float32, DataType::Float32}, Aidge::SqrtImpl_cpu_backward_kernel<float, float>);
+static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Int32(
+        {DataType::Int32, DataType::Int32}, Aidge::SqrtImpl_cpu_backward_kernel<int, int>);
+static Registrar<SqrtImplBackward_cpu> registrarSqrtImplBackward_cpu_Float64(
+        {DataType::Float64, DataType::Float64}, Aidge::SqrtImpl_cpu_backward_kernel<double, double>);
+}  // namespace
+}  // namespace Aidge
+
+#endif /* AIDGE_CPU_OPERATOR_SQRTIMPL_BACKWARD_KERNEL_H_ */
diff --git a/include/aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp
index a180fc2c..886b978c 100644
--- a/include/aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp
@@ -12,14 +12,16 @@
 #ifndef AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
 #define AIDGE_CPU_OPERATOR_SQRTIMPL_FORWARD_KERNEL_H_
 
+#include <cmath>    // std::sqrt
+#include <cstddef>  // std::size_t
+
 #include "aidge/utils/Registrar.hpp"
-#include <cmath>
 
 #include "aidge/backend/cpu/operator/SqrtImpl.hpp"
 
 namespace Aidge {
 template <class I, class O>
-void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght,
+void SqrtImpl_cpu_forward_kernel(const std::size_t inputLenght,
                                      const void* input_,
                                      void* output_) {
 
@@ -27,7 +29,7 @@ void SqrtImpl_cpu_forward_kernel(std::size_t inputLenght,
     O* output = static_cast<O*>(output_);
 
     for (std::size_t i = 0; i < inputLenght; ++i) {
-        output[i] = std::sqrt(input[i]);
+        output[i] = static_cast<O>(std::sqrt(static_cast<float>(input[i])));
     }
 }
 
diff --git a/src/operator/LeakyReLUImpl.cpp b/src/operator/LeakyReLUImpl.cpp
index 17912eb1..4ffb230d 100644
--- a/src/operator/LeakyReLUImpl.cpp
+++ b/src/operator/LeakyReLUImpl.cpp
@@ -10,17 +10,17 @@
  ********************************************************************************/
 
 #include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
 #include <vector>
 
+#include "aidge/data/Tensor.hpp"
 #include "aidge/operator/LeakyReLU.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/utils/Registrar.hpp"
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 #include "aidge/backend/cpu/operator/LeakyReLUImpl.hpp"
 #include "aidge/backend/cpu/operator/LeakyReLUImpl_forward_kernels.hpp"
+#include "aidge/backend/cpu/operator/LeakyReLUImpl_backward_kernels.hpp"
 
 Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
     // this implementation can be in-place
@@ -28,16 +28,36 @@ Aidge::NbElts_t Aidge::LeakyReLUImpl_cpu::getNbRequiredProtected(const Aidge::IO
 }
 
 void Aidge::LeakyReLUImpl_cpu::forward() {
-    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
+    std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
+    std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
+    AIDGE_ASSERT(in0, "missing input #0");
 
     // Find the correct kernel type
     auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
-        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
+        in0->dataType(),
+        out0->dataType()});
 
     // Call kernel
     kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(),
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
+        in0->size(),
         getCPUPtr(mOp.getRawInput(0)),
         getCPUPtr(mOp.getRawOutput(0)));
 }
+
+void Aidge::LeakyReLUImpl_cpu::backward() {
+    // reversing in and out Data for backprop
+    std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
+    std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
+    AIDGE_ASSERT(in0, "missing input #0");
+
+    // Find the correct kernel type
+    auto kernelFunc = Registrar<LeakyReLUImplForward_cpu>::create({
+        in0->dataType(),
+        out0->dataType()});
+
+    // Call kernel
+    kernelFunc(dynamic_cast<const LeakyReLU_Op&>(mOp).getStaticAttributes(),
+        in0->size(),
+        getCPUPtr(in0),
+        getCPUPtr(out0));
+}
\ No newline at end of file
diff --git a/src/operator/ProducerImpl.cpp b/src/operator/ProducerImpl.cpp
index 4c5883a9..d5432c0d 100644
--- a/src/operator/ProducerImpl.cpp
+++ b/src/operator/ProducerImpl.cpp
@@ -10,13 +10,11 @@
  ********************************************************************************/
 
 #include <cassert>
-#include <numeric> // std::accumulate
+#include <memory>
 #include <vector>
 
 #include "aidge/data/Tensor.hpp"
-#include "aidge/operator/Producer.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 #include "aidge/backend/cpu/operator/ProducerImpl.hpp"
 
@@ -29,7 +27,3 @@ Aidge::DimSize_t Aidge::ProducerImpl_cpu::getNbProducedData(
 
     return std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->size();
 }
-
-void Aidge::ProducerImpl_cpu::forward()
-{
-}
diff --git a/src/operator/ReLUImpl.cpp b/src/operator/ReLUImpl.cpp
index 8863be28..84bb1045 100644
--- a/src/operator/ReLUImpl.cpp
+++ b/src/operator/ReLUImpl.cpp
@@ -9,18 +9,18 @@
  *
  ********************************************************************************/
 
-#include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
+#include <memory>
 #include <vector>
 
+#include "aidge/data/Tensor.hpp"
 #include "aidge/operator/ReLU.hpp"
 #include "aidge/utils/Types.h"
 #include "aidge/backend/cpu/data/GetCPUPtr.h"
+#include "aidge/utils/ErrorHandling.hpp"
 
 #include "aidge/backend/cpu/operator/ReLUImpl.hpp"
 #include "aidge/backend/cpu/operator/ReLUImpl_forward_kernels.hpp"
+#include "aidge/backend/cpu/operator/ReLUImpl_backward_kernels.hpp"
 
 Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
     // this implementation can be in-place
@@ -28,15 +28,32 @@ Aidge::NbElts_t Aidge::ReLUImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex
 }
 
 void Aidge::ReLUImpl_cpu::forward() {
-    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
+    std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
+    AIDGE_ASSERT(in0, "missing input #0");
 
     // Find the correct kernel type
     auto kernelFunc = Registrar<ReLUImplForward_cpu>::create({
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
+        in0->dataType(),
         std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
 
     // Call kernel
-    kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
+    kernelFunc(in0->size(),
         getCPUPtr(mOp.getRawInput(0)),
         getCPUPtr(mOp.getRawOutput(0)));
 }
+
+void Aidge::ReLUImpl_cpu::backward() {
+    // reversing in and out Tensors
+    std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->grad();
+    std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->grad();
+    AIDGE_ASSERT(out0, "missing input #0");
+
+    // Find the correct kernel type
+    auto kernelFunc = Registrar<ReLUImplBackward_cpu>::create({
+        in0->dataType(),
+        out0->dataType()
+    });
+
+    // Call kernel
+    kernelFunc(in0->size(), getCPUPtr(in0), getCPUPtr(out0));
+}
diff --git a/src/operator/SqrtImpl.cpp b/src/operator/SqrtImpl.cpp
index 2766e8ae..ba9b57e8 100644
--- a/src/operator/SqrtImpl.cpp
+++ b/src/operator/SqrtImpl.cpp
@@ -9,18 +9,18 @@
  *
  ********************************************************************************/
 
-#include <cassert>
-#include <chrono>  // std::chrono::milliseconds
-#include <numeric> // std::accumulate
-#include <thread>  // std::this_thread::sleep_for
+#include <memory>
 #include <vector>
 
+#include "aidge/backend/cpu/data/GetCPUPtr.h"
+#include "aidge/data/Tensor.hpp"
 #include "aidge/operator/Sqrt.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/Types.h"
-#include "aidge/backend/cpu/data/GetCPUPtr.h"
 
 #include "aidge/backend/cpu/operator/SqrtImpl.hpp"
 #include "aidge/backend/cpu/operator/SqrtImpl_forward_kernels.hpp"
+#include "aidge/backend/cpu/operator/SqrtImpl_backward_kernels.hpp"
 
 Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
     // this implementation can be in-place
@@ -28,15 +28,34 @@ Aidge::NbElts_t Aidge::SqrtImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex
 }
 
 void Aidge::SqrtImpl_cpu::forward() {
-    assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(0)) && "missing input #0");
+    std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
+    std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
+    AIDGE_ASSERT(in0, "missing input #0");
 
     // Find the correct kernel type
     auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
-        std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
-        std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
+        in0->dataType(),
+        out0->dataType()});
 
     // Call kernel
-    kernelFunc(std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->size(),
+    kernelFunc(in0->size(),
         getCPUPtr(mOp.getRawInput(0)),
         getCPUPtr(mOp.getRawOutput(0)));
+}
+
+void Aidge::SqrtImpl_cpu::backward() {
+    // reversing in and out Data for backprop
+    std::shared_ptr<Tensor> in0 = std::static_pointer_cast<Tensor>(mOp.getRawOutput(0));
+    std::shared_ptr<Tensor> out0 = std::static_pointer_cast<Tensor>(mOp.getRawInput(0));
+    AIDGE_ASSERT(out0, "missing output #0");
+
+    // Find the correct kernel type
+    auto kernelFunc = Registrar<SqrtImplForward_cpu>::create({
+        in0->dataType(),
+        out0->dataType()});
+
+    // Call kernel
+    kernelFunc(in0->size(),
+        getCPUPtr(in0),
+        getCPUPtr(out0));
 }
\ No newline at end of file
-- 
GitLab