From 603195199ea6efb4746a378b44b3db36598b7675 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Wed, 31 Jul 2024 16:36:43 +0200 Subject: [PATCH] fix ArgMax kernel --- .../operator/ArgMaxImpl_forward_kernels.hpp | 35 ++-- unit_tests/operator/Test_ArgMaxImpl.cpp | 149 ++++++++++++++++++ 2 files changed, 164 insertions(+), 20 deletions(-) create mode 100644 unit_tests/operator/Test_ArgMaxImpl.cpp diff --git a/include/aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp index a03a0244..cea7d973 100644 --- a/include/aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp @@ -38,39 +38,34 @@ void ArgMaxImpl_cpu_forward_kernel(std::int32_t axis_, const std::size_t axis = static_cast<std::size_t>(axis_); - const std::size_t nb_dims = inputDims.size(); - - auto stride_post = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]); - stride_post[nb_dims - 1] = 1; - for (std::size_t i = nb_dims-2; i != static_cast<std::size_t>(-1); --i) { - stride_post[i] = stride_post[i+1]*inputDims[i+1]; + std::size_t stride_post = 1; + for (std::size_t i = axis + 1; i < inputDims.size(); ++i) { + stride_post *= inputDims[i]; } - auto stride_pre = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]); - stride_pre[0] = 1; - for (std::size_t i = 1; i < nb_dims; ++i) { - stride_pre[i] = stride_pre[i-1]*inputDims[i-1]; + std::size_t stride_pre = 1; + for (std::size_t i = 0; i < axis; ++i) { + stride_pre *= inputDims[i]; } - const std::size_t dim_i = inputDims[axis]; - for (std::size_t pre = 0; pre < stride_pre[axis]; ++pre) { - for (std::size_t post = 0; post < stride_post[axis]; ++post) { - const std::size_t idx_i = pre * dim_i * stride_post[axis] + post; - const std::size_t idx_o = pre * stride_post[axis] + post; + for (std::size_t pre = 0; pre < stride_pre; ++pre) { + for (std::size_t post = 0; post < stride_post; ++post) { + const std::size_t idx_i = pre * dim_i * stride_post + post; + const std::size_t idx_o = pre * stride_post + post; I max = std::numeric_limits<I>::min(); for (std::size_t i = 0; i < dim_i; ++i) { + I curr_value = input[idx_i + i*stride_post]; if (select_last_index) { - if (input[idx_i]>=max) - { + if (curr_value>=max) { output[idx_o] = i; + max = curr_value; } } else { - if (input[idx_i] > max) - { + if (curr_value > max) { output[idx_o] = i; + max = curr_value; } } - } } } diff --git a/unit_tests/operator/Test_ArgMaxImpl.cpp b/unit_tests/operator/Test_ArgMaxImpl.cpp new file mode 100644 index 00000000..c8873e48 --- /dev/null +++ b/unit_tests/operator/Test_ArgMaxImpl.cpp @@ -0,0 +1,149 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <memory> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/ArgMax.hpp" +#include "aidge/operator/Conv.hpp" + +#include "aidge/backend/cpu.hpp" +#include "aidge/utils/TensorUtils.hpp" + +using namespace Aidge; + +TEST_CASE("[cpu/operator] ArgMax(forward)", "[ArgMax][CPU]") { + SECTION("3D Tensor") { + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<float,2,3,4> { + { + { + { 1.0, 2.0, 3.0, 4.0}, + { 8.0, 0.0, 17.0, 1.0}, + { 5.0, 10.0, 6.0, 0.0} + }, + { + { 7.0, 1.0, 9.0, 4.0}, + { 0.0, 8.0, 4.0, 2.0}, + { 9.0, 2.0, 0.0, 5.0} + } + } + }); + SECTION("Axis 2") { + + Tensor myOutput = Tensor(Array3D<float,2,3, 1> { + { + { + {3.0}, + {2.0}, + {1.0} + }, + { + {2.0}, + {1.0}, + {0.0} + } + } + }); + + std::shared_ptr<Node> myArgMax = ArgMax(2); + auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator()); + op->associateInput(0,myInput); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myArgMax->forward(); + + REQUIRE(*(op->getOutput(0)) == myOutput); + } + SECTION("Axis 2 with keep_dims false") { + + Tensor myOutput = Tensor(Array2D<float,2,3> { + { + { 3.0, 2.0, 1.0 }, + { 2.0, 1.0, 0.0 } + } + }); + + std::shared_ptr<Node> myArgMax = ArgMax(2,0); + auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator()); + op->associateInput(0,myInput); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myArgMax->forward(); + + REQUIRE(*(op->getOutput(0)) == myOutput); + } + SECTION("Axis 1") { + Tensor myOutput = Tensor(Array3D<float,2,1,4> { + { + { + { 1.0, 2.0, 1.0, 0.0 } + }, + { + { 2.0, 1.0, 0.0, 2.0 } + } + } + }); + + std::shared_ptr<Node> myArgMax = ArgMax(1); + auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator()); + op->associateInput(0,myInput); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myArgMax->forward(); + + REQUIRE(*(op->getOutput(0)) == myOutput); + } + SECTION("Axis 0") { + Tensor myOutput = Tensor(Array3D<float,1,3,4> { + { + { + { 1.0, 0.0, 1.0, 0.0 }, + { 0.0, 1.0, 0.0, 1.0 }, + { 1.0, 0.0, 0.0, 1.0 } + } + } + }); + + std::shared_ptr<Node> myArgMax = ArgMax(0); + auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator()); + op->associateInput(0,myInput); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + std::cout << " ............... "<< std::endl; + myArgMax->forward(); + op->getOutput(0)->print(); + std::cout <<"------"<<std::endl; + myOutput.print(); + + REQUIRE(*(op->getOutput(0)) == myOutput); + } + } + SECTION("Select_Last_Index") { + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array1D<float,10> { + { + 1.0, 5.0, 9.0, 0.0, 6.0, 2.0, 9.0, 4.0, 3.0, 9.0 + } + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array1D<float,1> {{9}}); + + std::shared_ptr<Node> myArgMax = ArgMax(0, 1, 1); + auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator()); + op->associateInput(0,myInput); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myArgMax->forward(); + op->getOutput(0)->print(); + + REQUIRE(*(op->getOutput(0)) == *myOutput); + + } +} \ No newline at end of file -- GitLab