Skip to content
Snippets Groups Projects
Commit 9c8c9733 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add ArgMax Op

parent 69c994fe
Branches
No related tags found
No related merge requests found
......@@ -13,6 +13,7 @@
#define AIDGE_CPU_IMPORTS_H_
#include "aidge/backend/cpu/operator/AddImpl.hpp"
#include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
#include "aidge/backend/cpu/operator/AvgPoolingImpl.hpp"
#include "aidge/backend/cpu/operator/MaxPoolingImpl.hpp"
#include "aidge/backend/cpu/operator/BatchNormImpl.hpp"
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_ARGMAXIMPL_H_
#define AIDGE_CPU_OPERATOR_ARGMAXIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class ArgMaxImplForward_cpu
: public Registrable<ArgMaxImplForward_cpu,
std::tuple<DataType, DataType>,
void(std::int32_t,
DimSize_t,
const std::vector<DimSize_t>&,
const void *,
void *)> {};
class ArgMaxImplBackward_cpu
: public Registrable<ArgMaxImplBackward_cpu,
std::tuple<DataType, DataType>,
void(std::int32_t,
DimSize_t,
const std::vector<DimSize_t>&,
const void *,
void *)> {};
class ArgMaxImpl_cpu : public OperatorImpl {
public:
ArgMaxImpl_cpu(const ArgMax_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<ArgMaxImpl_cpu> create(const ArgMax_Op &op) {
return std::make_unique<ArgMaxImpl_cpu>(op);
}
public:
void forward() override;
};
namespace {
static Registrar<ArgMax_Op> registrarArgMaxImpl_cpu("cpu", Aidge::ArgMaxImpl_cpu::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_ARGMAXIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_ARGMAXIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_ARGMAXIMPL_FORWARD_KERNEL_H_
#include <algorithm> // std::for_each
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <functional> //std::multiplies
#include <numeric> //std::accumulate
#include <vector>
#include <limits>
#include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O>
void ArgMaxImpl_cpu_forward_kernel(std::int32_t axis_,
DimSize_t select_last_index,
const std::vector<DimSize_t>& inputDims,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
const std::size_t axis = static_cast<std::size_t>(axis_);
const std::size_t nb_dims = inputDims.size();
auto stride_post = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]);
stride_post[nb_dims - 1] = 1;
for (std::size_t i = nb_dims-2; i != static_cast<std::size_t>(-1); --i) {
stride_post[i] = stride_post[i+1]*inputDims[i+1];
}
auto stride_pre = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]);
stride_pre[0] = 1;
for (std::size_t i = 1; i < nb_dims; ++i) {
stride_pre[i] = stride_pre[i-1]*inputDims[i-1];
}
const std::size_t dim_i = inputDims[axis];
for (std::size_t pre = 0; pre < stride_pre[axis]; ++pre) {
for (std::size_t post = 0; post < stride_post[axis]; ++post) {
const std::size_t idx_i = pre * dim_i * stride_post[axis] + post;
const std::size_t idx_o = pre * stride_post[axis] + post;
I max = std::numeric_limits<I>::min();
for (std::size_t i = 0; i < dim_i; ++i) {
if (select_last_index) {
if (input[idx_i]>=max)
{
output[idx_o] = i;
}
}
else {
if (input[idx_i] > max)
{
output[idx_o] = i;
}
}
}
}
}
}
namespace {
static Registrar<ArgMaxImplForward_cpu> registrarArgMaxImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ArgMaxImpl_cpu_forward_kernel<float, float>);
static Registrar<ArgMaxImplForward_cpu> registrarArgMaxImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ArgMaxImpl_cpu_forward_kernel<int, int>);
static Registrar<ArgMaxImplForward_cpu> registrarArgMaxImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ArgMaxImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_ARGMAXIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
#include <memory>
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp"
void Aidge::ArgMaxImpl_cpu::forward() {
const ArgMax_Op& op_ = dynamic_cast<const ArgMax_Op&>(mOp);
// Find the correct kernel type
auto kernelFunc = Registrar<ArgMaxImplForward_cpu>::create({
op_.getInput(0)->dataType(),
op_.getOutput(0)->dataType()});
// Call kernel
kernelFunc(op_.axis(),
op_.selectLastIndex(),
op_.getInput(0)->dims(),
op_.getInput(0)->getImpl()->rawPtr(),
op_.getOutput(0)->getImpl()->rawPtr());
}
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <memory>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/backend/cpu.hpp"
#include "aidge/utils/TensorUtils.hpp"
using namespace Aidge;
TEST_CASE("[cpu/operator] ArgMax(forward)", "[ArgMax][CPU]") {
SECTION("3D Tensor") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<float,2,3,4> {
{
{
{ 1.0, 2.0, 3.0, 4.0},
{ 8.0, 0.0, 17.0, 1.0},
{ 5.0, 10.0, 6.0, 0.0}
},
{
{ 7.0, 1.0, 9.0, 4.0},
{ 0.0, 8.0, 4.0, 2.0},
{ 9.0, 2.0, 0.0, 5.0}
}
}
});
SECTION("Axis 2") {
Tensor myOutput = Tensor(Array2D<float,2,3> {
{
{ 3.0, 2.0, 1.0 },
{ 2.0, 1.0, 0.0}
}
});
std::shared_ptr<Node> myArgMax = ArgMax(2);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
SECTION("Axis 1") {
Tensor myOutput = Tensor(Array2D<float,2,4> {
{
{ 1.0, 2.0, 1.0, 0.0 },
{ 2.0, 1.0, 0.0, 2.0 }
}
});
std::shared_ptr<Node> myArgMax = ArgMax(1);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
myOutput.print();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
SECTION("Axis 0") {
Tensor myOutput = Tensor(Array2D<float,3,4> {
{
{ 1.0, 0.0, 1.0, 0.0 },
{ 0.0, 1.0, 0.0, 1.0 },
{ 1.0, 0.0, 0.0, 1.0 }
},
});
std::shared_ptr<Node> myArgMax = ArgMax(1);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
myOutput.print();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
}
SECTION("Select_Last_Index") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array1D<float,10> {
{
1.0, 5.0, 9.0, 0.0, 6.0, 2.0, 9.0, 4.0, 3.0, 9.0
}
});
std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array1D<float,1> {{9}});
std::shared_ptr<Node> myArgMax = ArgMax(0, 1, 1);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *myOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment