diff --git a/include/aidge/backend/cpu/operator/GatherImpl.hpp b/include/aidge/backend/cpu/operator/GatherImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0f10c02a0fbfb5c08b190dff5eb238cf70db619c --- /dev/null +++ b/include/aidge/backend/cpu/operator/GatherImpl.hpp @@ -0,0 +1,50 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_GATHERIMPL_H_ +#define AIDGE_CPU_OPERATOR_GATHERIMPL_H_ + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Gather.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" +#include <memory> +#include <vector> + +namespace Aidge { +// class Gather_Op; + +// compute kernel registry for forward and backward +class GatherImplForward_cpu + : public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(const int, const std::vector<DimSize_t>, const std::vector<DimSize_t>, const void*, const void*, void*)> { +}; +class GatherImplBackward_cpu + : public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(const int, const std::vector<DimSize_t>, const std::vector<DimSize_t>, const void*, const void*, void*)> { +}; + +class GatherImpl_cpu : public OperatorImpl { +public: + GatherImpl_cpu(const Gather_Op& op) : OperatorImpl(op) {} + + static std::unique_ptr<GatherImpl_cpu> create(const Gather_Op& op) { + return std::make_unique<GatherImpl_cpu>(op); + } + + NbElts_t getNbRequiredProtected(const IOIndex_t inputIdx) const override final; + void forward() override; +}; + +namespace { +static Registrar<Gather_Op> registrarGatherImpl_cpu("cpu", Aidge::GatherImpl_cpu::create); +} +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_GATHERIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..d8734d1ab47aa06ed8b6b48268d493d32b54c38a --- /dev/null +++ b/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp @@ -0,0 +1,72 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_ +#define AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_ + +#include "aidge/utils/Registrar.hpp" +#include <cstddef> +#include <cmath> +#include "aidge/data/Data.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cpu/operator/GatherImpl.hpp" + +namespace Aidge { +template <class I, class O> +void GatherImpl_cpu_forward_kernel(const int& axisIdx_, std::vector<DimSize_t> inputDims, const std::vector<DimSize_t> indicesDims, const void* input_, const void* indexes_, void* output_) +{ + const I* input = static_cast<const I*>(input_); + const int* indexes = static_cast<const int*>(indexes_); + const std::size_t axisIdx = axisIdx_; + O* output = static_cast<O*>(output_); + + // Calculate the total number of elements in the input array + size_t totalElements = 1; + for (size_t dimSize : inputDims) { + totalElements *= dimSize; + } + std::size_t nbElemAfterAxis = 1; + std::size_t nbElemBeforeAxis = 1; + + for (size_t d = 0; d < inputDims.size(); ++d) { + if( d < axisIdx ) + nbElemBeforeAxis *= inputDims[d]; + else if ( d > axisIdx ) + nbElemAfterAxis *= inputDims[d]; + } + + for (std::size_t i=0; i<nbElemBeforeAxis; ++i) + { + for(std::size_t idxRow=0; idxRow<indicesDims[0]; ++idxRow) + { + for(std::size_t idxCol=0; idxCol<indicesDims[1]; ++idxCol) + { + std::size_t idx = indexes[indicesDims[1] * idxRow + idxCol]; + const I* startPtr = std::next(input, i * nbElemAfterAxis * inputDims[axisIdx] + idx * nbElemAfterAxis); + std::copy_n(startPtr, nbElemAfterAxis, output); + output += nbElemAfterAxis; + } + } + } +} + +namespace { +static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Float32( + {DataType::Float32, DataType::Float32}, Aidge::GatherImpl_cpu_forward_kernel<float, float>); +static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Int32( + {DataType::Int32, DataType::Int32}, Aidge::GatherImpl_cpu_forward_kernel<int, int>); +static Registrar<GatherImplForward_cpu> registrarGatherImplForward_cpu_Float64( + {DataType::Float64, DataType::Float64}, Aidge::GatherImpl_cpu_forward_kernel<double, double>); +} // namespace +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_GATHERIMPL_FORWARD_KERNEL_H_ */ diff --git a/src/operator/GatherImpl.cpp b/src/operator/GatherImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c5f4f1f37aea64b27913fea4af9885d01dc20f6b --- /dev/null +++ b/src/operator/GatherImpl.cpp @@ -0,0 +1,51 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <cassert> +#include <chrono> // std::chrono::milliseconds +#include <numeric> // std::accumulate +#include <thread> // std::this_thread::sleep_for +#include <vector> + +#include "aidge/operator/Gather.hpp" +#include "aidge/utils/Types.h" + +#include "aidge/backend/cpu/operator/GatherImpl.hpp" +#include "aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp" + +#include <iostream> + +Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOIndex_t /*inputIdx*/) const { + // this implementation can be in-place + return 0; +} + +void Aidge::GatherImpl_cpu::forward() { + assert(mOp.getInput(0) && "missing input #0"); + assert(mOp.getInput(1) && "missing input #1"); + assert((mOp.getInput(0)->nbDims() == 2 && mOp.getInput(1)->nbDims() == 2 )&& "only 2D tensors are supported"); + + Gather_Op::Attrs attr = dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes(); + const int& axisIdx = static_cast<const int&>(std::get<0>(attr)); + assert(mOp.getInput(0)->nbDims() > 1);// > axisIdx && "input dim must be bigger than "+std::to_strint(axisIdx) + + auto kernelFunc = Registrar<GatherImplForward_cpu>::create({ + mOp.getInput(0)->dataType(), + mOp.getOutput(0)->dataType()}); + + // Call kernel + kernelFunc(axisIdx, + mOp.getInput(0)->dims(), + mOp.getInput(1)->dims(), + mOp.getInput(0)->getImpl()->rawPtr(), + mOp.getInput(1)->getImpl()->rawPtr(), + mOp.getOutput(0)->getImpl()->rawPtr()); +} diff --git a/unit_tests/operator/Test_GatherImpl.cpp b/unit_tests/operator/Test_GatherImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c295bf7c612316c1382ea73ee2dadad10d42c2e7 --- /dev/null +++ b/unit_tests/operator/Test_GatherImpl.cpp @@ -0,0 +1,98 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Gather.hpp" + +#include "aidge/backend/cpu.hpp" + +#include <memory> + + +using namespace Aidge; + +TEST_CASE("[cpu/operator] Gather(forward)") { + SECTION("2D Tensor axis 0") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<int,3,3> { + { + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9} + } + }); + std::shared_ptr<Tensor> indexes = std::make_shared<Tensor>(Array2D<int,1,2> { + { + {1, 2} + } + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,1,2,3> { + { + { + {4, 5, 6}, + {7, 8, 9} + } + } + }); + + std::shared_ptr<Node> myGather = Gather(0); + myGather->getOperator()->setDatatype(DataType::Int32); + myGather->getOperator()->setBackend("cpu"); + myGather->getOperator()->associateInput(0,input); + myGather->getOperator()->associateInput(1,indexes); + myGather->getOperator()->computeOutputDims(); + myGather->forward(); + + REQUIRE(myGather->getOperator()->output(0).dims() == expectedOutput->dims()); + REQUIRE(*(myGather->getOperator()->getOutput(0)) == *expectedOutput); + + } + SECTION("2D Tensor axis 1") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<int,3,3> { + { + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9} + } + }); + std::shared_ptr<Tensor> indexes = std::make_shared<Tensor>(Array2D<int,1,2> { + { + {0, 2} + } + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,3,1,2> { + { + { + {1, 3} + }, + { + {4, 6} + }, + { + {7, 9} + } + } + }); + + std::shared_ptr<Node> myGather = Gather(1); + myGather->getOperator()->setDatatype(DataType::Int32); + myGather->getOperator()->setBackend("cpu"); + myGather->getOperator()->associateInput(0,input); + myGather->getOperator()->associateInput(1,indexes); + myGather->getOperator()->computeOutputDims(); + myGather->forward(); + + REQUIRE(myGather->getOperator()->output(0).dims() == expectedOutput->dims()); + REQUIRE(*(myGather->getOperator()->getOutput(0)) == *expectedOutput); + + } +} \ No newline at end of file