From 603195199ea6efb4746a378b44b3db36598b7675 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Wed, 31 Jul 2024 16:36:43 +0200
Subject: [PATCH] fix ArgMax kernel

---
 .../operator/ArgMaxImpl_forward_kernels.hpp   |  35 ++--
 unit_tests/operator/Test_ArgMaxImpl.cpp       | 149 ++++++++++++++++++
 2 files changed, 164 insertions(+), 20 deletions(-)
 create mode 100644 unit_tests/operator/Test_ArgMaxImpl.cpp

diff --git a/include/aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp
index a03a0244..cea7d973 100644
--- a/include/aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp
@@ -38,39 +38,34 @@ void ArgMaxImpl_cpu_forward_kernel(std::int32_t axis_,
 
     const std::size_t axis = static_cast<std::size_t>(axis_);
 
-    const std::size_t nb_dims = inputDims.size();
-
-    auto stride_post = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]);
-    stride_post[nb_dims - 1] = 1;
-    for (std::size_t i = nb_dims-2; i != static_cast<std::size_t>(-1); --i) {
-        stride_post[i] = stride_post[i+1]*inputDims[i+1];
+    std::size_t stride_post = 1;
+    for (std::size_t i = axis + 1; i < inputDims.size(); ++i) {
+        stride_post *= inputDims[i];
     }
-    auto stride_pre = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]);
-    stride_pre[0] = 1;
-    for (std::size_t i = 1; i < nb_dims; ++i) {
-        stride_pre[i] = stride_pre[i-1]*inputDims[i-1];
+    std::size_t stride_pre = 1;
+    for (std::size_t i = 0; i < axis; ++i) {
+        stride_pre *= inputDims[i];
     }
-
     const std::size_t dim_i = inputDims[axis];
-    for (std::size_t pre = 0; pre < stride_pre[axis]; ++pre) {
-        for (std::size_t post = 0; post < stride_post[axis]; ++post) {
-            const std::size_t idx_i = pre * dim_i * stride_post[axis] + post;
-            const std::size_t idx_o = pre * stride_post[axis] + post;
+    for (std::size_t pre = 0; pre < stride_pre; ++pre) {
+        for (std::size_t post = 0; post < stride_post; ++post) {
+            const std::size_t idx_i = pre * dim_i * stride_post + post;
+            const std::size_t idx_o = pre * stride_post + post;
             I max = std::numeric_limits<I>::min();
             for (std::size_t i = 0; i < dim_i; ++i) {
+                I curr_value = input[idx_i + i*stride_post];
                 if (select_last_index) {
-                    if (input[idx_i]>=max)
-                    {
+                    if (curr_value>=max) {
                         output[idx_o] = i;
+                        max = curr_value;
                     }
                 }
                 else {
-                    if (input[idx_i] > max)
-                    {
+                    if (curr_value > max) {
                         output[idx_o] = i;
+                        max = curr_value;
                     }
                 }
-                
             }
         }
     }
diff --git a/unit_tests/operator/Test_ArgMaxImpl.cpp b/unit_tests/operator/Test_ArgMaxImpl.cpp
new file mode 100644
index 00000000..c8873e48
--- /dev/null
+++ b/unit_tests/operator/Test_ArgMaxImpl.cpp
@@ -0,0 +1,149 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <catch2/catch_test_macros.hpp>
+#include <memory>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/operator/ArgMax.hpp"
+#include "aidge/operator/Conv.hpp"
+
+#include "aidge/backend/cpu.hpp"
+#include "aidge/utils/TensorUtils.hpp"
+
+using namespace Aidge;
+
+TEST_CASE("[cpu/operator] ArgMax(forward)", "[ArgMax][CPU]") {
+    SECTION("3D Tensor") {
+            std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<float,2,3,4> {
+                {
+                    {
+                        { 1.0, 2.0, 3.0, 4.0},
+                        { 8.0, 0.0, 17.0, 1.0},
+                        { 5.0, 10.0, 6.0, 0.0}
+                    },
+                    {
+                        { 7.0, 1.0, 9.0, 4.0},
+                        { 0.0, 8.0, 4.0, 2.0},
+                        { 9.0, 2.0, 0.0, 5.0}
+                    }
+                }
+            });
+        SECTION("Axis 2") {
+
+            Tensor myOutput = Tensor(Array3D<float,2,3, 1> {
+               { 
+                    { 
+                        {3.0},
+                        {2.0},
+                        {1.0}
+                    },
+                    {
+                        {2.0},
+                        {1.0},
+                        {0.0}
+                    }
+               }
+            });
+
+            std::shared_ptr<Node> myArgMax = ArgMax(2);
+            auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
+            op->associateInput(0,myInput);
+            op->setDataType(DataType::Float32);
+            op->setBackend("cpu");
+            myArgMax->forward();
+
+            REQUIRE(*(op->getOutput(0)) == myOutput);
+        }
+        SECTION("Axis 2 with keep_dims false") {
+
+            Tensor myOutput = Tensor(Array2D<float,2,3> {
+               { 
+                    { 3.0, 2.0, 1.0 },
+                    { 2.0, 1.0, 0.0 }
+               }
+            });
+
+            std::shared_ptr<Node> myArgMax = ArgMax(2,0);
+            auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
+            op->associateInput(0,myInput);
+            op->setDataType(DataType::Float32);
+            op->setBackend("cpu");
+            myArgMax->forward();
+
+            REQUIRE(*(op->getOutput(0)) == myOutput);
+        }
+        SECTION("Axis 1") {
+            Tensor myOutput = Tensor(Array3D<float,2,1,4> {
+                {
+                    {
+                        { 1.0, 2.0, 1.0, 0.0 }
+                    },
+                    {
+                        { 2.0, 1.0, 0.0, 2.0 }
+                    }
+                }
+            });
+
+            std::shared_ptr<Node> myArgMax = ArgMax(1);
+            auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
+            op->associateInput(0,myInput);
+            op->setDataType(DataType::Float32);
+            op->setBackend("cpu");
+            myArgMax->forward();
+
+            REQUIRE(*(op->getOutput(0)) == myOutput);
+        }
+        SECTION("Axis 0") {
+            Tensor myOutput = Tensor(Array3D<float,1,3,4> {
+                {
+                    {
+                        { 1.0, 0.0, 1.0, 0.0 },
+                        { 0.0, 1.0, 0.0, 1.0 },
+                        { 1.0, 0.0, 0.0, 1.0 }
+                    }
+                }
+            });
+
+            std::shared_ptr<Node> myArgMax = ArgMax(0);
+            auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
+            op->associateInput(0,myInput);
+            op->setDataType(DataType::Float32);
+            op->setBackend("cpu");
+            std::cout << " ...............  "<< std::endl;
+            myArgMax->forward();
+            op->getOutput(0)->print();
+            std::cout <<"------"<<std::endl;
+            myOutput.print();
+
+            REQUIRE(*(op->getOutput(0)) == myOutput);
+        }
+    }
+    SECTION("Select_Last_Index") {
+        std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array1D<float,10> {
+            {
+                1.0, 5.0, 9.0, 0.0, 6.0, 2.0, 9.0, 4.0, 3.0, 9.0
+            }
+        });
+        std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array1D<float,1> {{9}});
+
+        std::shared_ptr<Node> myArgMax = ArgMax(0, 1, 1);
+        auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
+        op->associateInput(0,myInput);
+        op->setDataType(DataType::Float32);
+        op->setBackend("cpu");
+        myArgMax->forward();
+        op->getOutput(0)->print();
+
+        REQUIRE(*(op->getOutput(0)) == *myOutput);
+
+    }
+}
\ No newline at end of file
-- 
GitLab