Skip to content
Snippets Groups Projects
Commit 0a694579 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add Gather operator

parent 10cfb6dd
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!20Vit operators
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_GATHERIMPL_H_
#define AIDGE_CPU_OPERATOR_GATHERIMPL_H_
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/Gather.hpp"
#include "aidge/utils/Registrar.hpp"
#include "aidge/utils/Types.h"
#include <memory>
#include <vector>
namespace Aidge {
// class Gather_Op;
// compute kernel registry for forward and backward
class GatherImplForward_cpu
: public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(const int, const std::vector<DimSize_t>, const std::vector<DimSize_t>, const void*, const void*, void*)> {
};
class GatherImplBackward_cpu
: public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(const int, const std::vector<DimSize_t>, const std::vector<DimSize_t>, const void*, const void*, void*)> {
};
class GatherImpl_cpu : public OperatorImpl {
public:
GatherImpl_cpu(const Gather_Op& op) : OperatorImpl(op) {}
static std::unique_ptr<GatherImpl_cpu> create(const Gather_Op& op) {
return std::make_unique<GatherImpl_cpu>(op);
}
NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final;
void forward() override;
};
namespace {
static Registrar<Gather_Op> registrarGatherImpl_cpu("cpu", Aidge::GatherImpl_cpu::create);
}
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_GATHERIMPL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_
#define AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_
#include "aidge/utils/Registrar.hpp"
#include <cstddef>
#include <cmath>
#include "aidge/data/Data.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/operator/GatherImpl.hpp"
namespace Aidge {
template <class I, class O>
void GatherImpl_cpu_forward_kernel(const int& axisIdx_, std::vector<DimSize_t> inputDims, const std::vector<DimSize_t> indicesDims, const void* input_, const void* indexes_, void* output_)
{
const I* input = static_cast<const I*>(input_);
const int* indexes = static_cast<const int*>(indexes_);
const std::size_t axisIdx = axisIdx_;
O* output = static_cast<O*>(output_);
// Calculate the total number of elements in the input array
size_t totalElements = 1;
for (size_t dimSize : inputDims) {
totalElements *= dimSize;
}
std::size_t nbElemAfterAxis = 1;
std::size_t nbElemBeforeAxis = 1;
for (size_t d = 0; d < inputDims.size(); ++d) {
if( d < axisIdx )
nbElemBeforeAxis *= inputDims[d];
else if ( d > axisIdx )
nbElemAfterAxis *= inputDims[d];
}
for (std::size_t i=0; i<nbElemBeforeAxis; ++i)
{
for(std::size_t idxRow=0; idxRow<indicesDims[0]; ++idxRow)
{
for(std::size_t idxCol=0; idxCol<indicesDims[1]; ++idxCol)
{
std::size_t idx = indexes[indicesDims[1] * idxRow + idxCol];
const I* startPtr = std::next(input, i * nbElemAfterAxis * inputDims[axisIdx] + idx * nbElemAfterAxis);
std::copy_n(startPtr, nbElemAfterAxis, output);
output += nbElemAfterAxis;
}
}
}
}
namespace {
static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Float32(
{DataType::Float32, DataType::Float32}, Aidge::GatherImpl_cpu_forward_kernel<float, float>);
static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Int32(
{DataType::Int32, DataType::Int32}, Aidge::GatherImpl_cpu_forward_kernel<int, int>);
static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Float64(
{DataType::Float64, DataType::Float64}, Aidge::GatherImpl_cpu_forward_kernel<double, double>);
} // namespace
} // namespace Aidge
#endif /* AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <chrono> // std::chrono::milliseconds
#include <numeric> // std::accumulate
#include <thread> // std::this_thread::sleep_for
#include <vector>
#include "aidge/operator/Gather.hpp"
#include "aidge/utils/Types.h"
#include "aidge/backend/cpu/operator/GatherImpl.hpp"
#include "aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp"
#include <iostream>
Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const {
// this implementation can be in-place
return 0;
}
void Aidge::GatherImpl_cpu::forward() {
assert(mOp.getInput(0) && "missing input #0");
assert(mOp.getInput(1) && "missing input #1");
assert((mOp.getInput(0)->nbDims() == 2 && mOp.getInput(1)->nbDims() == 2 )&& "only 2D tensors are supported");
Gather_Op::Attrs attr = dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes();
const int& axisIdx = static_cast<const int&>(std::get<0>(attr));
assert(mOp.getInput(0)->nbDims() > 1);// > axisIdx && "input dim must be bigger than "+std::to_strint(axisIdx)
auto kernelFunc = Registrar<GatherImplForward_cpu>::create({
mOp.getInput(0)->dataType(),
mOp.getOutput(0)->dataType()});
// Call kernel
kernelFunc(axisIdx,
mOp.getInput(0)->dims(),
mOp.getInput(1)->dims(),
mOp.getInput(0)->getImpl()->rawPtr(),
mOp.getInput(1)->getImpl()->rawPtr(),
mOp.getOutput(0)->getImpl()->rawPtr());
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Gather.hpp"
#include "aidge/backend/cpu.hpp"
#include <memory>
using namespace Aidge;
TEST_CASE("[cpu/operator] Gather(forward)") {
SECTION("2D Tensor axis 0") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<int,3,3> {
{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9}
}
});
std::shared_ptr<Tensor> indexes = std::make_shared<Tensor>(Array2D<int,1,2> {
{
{1, 2}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,1,2,3> {
{
{
{4, 5, 6},
{7, 8, 9}
}
}
});
std::shared_ptr<Node> myGather = Gather(0);
myGather->getOperator()->setDatatype(DataType::Int32);
myGather->getOperator()->setBackend("cpu");
myGather->getOperator()->associateInput(0,input);
myGather->getOperator()->associateInput(1,indexes);
myGather->getOperator()->computeOutputDims();
myGather->forward();
REQUIRE(myGather->getOperator()->output(0).dims() == expectedOutput->dims());
REQUIRE(*(myGather->getOperator()->getOutput(0)) == *expectedOutput);
}
SECTION("2D Tensor axis 1") {
std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<int,3,3> {
{
{1, 2, 3},
{4, 5, 6},
{7, 8, 9}
}
});
std::shared_ptr<Tensor> indexes = std::make_shared<Tensor>(Array2D<int,1,2> {
{
{0, 2}
}
});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,3,1,2> {
{
{
{1, 3}
},
{
{4, 6}
},
{
{7, 9}
}
}
});
std::shared_ptr<Node> myGather = Gather(1);
myGather->getOperator()->setDatatype(DataType::Int32);
myGather->getOperator()->setBackend("cpu");
myGather->getOperator()->associateInput(0,input);
myGather->getOperator()->associateInput(1,indexes);
myGather->getOperator()->computeOutputDims();
myGather->forward();
REQUIRE(myGather->getOperator()->output(0).dims() == expectedOutput->dims());
REQUIRE(*(myGather->getOperator()->getOutput(0)) == *expectedOutput);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment