From bad418ebf5547fd1100ffe231e98123d26d5a88c Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Tue, 18 Jun 2024 17:10:24 +0200
Subject: [PATCH] Removed bias and improved ConvDepthWiseImpl2D_cpu::forward()

---
 .../ConvDepthWiseImpl_forward_kernels.hpp     |  2 +-
 .../cpu/operator/ConvImpl_forward_kernels.hpp |  4 +-
 include/aidge/backend/cpu/operator/FCImpl.hpp |  8 ++--
 .../cpu/operator/FCImpl_backward_kernels.hpp  |  5 +-
 .../cpu/operator/FCImpl_forward_kernels.hpp   |  4 +-
 src/operator/ConvDepthWiseImpl.cpp            | 47 ++++++++++++++-----
 src/operator/ConvImpl.cpp                     |  6 +--
 src/operator/FCImpl.cpp                       | 26 +++++-----
 8 files changed, 63 insertions(+), 39 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_forward_kernels.hpp
index 720e331c..9537e34a 100644
--- a/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_forward_kernels.hpp
@@ -64,7 +64,7 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const ConvDepthWise_Op<2>::Attrs &at
     for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
         for (std::size_t ch = 0; ch < inputDims[1]; ++ch) {
             const std::size_t oIndex = (ch + batch*inputDims[1]) * oxSize * oySize;
-            B biasVal = ((!std::get<3>(attrs)) && biases != nullptr) ? biases[ch] : B(0);
+            B biasVal = (biases != nullptr) ? biases[ch] : B(0);
             std::fill(output + oIndex, output+(oIndex+oxSize*oySize), biasVal);
             const std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3];
             const std::size_t wIndex = ch * std::get<2>(attrs)[0] * std::get<2>(attrs)[1];
diff --git a/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp
index 0f171d79..c2e5e4ca 100644
--- a/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ConvImpl_forward_kernels.hpp
@@ -106,8 +106,8 @@ void ConvImpl2D_cpu_forward_kernel(const Conv_Op<2>::Attrs &attrs, const std::ar
     for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
         for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
             const std::size_t oIndex = (outCh + batch*outChannels) * oxSize * oySize;
-            // If  NoBias or bias = nullptr, set B(0)
-            B biasVal = ((!std::get<3>(attrs)) && biases != nullptr) ? biases[outCh] : B(0);
+            // If bias = nullptr, set B(0)
+            B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
             std::fill(output + oIndex, output+(oIndex+oxSize*oySize), biasVal);
             for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
                 const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
diff --git a/include/aidge/backend/cpu/operator/FCImpl.hpp b/include/aidge/backend/cpu/operator/FCImpl.hpp
index db5f7683..f9f97ffd 100644
--- a/include/aidge/backend/cpu/operator/FCImpl.hpp
+++ b/include/aidge/backend/cpu/operator/FCImpl.hpp
@@ -30,8 +30,8 @@ class FCImplForward_cpu : public Registrable<FCImplForward_cpu,
                                                         DataType,
                                                         DataType,
                                                         DataType>,
-                                             void(const FC_Op::Attrs&,
-                                                  const DimSize_t,
+                                             void(
+                                                const DimSize_t,
                                                   const DimSize_t,
                                                   const DimSize_t,
                                                   const void *,
@@ -43,8 +43,8 @@ class FCImplBackward_cpu : public Registrable<FCImplBackward_cpu,
                                                          DataType,
                                                          DataType,
                                                          DataType>,
-                                              void(const FC_Op::Attrs&,
-                                              const DimSize_t,
+                                              void(
+                                                const DimSize_t,
                                               const DimSize_t,
                                               const DimSize_t,
                                               const void *,
diff --git a/include/aidge/backend/cpu/operator/FCImpl_backward_kernels.hpp b/include/aidge/backend/cpu/operator/FCImpl_backward_kernels.hpp
index 9dd91eb8..c93a44d9 100644
--- a/include/aidge/backend/cpu/operator/FCImpl_backward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/FCImpl_backward_kernels.hpp
@@ -19,8 +19,7 @@
 
 namespace Aidge {
 template <class I, class O, class W, class B>
-void FCImpl_cpu_backward_kernel(const FC_Op::Attrs& attrs,
-                                const DimSize_t batchSize,
+void FCImpl_cpu_backward_kernel(const DimSize_t batchSize,
                                 const DimSize_t inputFeatureSize,
                                 const DimSize_t outputFeatureSize,
                                 const void* input_,
@@ -40,7 +39,7 @@ void FCImpl_cpu_backward_kernel(const FC_Op::Attrs& attrs,
 
 
     // bias grad
-    if (std::get<0>(attrs)) { // no bias
+    if (biasesGrad == nullptr) { // no bias
         std::fill(biasesGrad, biasesGrad + outputFeatureSize, B(0));
     } else {
         for (std::size_t o = 0; o < outputFeatureSize; ++o) { // nb outputs
diff --git a/include/aidge/backend/cpu/operator/FCImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/FCImpl_forward_kernels.hpp
index 2a1a86ba..a82e8850 100644
--- a/include/aidge/backend/cpu/operator/FCImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/FCImpl_forward_kernels.hpp
@@ -83,7 +83,7 @@ namespace Aidge {
 // }
 
 template <class I, class W, class B, class O>
-void FCImpl_cpu_forward_kernel(const FC_Op::Attrs& attrs, const DimSize_t batchSize, const DimSize_t inputFeatureSize,
+void FCImpl_cpu_forward_kernel(const DimSize_t batchSize, const DimSize_t inputFeatureSize,
                                     const DimSize_t outputFeatureSize,
                                    const void* input_, const void* weights_, const void* biases_, void* output_) {
     // FIXME: missing FC attributes as arguments
@@ -92,7 +92,7 @@ void FCImpl_cpu_forward_kernel(const FC_Op::Attrs& attrs, const DimSize_t batchS
     const B* biases = static_cast<const B*>(biases_);
     O* output = static_cast<O*>(output_);
 
-    if (std::get<0>(attrs)) {
+    if (biases == nullptr) {
         std::fill(output, output+(batchSize*outputFeatureSize), B(0));
     }
     else {
diff --git a/src/operator/ConvDepthWiseImpl.cpp b/src/operator/ConvDepthWiseImpl.cpp
index 5c8d2fe3..51677f05 100644
--- a/src/operator/ConvDepthWiseImpl.cpp
+++ b/src/operator/ConvDepthWiseImpl.cpp
@@ -28,23 +28,48 @@ Aidge::Elts_t Aidge::ConvDepthWiseImpl2D_cpu::getNbRequiredProtected(IOIndex_t /
 }
 
 void Aidge::ConvDepthWiseImpl2D_cpu::forward() {
+    const auto& opTensor = static_cast<const OperatorTensor&>(mOp);
+
     assert(mOp.getRawInput(0) && "missing input #0");
     assert(mOp.getRawInput(1) && "missing input #1");
     assert(mOp.getRawInput(2) && "missing input #2");
 
-    assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 4) && "support for 4-dimensions tensors only");
+    assert((std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->nbDims() == 4) && "support for 3-dimensions tensors only");
 
     // Find the correct kernel type
-    auto kernelFunc =
-            Registrar<ConvDepthWiseImpl2DForward_cpu>::create({std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(),
-                                                               std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dataType(),
-                                                               std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->dataType(),
-                                                               std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()});
+    const auto outputDataType = opTensor.getOutput(0)->dataType();
+    const Registrar<ConvDepthWiseImpl2DForward_cpu>::registrar_key registrarKey = {
+        opTensor.getInput(0)->dataType(),
+        opTensor.getInput(1)->dataType(),
+        ((opTensor.getInput(2)) ? opTensor.getInput(2)->dataType() : opTensor.getInput(1)->dataType()),
+        outputDataType};
+
+    Registrar<ConvDepthWiseImpl2DForward_cpu>::registrar_type kernelFunc;
+    if (Registrar<ConvDepthWiseImpl2DForward_cpu>::exists(registrarKey)) {
+        // One exists with the right inputs/output types
+        kernelFunc = Registrar<ConvDepthWiseImpl2DForward_cpu>::create(registrarKey);
+    }
+    else {
+        // Otherwise, fallback to the kernel with all types matching output type
+        kernelFunc = Registrar<ConvDepthWiseImpl2DForward_cpu>::create({
+            outputDataType, outputDataType, outputDataType, outputDataType});
+    }
+
+    // Convert input data (no overhead if not needed!)
+    // TODO: right now, if needed, memory will be allocated/deallocated at each
+    // call to forward(). We might put the following shared_ptr as members of
+    // this class to avoid that.
+    std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
+    const auto& input0 = opTensor.getInput(0)->refCastFrom(input0Fallback, *opTensor.getOutput(0));
+    const auto& input1 = opTensor.getInput(1)->refCastFrom(input1Fallback, *opTensor.getOutput(0));
+    const auto& input2 = (opTensor.getInput(2)) ? opTensor.getInput(2)->refCastFrom(input2Fallback, *opTensor.getOutput(0)) : Tensor();
 
     // Call kernel
-    kernelFunc(dynamic_cast<const ConvDepthWise_Op<2>&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->template dims<4>(),
-               getCPUPtr(mOp.getRawInput(0)),
-               getCPUPtr(mOp.getRawInput(1)),
-               getCPUPtr(mOp.getRawInput(2)),
-               getCPUPtr(mOp.getRawOutput(0)));
+    kernelFunc(dynamic_cast<const ConvDepthWise_Op<2>&>(mOp).getStaticAttributes(), // Conv attributes
+               opTensor.getInput(0)->template dims<4>(), // input dimensions
+               input0.getImpl()->rawPtr(), // input
+               input1.getImpl()->rawPtr(), // weight
+               (opTensor.getInput(2)) ? input2.getImpl()->rawPtr() : nullptr, // bias
+               getCPUPtr(mOp.getRawOutput(0)) // output
+            );
 }
diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp
index 27e2882c..b69bbc07 100644
--- a/src/operator/ConvImpl.cpp
+++ b/src/operator/ConvImpl.cpp
@@ -40,7 +40,7 @@ void Aidge::ConvImpl2D_cpu::forward() {
     const Registrar<ConvImpl2DForward_cpu>::registrar_key registrarKey = {
         opTensor.getInput(0)->dataType(),
         opTensor.getInput(1)->dataType(),
-        opTensor.getInput(2)->dataType(),
+        ((opTensor.getInput(2)) ? opTensor.getInput(2)->dataType() : opTensor.getInput(1)->dataType()),
         outputDataType};
 
     Registrar<ConvImpl2DForward_cpu>::registrar_type kernelFunc;
@@ -61,7 +61,7 @@ void Aidge::ConvImpl2D_cpu::forward() {
     std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
     const auto& input0 = opTensor.getInput(0)->refCastFrom(input0Fallback, *opTensor.getOutput(0));
     const auto& input1 = opTensor.getInput(1)->refCastFrom(input1Fallback, *opTensor.getOutput(0));
-    const auto& input2 = opTensor.getInput(2)->refCastFrom(input2Fallback, *opTensor.getOutput(0));
+    const auto& input2 = (opTensor.getInput(2)) ? opTensor.getInput(2)->refCastFrom(input2Fallback, *opTensor.getOutput(0)) : Tensor();
 
     // Call kernel
     kernelFunc(dynamic_cast<const Conv_Op<2>&>(mOp).getStaticAttributes(), // Conv attributes
@@ -69,7 +69,7 @@ void Aidge::ConvImpl2D_cpu::forward() {
                dynamic_cast<const Conv_Op<2>&>(mOp).outChannels(), // outChannels
                input0.getImpl()->rawPtr(), // input
                input1.getImpl()->rawPtr(), // weight
-               input2.getImpl()->rawPtr(), // bias
+               (opTensor.getInput(2)) ? input2.getImpl()->rawPtr() : nullptr, // bias
                getCPUPtr(mOp.getRawOutput(0)) // output
             );
 }
diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp
index 9ade5841..ad3727f6 100644
--- a/src/operator/FCImpl.cpp
+++ b/src/operator/FCImpl.cpp
@@ -34,9 +34,9 @@ void Aidge::FCImpl_cpu::forward()
     // Find the correct kernel type
     const auto outputDataType = op_.getOutput(0)->dataType();
     const Registrar<FCImplForward_cpu>::registrar_key registrarKey = {
-        outputDataType,
-        outputDataType,
-        outputDataType,
+        op_.getInput(0)->dataType(),
+        op_.getInput(1)->dataType(),
+        ((op_.getInput(2)) ? op_.getInput(2)->dataType() : op_.getInput(1)->dataType()),
         outputDataType};
 
     Registrar<FCImplForward_cpu>::registrar_type kernelFunc;
@@ -57,15 +57,16 @@ void Aidge::FCImpl_cpu::forward()
     std::shared_ptr<Tensor> input0Fallback, input1Fallback, input2Fallback;
     const auto& input0 = op_.getInput(0)->refCastFrom(input0Fallback, *(op_.getOutput(0)));
     const auto& input1 = op_.getInput(1)->refCastFrom(input1Fallback, *(op_.getOutput(0)));
-    const auto& input2 = op_.getInput(2)->refCastFrom(input2Fallback, *(op_.getOutput(0)));
+    const auto& input2 = (op_.getInput(2)) ? op_.getInput(2)->refCastFrom(input2Fallback, *(op_.getOutput(0))) : Tensor();
 
     // Call kernel
     const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1;
-    kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(),
-        batchSize,
+    kernelFunc(batchSize,
         input1.dims()[1], // nb input features
         input1.dims()[0], // nb output features
-        input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
+        input0.getImpl()->rawPtr(),
+        input1.getImpl()->rawPtr(),
+        (op_.getInput(2)) ? input2.getImpl()->rawPtr() : nullptr,
         getCPUPtr(mOp.getRawOutput(0)));
 }
 
@@ -81,9 +82,9 @@ void Aidge::FCImpl_cpu::backward()
     // Find the correct kernel type
     const Registrar<FCImplBackward_cpu>::registrar_key registrarKey = {
         fc_grad->dataType(),
-        op_.getInput(0)->grad()->dataType(),
         op_.getInput(1)->grad()->dataType(),
-        op_.getInput(2)->grad()->dataType()};
+        (op_.getInput(2)) ? op_.getInput(2)->grad()->dataType() : op_.getInput(1)->grad()->dataType(),
+        op_.getInput(0)->grad()->dataType()};
 
     Registrar<FCImplBackward_cpu>::registrar_type kernelFunc;
     if (Registrar<FCImplBackward_cpu>::exists(registrarKey)) {
@@ -103,12 +104,11 @@ void Aidge::FCImpl_cpu::backward()
     std::shared_ptr<Tensor> input0gradFallback, input1gradFallback, input2gradFallback;
     const auto& input0grad = op_.getInput(0)->grad()->refCastFrom(input0gradFallback, *(op_.getOutput(0)));
     const auto& input1grad = op_.getInput(1)->grad()->refCastFrom(input1gradFallback, *(op_.getOutput(0)));
-    const auto& input2grad = op_.getInput(2)->grad()->refCastFrom(input2gradFallback, *(op_.getOutput(0)));
+    const auto& input2grad = (op_.getInput(2)) ? op_.getInput(2)->grad()->refCastFrom(input2gradFallback, *(op_.getOutput(0))) : Tensor();
 
     // Call kernel
     const auto batchSize = (input0grad.dims().size() > 1) ? input0grad.dims()[0] : 1;
-    kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(),
-        batchSize,
+    kernelFunc(batchSize,
         input1grad.dims()[1], // nb input features
         input1grad.dims()[0], // nb output features
         getCPUPtr(fc_grad),
@@ -116,5 +116,5 @@ void Aidge::FCImpl_cpu::backward()
         getCPUPtr(mOp.getRawInput(1)),
         input0grad.getImpl()->rawPtr(),
         input1grad.getImpl()->rawPtr(),
-        input2grad.getImpl()->rawPtr());
+        (op_.getInput(2)) ? input2grad.getImpl()->rawPtr() : nullptr);
 }
-- 
GitLab