Skip to content
Snippets Groups Projects
Commit 60319519 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix ArgMax kernel

parent 9c8c9733
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!75Learning backend cuda
...@@ -38,39 +38,34 @@ void ArgMaxImpl_cpu_forward_kernel(std::int32_t axis_, ...@@ -38,39 +38,34 @@ void ArgMaxImpl_cpu_forward_kernel(std::int32_t axis_,
const std::size_t axis = static_cast<std::size_t>(axis_); const std::size_t axis = static_cast<std::size_t>(axis_);
const std::size_t nb_dims = inputDims.size(); std::size_t stride_post = 1;
for (std::size_t i = axis + 1; i < inputDims.size(); ++i) {
auto stride_post = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]); stride_post *= inputDims[i];
stride_post[nb_dims - 1] = 1;
for (std::size_t i = nb_dims-2; i != static_cast<std::size_t>(-1); --i) {
stride_post[i] = stride_post[i+1]*inputDims[i+1];
} }
auto stride_pre = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]); std::size_t stride_pre = 1;
stride_pre[0] = 1; for (std::size_t i = 0; i < axis; ++i) {
for (std::size_t i = 1; i < nb_dims; ++i) { stride_pre *= inputDims[i];
stride_pre[i] = stride_pre[i-1]*inputDims[i-1];
} }
const std::size_t dim_i = inputDims[axis]; const std::size_t dim_i = inputDims[axis];
for (std::size_t pre = 0; pre < stride_pre[axis]; ++pre) { for (std::size_t pre = 0; pre < stride_pre; ++pre) {
for (std::size_t post = 0; post < stride_post[axis]; ++post) { for (std::size_t post = 0; post < stride_post; ++post) {
const std::size_t idx_i = pre * dim_i * stride_post[axis] + post; const std::size_t idx_i = pre * dim_i * stride_post + post;
const std::size_t idx_o = pre * stride_post[axis] + post; const std::size_t idx_o = pre * stride_post + post;
I max = std::numeric_limits<I>::min(); I max = std::numeric_limits<I>::min();
for (std::size_t i = 0; i < dim_i; ++i) { for (std::size_t i = 0; i < dim_i; ++i) {
I curr_value = input[idx_i + i*stride_post];
if (select_last_index) { if (select_last_index) {
if (input[idx_i]>=max) if (curr_value>=max) {
{
output[idx_o] = i; output[idx_o] = i;
max = curr_value;
} }
} }
else { else {
if (input[idx_i] > max) if (curr_value > max) {
{
output[idx_o] = i; output[idx_o] = i;
max = curr_value;
} }
} }
} }
} }
} }
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <memory>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/backend/cpu.hpp"
#include "aidge/utils/TensorUtils.hpp"
using namespace Aidge;
TEST_CASE("[cpu/operator] ArgMax(forward)", "[ArgMax][CPU]") {
SECTION("3D Tensor") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<float,2,3,4> {
{
{
{ 1.0, 2.0, 3.0, 4.0},
{ 8.0, 0.0, 17.0, 1.0},
{ 5.0, 10.0, 6.0, 0.0}
},
{
{ 7.0, 1.0, 9.0, 4.0},
{ 0.0, 8.0, 4.0, 2.0},
{ 9.0, 2.0, 0.0, 5.0}
}
}
});
SECTION("Axis 2") {
Tensor myOutput = Tensor(Array3D<float,2,3, 1> {
{
{
{3.0},
{2.0},
{1.0}
},
{
{2.0},
{1.0},
{0.0}
}
}
});
std::shared_ptr<Node> myArgMax = ArgMax(2);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
SECTION("Axis 2 with keep_dims false") {
Tensor myOutput = Tensor(Array2D<float,2,3> {
{
{ 3.0, 2.0, 1.0 },
{ 2.0, 1.0, 0.0 }
}
});
std::shared_ptr<Node> myArgMax = ArgMax(2,0);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
SECTION("Axis 1") {
Tensor myOutput = Tensor(Array3D<float,2,1,4> {
{
{
{ 1.0, 2.0, 1.0, 0.0 }
},
{
{ 2.0, 1.0, 0.0, 2.0 }
}
}
});
std::shared_ptr<Node> myArgMax = ArgMax(1);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
SECTION("Axis 0") {
Tensor myOutput = Tensor(Array3D<float,1,3,4> {
{
{
{ 1.0, 0.0, 1.0, 0.0 },
{ 0.0, 1.0, 0.0, 1.0 },
{ 1.0, 0.0, 0.0, 1.0 }
}
}
});
std::shared_ptr<Node> myArgMax = ArgMax(0);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
std::cout << " ............... "<< std::endl;
myArgMax->forward();
op->getOutput(0)->print();
std::cout <<"------"<<std::endl;
myOutput.print();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
}
SECTION("Select_Last_Index") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array1D<float,10> {
{
1.0, 5.0, 9.0, 0.0, 6.0, 2.0, 9.0, 4.0, 3.0, 9.0
}
});
std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array1D<float,1> {{9}});
std::shared_ptr<Node> myArgMax = ArgMax(0, 1, 1);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *myOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment