Skip to content
Snippets Groups Projects
Commit 9c8c9733 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add ArgMax Op

parent 69c994fe
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!75Learning backend cuda
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
#define AIDGE_CPU_IMPORTS_H_ #define AIDGE_CPU_IMPORTS_H_
#include "aidge/backend/cpu/operator/AddImpl.hpp" #include "aidge/backend/cpu/operator/AddImpl.hpp"
#include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
#include "aidge/backend/cpu/operator/AvgPoolingImpl.hpp" #include "aidge/backend/cpu/operator/AvgPoolingImpl.hpp"
#include "aidge/backend/cpu/operator/MaxPoolingImpl.hpp" #include "aidge/backend/cpu/operator/MaxPoolingImpl.hpp"
#include "aidge/backend/cpu/operator/BatchNormImpl.hpp" #include "aidge/backend/cpu/operator/BatchNormImpl.hpp"
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_ARGMAXIMPL_H_
#define AIDGE_CPU_OPERATOR_ARGMAXIMPL_H_
#include <array>
#include <memory>
#include <tuple>
#include <vector>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
namespace Aidge {
class ArgMaxImplForward_cpu
: public Registrable<ArgMaxImplForward_cpu,
std::tuple<DataType, DataType>,
void(std::int32_t,
DimSize_t,
const std::vector<DimSize_t>&,
const void *,
void *)> {};
class ArgMaxImplBackward_cpu
: public Registrable<ArgMaxImplBackward_cpu,
std::tuple<DataType, DataType>,
void(std::int32_t,
DimSize_t,
const std::vector<DimSize_t>&,
const void *,
void *)> {};
class ArgMaxImpl_cpu : public OperatorImpl {
public:
ArgMaxImpl_cpu(const ArgMax_Op& op) : OperatorImpl(op, "cpu") {}
static std::unique_ptr<ArgMaxImpl_cpu> create(const ArgMax_Op &op) {
return std::make_unique<ArgMaxImpl_cpu>(op);
}
public:
void forward() override;
};
namespace {
static Registrar<ArgMax_Op> registrarArgMaxImpl_cpu("cpu", Aidge::ArgMaxImpl_cpu::create);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_ARGMAXIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_ARGMAXIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_ARGMAXIMPL_FORWARD_KERNEL_H_
#include <algorithm> // std::for_each
#include <cstddef> // std::size_t
#include <cstdint> // std::int32_t
#include <functional> //std::multiplies
#include <numeric> //std::accumulate
#include <vector>
#include <limits>
#include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
#include "aidge/data/Data.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/utils/Registrar.hpp"
namespace Aidge {
template <class I, class O>
void ArgMaxImpl_cpu_forward_kernel(std::int32_t axis_,
DimSize_t select_last_index,
const std::vector<DimSize_t>& inputDims,
const void* input_,
void* output_) {
const I* input = static_cast<const I*>(input_);
O* output = static_cast<O*>(output_);
const std::size_t axis = static_cast<std::size_t>(axis_);
const std::size_t nb_dims = inputDims.size();
auto stride_post = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]);
stride_post[nb_dims - 1] = 1;
for (std::size_t i = nb_dims-2; i != static_cast<std::size_t>(-1); --i) {
stride_post[i] = stride_post[i+1]*inputDims[i+1];
}
auto stride_pre = std::unique_ptr<std::size_t[]>(new std::size_t[nb_dims]);
stride_pre[0] = 1;
for (std::size_t i = 1; i < nb_dims; ++i) {
stride_pre[i] = stride_pre[i-1]*inputDims[i-1];
}
const std::size_t dim_i = inputDims[axis];
for (std::size_t pre = 0; pre < stride_pre[axis]; ++pre) {
for (std::size_t post = 0; post < stride_post[axis]; ++post) {
const std::size_t idx_i = pre * dim_i * stride_post[axis] + post;
const std::size_t idx_o = pre * stride_post[axis] + post;
I max = std::numeric_limits<I>::min();
for (std::size_t i = 0; i < dim_i; ++i) {
if (select_last_index) {
if (input[idx_i]>=max)
{
output[idx_o] = i;
}
}
else {
if (input[idx_i] > max)
{
output[idx_o] = i;
}
}
}
}
}
}
namespace {
static Registrar<ArgMaxImplForward_cpu> registrarArgMaxImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::ArgMaxImpl_cpu_forward_kernel<float, float>);
static Registrar<ArgMaxImplForward_cpu> registrarArgMaxImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::ArgMaxImpl_cpu_forward_kernel<int, int>);
static Registrar<ArgMaxImplForward_cpu> registrarArgMaxImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::ArgMaxImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_ARGMAXIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/backend/cpu/operator/ArgMaxImpl.hpp"
#include <memory>
#include <vector>
#include "aidge/utils/Types.h"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/backend/cpu/operator/ArgMaxImpl_forward_kernels.hpp"
void Aidge::ArgMaxImpl_cpu::forward() {
const ArgMax_Op& op_ = dynamic_cast<const ArgMax_Op&>(mOp);
// Find the correct kernel type
auto kernelFunc = Registrar<ArgMaxImplForward_cpu>::create({
op_.getInput(0)->dataType(),
op_.getOutput(0)->dataType()});
// Call kernel
kernelFunc(op_.axis(),
op_.selectLastIndex(),
op_.getInput(0)->dims(),
op_.getInput(0)->getImpl()->rawPtr(),
op_.getOutput(0)->getImpl()->rawPtr());
}
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <memory>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/ArgMax.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/backend/cpu.hpp"
#include "aidge/utils/TensorUtils.hpp"
using namespace Aidge;
TEST_CASE("[cpu/operator] ArgMax(forward)", "[ArgMax][CPU]") {
SECTION("3D Tensor") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array3D<float,2,3,4> {
{
{
{ 1.0, 2.0, 3.0, 4.0},
{ 8.0, 0.0, 17.0, 1.0},
{ 5.0, 10.0, 6.0, 0.0}
},
{
{ 7.0, 1.0, 9.0, 4.0},
{ 0.0, 8.0, 4.0, 2.0},
{ 9.0, 2.0, 0.0, 5.0}
}
}
});
SECTION("Axis 2") {
Tensor myOutput = Tensor(Array2D<float,2,3> {
{
{ 3.0, 2.0, 1.0 },
{ 2.0, 1.0, 0.0}
}
});
std::shared_ptr<Node> myArgMax = ArgMax(2);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
SECTION("Axis 1") {
Tensor myOutput = Tensor(Array2D<float,2,4> {
{
{ 1.0, 2.0, 1.0, 0.0 },
{ 2.0, 1.0, 0.0, 2.0 }
}
});
std::shared_ptr<Node> myArgMax = ArgMax(1);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
myOutput.print();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
SECTION("Axis 0") {
Tensor myOutput = Tensor(Array2D<float,3,4> {
{
{ 1.0, 0.0, 1.0, 0.0 },
{ 0.0, 1.0, 0.0, 1.0 },
{ 1.0, 0.0, 0.0, 1.0 }
},
});
std::shared_ptr<Node> myArgMax = ArgMax(1);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
myOutput.print();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == myOutput);
}
}
SECTION("Select_Last_Index") {
std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array1D<float,10> {
{
1.0, 5.0, 9.0, 0.0, 6.0, 2.0, 9.0, 4.0, 3.0, 9.0
}
});
std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array1D<float,1> {{9}});
std::shared_ptr<Node> myArgMax = ArgMax(0, 1, 1);
auto op = std::static_pointer_cast<OperatorTensor>(myArgMax -> getOperator());
op->associateInput(0,myInput);
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myArgMax->forward();
op->getOutput(0)->print();
REQUIRE(*(op->getOutput(0)) == *myOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment