From e31dca3ae146db69cccc3a29fcff4e8e9e76880e Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Thu, 23 Nov 2023 15:30:24 +0100
Subject: [PATCH] add axis attr to Softmax

---
 .../backend/cpu/operator/SoftmaxImpl.hpp      |  4 +-
 .../operator/SoftmaxImpl_forward_kernels.hpp  | 39 ++++++++++---------
 src/operator/SoftmaxImpl.cpp                  | 12 +++---
 unit_tests/operator/Test_SoftmaxImpl.cpp      |  9 ++---
 4 files changed, 32 insertions(+), 32 deletions(-)

diff --git a/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp b/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp
index 995f57f7..2b15eb36 100644
--- a/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp
+++ b/include/aidge/backend/cpu/operator/SoftmaxImpl.hpp
@@ -24,10 +24,10 @@ namespace Aidge {
 
 // compute kernel registry for forward and backward
 class SoftmaxImplForward_cpu
-    : public Registrable<SoftmaxImplForward_cpu, std::tuple<DataType, DataType>, void(const DimSize_t, const DimSize_t, const DimSize_t, const void*, void*)> {
+    : public Registrable<SoftmaxImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const void*, void*)> {
 };
 class SoftmaxImplBackward_cpu
-    : public Registrable<SoftmaxImplBackward_cpu, std::tuple<DataType, DataType>, void(const std::size_t, const void*, void*)> {
+    : public Registrable<SoftmaxImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const void*, void*)> {
 };
 
 class SoftmaxImpl_cpu : public OperatorImpl {
diff --git a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp
index 297a3a32..fb264afd 100644
--- a/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/SoftmaxImpl_forward_kernels.hpp
@@ -22,30 +22,33 @@
 
 namespace Aidge {
 template <class I, class O>
-void SoftmaxImpl_cpu_forward_kernel(const DimSize_t batchSize,
-                                        const DimSize_t channelSize,
-                                        const DimSize_t featureSize,
-                                        const void* input_,
-                                        void* output_) {
-
+void SoftmaxImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const void* input_, void* output_)
+{
     const I* input = static_cast<const I*>(input_);
     O* output = static_cast<O*>(output_);
 
-    for (std::size_t batch = 0; batch < batchSize; ++batch) {
-        for (std::size_t feature = 0; feature < featureSize; ++feature) {
-            std::size_t ioIndex = batch*channelSize*featureSize + feature;
+    std::size_t postAxisElems = 1;
+    for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) {
+        postAxisElems *= inputDims[i];
+    }
+    std::size_t preAxisElems = 1;
+    for (std::size_t i = 0; i < axisIdx; ++i) {
+        preAxisElems *= inputDims[i];
+    }
 
-            I sum(0.0);
-            for (std::size_t ch = 0; ch < channelSize; ++ch) {
-                output[ioIndex] = std::exp(input[ioIndex]);
-                sum += output[ioIndex];
-                ioIndex+=featureSize;
+    for (std::size_t i = 0; i < preAxisElems; ++i) {
+        for (std::size_t j = 0; j < postAxisElems; ++j) {
+            // Calculate sum of exponentials within the axis
+            I sumExp = 0;
+            for (std::size_t k = 0; k < inputDims[axisIdx]; ++k) {
+                std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
+                sumExp += std::exp(input[inIdx]);
             }
 
-            ioIndex = batch*channelSize*featureSize + feature;
-            for (std::size_t ch = 0; ch < channelSize; ++ch) {
-                output[ioIndex] /= sum;
-                ioIndex += featureSize;
+            // Calculate softmax for the current slice along the axis
+            for (std::size_t  k = 0; k < inputDims[axisIdx]; ++k) {
+                std::size_t inIdx = i * inputDims[axisIdx] * postAxisElems + k * postAxisElems + j;
+                output[inIdx] = std::exp(input[inIdx]) / sumExp;
             }
         }
     }
diff --git a/src/operator/SoftmaxImpl.cpp b/src/operator/SoftmaxImpl.cpp
index 45b455a3..ae89090a 100644
--- a/src/operator/SoftmaxImpl.cpp
+++ b/src/operator/SoftmaxImpl.cpp
@@ -28,20 +28,18 @@ Aidge::NbElts_t Aidge::SoftmaxImpl_cpu::getNbRequiredProtected(const Aidge::IOIn
 
 void Aidge::SoftmaxImpl_cpu::forward() {
     assert(mOp.getInput(0) && "missing input #0");
-    assert(mOp.getInput(0)->nbDims()>1);
+    // assert(mOp.getInput(0)->nbDims()>1);
 
     // Find the correct kernel type
     auto kernelFunc = Registrar<SoftmaxImplForward_cpu>::create({
         mOp.getInput(0)->dataType(),
         mOp.getOutput(0)->dataType()});
 
-    DimSize_t batchSize = mOp.getInput(0)->dims()[0];
-    DimSize_t channelSize = mOp.getInput(0)->dims()[1];
-    DimSize_t featureSize = mOp.getInput(0)->sizeM1()/channelSize;
+    Softmax_Op::Attrs attr = dynamic_cast<const Softmax_Op&>(mOp).getStaticAttributes();
+    const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
     // Call kernel
-    kernelFunc(batchSize,
-               channelSize,
-               featureSize,
+    kernelFunc(axisIdx,
+               mOp.getInput(0)->dims(),
                mOp.getInput(0)->getImpl()->rawPtr(),
                mOp.getOutput(0)->getImpl()->rawPtr());
 }
diff --git a/unit_tests/operator/Test_SoftmaxImpl.cpp b/unit_tests/operator/Test_SoftmaxImpl.cpp
index bad34102..81b73786 100644
--- a/unit_tests/operator/Test_SoftmaxImpl.cpp
+++ b/unit_tests/operator/Test_SoftmaxImpl.cpp
@@ -39,7 +39,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") {
             }
         });
 
-        std::shared_ptr<Node> mySoftmax = Softmax();
+        std::shared_ptr<Node> mySoftmax = Softmax(1);
         mySoftmax->getOperator()->setDatatype(DataType::Float32);
         mySoftmax->getOperator()->setBackend("cpu");
         mySoftmax->getOperator()->associateInput(0,input);
@@ -48,7 +48,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") {
 
         float* resPtr = static_cast<float*>(mySoftmax->getOperator()->getOutput(0)->getImpl()->rawPtr());
         float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
-        for (std::size_t i = 0; i< 20; ++i) {
+        for (std::size_t i = 0; i< expectedOutput->size(); ++i) {
             REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
         }
 
@@ -107,7 +107,7 @@ TEST_CASE("[cpu/operator] Softmax(forward)") {
             }
         });
 
-        std::shared_ptr<Node> mySoftmax = Softmax();
+        std::shared_ptr<Node> mySoftmax = Softmax(1);
         mySoftmax->getOperator()->setDatatype(DataType::Float32);
         mySoftmax->getOperator()->setBackend("cpu");
         mySoftmax->getOperator()->associateInput(0,input);
@@ -116,9 +116,8 @@ TEST_CASE("[cpu/operator] Softmax(forward)") {
 
         float* resPtr = static_cast<float*>(mySoftmax->getOperator()->getOutput(0)->getImpl()->rawPtr());
         float* expectedPtr = static_cast<float*>(expectedOutput->getImpl()->rawPtr());
-        for (std::size_t i = 0; i< 54; ++i) {
+        for (std::size_t i = 0; i< expectedOutput->size(); ++i) {
             REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001);
         }
-        // REQUIRE(*mySoftmax->getOperator()->getOutput(0) == *expectedOutput);
     }
 }
\ No newline at end of file
-- 
GitLab