diff --git a/include/aidge/backend/cpu/operator/GatherImpl.hpp b/include/aidge/backend/cpu/operator/GatherImpl.hpp index d22e484e3a80f70753bbc083a5d89562774a3870..1d235ff14ca01955c268a7b061e6ecb7b2bbbb2a 100644 --- a/include/aidge/backend/cpu/operator/GatherImpl.hpp +++ b/include/aidge/backend/cpu/operator/GatherImpl.hpp @@ -24,10 +24,10 @@ namespace Aidge { // compute kernel registry for forward and backward class GatherImplForward_cpu - : public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const void*, const void*, void*)> { + : public Registrable<GatherImplForward_cpu, std::tuple<DataType, DataType>, void(const typename Gather_Op::Attrs&, const std::vector<DimSize_t>&, const void*, void*)> { }; class GatherImplBackward_cpu - : public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(std::size_t, const std::vector<DimSize_t>&, const std::vector<DimSize_t>&, const void*, const void*, void*)> { + : public Registrable<GatherImplBackward_cpu, std::tuple<DataType, DataType>, void(const typename Gather_Op::Attrs&, const std::vector<DimSize_t>&, const void*, void*)> { }; class GatherImpl_cpu : public OperatorImpl { diff --git a/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp index 31119e27f213a8698a8ff07f7690155d4753a05a..591985e88ef337f69d79463935ea0d2d258f49e3 100644 --- a/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp +++ b/include/aidge/backend/cpu/operator/GatherImpl_forward_kernels.hpp @@ -22,12 +22,13 @@ namespace Aidge { template <class I, class O> -void GatherImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSize_t>& inputDims, const std::vector<DimSize_t>& indicesDims, const void* input_, const void* indexes_, void* output_) +void GatherImpl_cpu_forward_kernel(const typename Gather_Op::Attrs& attrs, const std::vector<DimSize_t>& inputDims, const void* input_, void* output_) { const I* input = static_cast<const I*>(input_); - const int* indexes = static_cast<const int*>(indexes_); O* output = static_cast<O*>(output_); + std::size_t axisIdx = std::get<2>(attrs)>=0 ? std::get<2>(attrs) : static_cast<std::size_t>(std::get<2>(attrs)) + inputDims.size(); + std::size_t postAxisElems = 1; for (std::size_t i = axisIdx + 1; i < inputDims.size(); ++i) { postAxisElems *= inputDims[i]; @@ -37,17 +38,15 @@ void GatherImpl_cpu_forward_kernel(std::size_t axisIdx, const std::vector<DimSiz preAxisElems *= inputDims[i]; } + std::vector<std::int64_t> indices = std::get<0>(attrs); for (std::size_t i=0; i<preAxisElems; ++i) { - for(std::size_t idxRow=0; idxRow<indicesDims[0]; ++idxRow) + for(std::size_t j=0; j<indices.size(); ++j) { - for(std::size_t idxCol=0; idxCol<indicesDims[1]; ++idxCol) - { - std::size_t idx = indexes[indicesDims[1] * idxRow + idxCol]; - const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems); - std::copy_n(startPtr, postAxisElems, output); - output += postAxisElems; - } + std::size_t idx = indices[j] >= 0 ? indices[j] : indices[j] + inputDims[axisIdx]; + const I* startPtr = std::next(input, i * postAxisElems * inputDims[axisIdx] + idx * postAxisElems); + std::copy_n(startPtr, postAxisElems, output); + output += postAxisElems; } } } diff --git a/src/operator/GatherImpl.cpp b/src/operator/GatherImpl.cpp index fd5e755be6ba898d916243c7f8a67bc2c6baca1e..ce98627d95e0d05541db1ccaf4896abe756431b0 100644 --- a/src/operator/GatherImpl.cpp +++ b/src/operator/GatherImpl.cpp @@ -27,19 +27,14 @@ Aidge::NbElts_t Aidge::GatherImpl_cpu::getNbRequiredProtected(const Aidge::IOInd } void Aidge::GatherImpl_cpu::forward() { - Gather_Op::Attrs attr = dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes(); - const int& axisIdx = static_cast<const int&>(std::get<0>(attr)); - assert(std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->nbDims() > 1);// > axisIdx && "input dim must be bigger than "+std::to_strint(axisIdx) auto kernelFunc = Registrar<GatherImplForward_cpu>::create({ std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dataType(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->dataType()}); // Call kernel - kernelFunc(axisIdx, + kernelFunc(dynamic_cast<const Gather_Op&>(mOp).getStaticAttributes(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->dims(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->dims(), std::static_pointer_cast<Tensor>(mOp.getRawInput(0))->getImpl()->rawPtr(), - std::static_pointer_cast<Tensor>(mOp.getRawInput(1))->getImpl()->rawPtr(), std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))->getImpl()->rawPtr()); } diff --git a/unit_tests/operator/Test_GatherImpl.cpp b/unit_tests/operator/Test_GatherImpl.cpp index e0903aa77696467db3708cb26dfda4ce90a4344a..a8345917ab0a141065e86638c09b2689902679ec 100644 --- a/unit_tests/operator/Test_GatherImpl.cpp +++ b/unit_tests/operator/Test_GatherImpl.cpp @@ -44,14 +44,16 @@ TEST_CASE("[cpu/operator] Gather(forward)") { } }); - std::shared_ptr<Node> myGather = Gather(0); + std::shared_ptr<Node> myGather = Gather({1, 2}, {1, 2}, 0); auto op = std::static_pointer_cast<OperatorTensor>(myGather -> getOperator()); op->associateInput(0,input); - op->associateInput(1,indexes); + // op->associateInput(1,indexes); op->setDataType(DataType::Int32); op->setBackend("cpu"); op->computeOutputDims(); myGather->forward(); + op->getOutput(0)->print(); + expectedOutput->print(); REQUIRE(*(op->getOutput(0)) == *expectedOutput); @@ -83,10 +85,10 @@ TEST_CASE("[cpu/operator] Gather(forward)") { } }); - std::shared_ptr<Node> myGather = Gather(1); + std::shared_ptr<Node> myGather = Gather({0, 2}, {1, 2}, 1); auto op = std::static_pointer_cast<OperatorTensor>(myGather -> getOperator()); op->associateInput(0,input); - op->associateInput(1,indexes); + // op->associateInput(1,indexes); op->setDataType(DataType::Int32); op->setBackend("cpu"); op->computeOutputDims();