diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index 5a067f2e94dbfe9826149037036d62686220049c..6680f2e1d6de5157024f9e7ca65b14256e53eae2 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -31,24 +31,26 @@ public: void forward() override; }; -enum class GatherAttr { Axis }; +enum class GatherAttr { Axis, Indices, GatheredShape }; class Gather_Op : public OperatorTensor, public Registrable<Gather_Op, std::string, std::shared_ptr<OperatorImpl>(const Gather_Op&)>, - public StaticAttributes<GatherAttr, std::int8_t> { + public StaticAttributes<GatherAttr, std::int8_t, std::vector<int64_t>, std::vector<DimSize_t>> { public: static const std::string Type; Gather_Op() = delete; - using Attributes_ = StaticAttributes<GatherAttr, std::int8_t>; + using Attributes_ = StaticAttributes<GatherAttr, std::int8_t, std::vector<int64_t>, std::vector<DimSize_t>>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; - Gather_Op(std::int8_t axis) + Gather_Op(std::int8_t axis, const std::vector<int64_t>& indices, const std::vector<DimSize_t>& gatheredShape) : OperatorTensor(Type, 2, 0, 1), - Attributes_(attr<GatherAttr::Axis>(axis)) + Attributes_(attr<GatherAttr::Axis>(axis), + attr<GatherAttr::Indices>(indices), + attr<GatherAttr::GatheredShape>(gatheredShape)) { mImpl = std::make_shared<Gather_OpImpl>(*this); } @@ -89,14 +91,14 @@ public: } }; -inline std::shared_ptr<Node> Gather(std::int8_t axis = 0, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name); +inline std::shared_ptr<Node> Gather(std::int8_t axis = 0, const std::vector<int64_t>& indices = {}, const std::vector<DimSize_t>& gatheredShape = {}, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Gather_Op>(axis, indices, gatheredShape), name); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"}; +const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis", "Indices", "GatheredShape"}; } #endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */ diff --git a/python_binding/operator/pybind_Gather.cpp b/python_binding/operator/pybind_Gather.cpp index 56f589dbed9dfa04471d43d7021e0ca540344ce4..e5507e670c1ec0bf4758169a9ea9864ff3fe29be 100644 --- a/python_binding/operator/pybind_Gather.cpp +++ b/python_binding/operator/pybind_Gather.cpp @@ -11,6 +11,7 @@ #include <pybind11/pybind11.h> #include <string> +#include <vector> #include "aidge/data/Tensor.hpp" #include "aidge/operator/Gather.hpp" @@ -26,6 +27,6 @@ void init_Gather(py::module& m) { .def("attributes_name", &Gather_Op::staticGetAttrsName); declare_registrable<Gather_Op>(m, "GatherOp"); - m.def("Gather", &Gather, py::arg("axis") = 0, py::arg("name") = ""); + m.def("Gather", &Gather, py::arg("axis") = 0, py::arg("indices") = std::vector<std::int64_t>(), py::arg("gathered_shape") = std::vector<std::size_t>(), py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 8b931667adf0a21a025aaf6ffb023a32c89b4970..4e5bd2573a0e1b0cc78256a68dad88332877067b 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -35,15 +35,14 @@ void Aidge::Gather_OpImpl::forward() { preAxisElems *= op.getInput(0)->dims()[i]; } - const auto indices = static_cast<const int*>(op.getInput(1)->getImpl()->rawPtr()); std::size_t outputOffset = 0; for (std::size_t i=0; i<preAxisElems; ++i) { - for(std::size_t j=0; j<op.getInput(1)->size(); ++j) + for(std::size_t j=0; j<op.template getAttr<std::vector<int64_t>>("Indices").size(); ++j) { - const std::size_t idx = indices[j] >= 0 ? - static_cast<std::size_t>(indices[j]) : - static_cast<std::size_t>(indices[j] + static_cast<int>(op.getInput(0)->dims()[axisIdx])); + const std::size_t idx = op.template getAttr<std::vector<int64_t>>("Indices")[j] >= 0 ? + static_cast<std::size_t>(op.template getAttr<std::vector<int64_t>>("Indices")[j]) : + static_cast<std::size_t>(op.template getAttr<std::vector<int64_t>>("Indices")[j] + static_cast<int>(op.getInput(0)->dims()[axisIdx])); op.getOutput(0)->getImpl()->copy(op.getInput(0)->getImpl()->rawPtr(i * postAxisElems * op.getInput(0)->dims()[axisIdx] + idx * postAxisElems), postAxisElems, outputOffset); outputOffset += postAxisElems; } @@ -53,25 +52,57 @@ void Aidge::Gather_OpImpl::forward() { const std::string Aidge::Gather_Op::Type = "Gather"; bool Aidge::Gather_Op::forwardDims(bool /*allowDataDependency*/) { - // check inputs have been associated - for(int i=0; i<2; ++i){ - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } + // check data input has been associated + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } - - if (!getInput(0)->empty() && !getInput(1)->empty()) { + if (!getInput(0)->empty()) { + if (this->template getAttr<GatherAttr::Indices>().empty()) + { + if(getInput(1)->empty()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Either indices input or attribute must be provided", type()); + } + this->template getAttr<GatherAttr::GatheredShape>() = getInput(1)->dims(); + this->template getAttr<GatherAttr::Indices>().clear(); // If both are provided input would override attrs + this->template getAttr<GatherAttr::Indices>().reserve(getInput(1)->size()); + switch (mInputs[1]->dataType()) { + case DataType::Float64: + std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<GatherAttr::Indices>())); + break; + case DataType::Float32: + std::copy_n(static_cast<float*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<GatherAttr::Indices>())); + break; + case DataType::Int64: + std::copy_n(static_cast<std::int64_t*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<GatherAttr::Indices>())); + break; + case DataType::Int32: + std::copy_n(static_cast<std::int32_t*>(mInputs[1]->getImpl()->rawPtr()), + getInput(1)->size(), + std::back_inserter(this->template getAttr<GatherAttr::Indices>())); + break; + default: + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: Indices input DataType is not supported.", type()); + break; + } + } std::vector<DimSize_t> outDims = getInput(0)->dims(); - std::vector<DimSize_t> indicesDims = getInput(1)->dims(); std::int8_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? this->template getAttr<GatherAttr::Axis>(): this->template getAttr<GatherAttr::Axis>()+outDims.size(); outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); - if( indicesDims[0]>0 ) // In case indices is a scalar indicesDims is a 0 + if( !this->template getAttr<GatherAttr::GatheredShape>().empty()) { - outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indicesDims.begin(),indicesDims.end()); + outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), + this->template getAttr<GatherAttr::GatheredShape>().begin(), + this->template getAttr<GatherAttr::GatheredShape>().end()); } mOutputs[0]->resize(outDims); return true; diff --git a/unit_tests/operator/Test_GatherImpl.cpp b/unit_tests/operator/Test_GatherImpl.cpp index bd77c774a890fc63c5517263a2d413a63ee37926..02e8e74890918726212e09fdd9f969ce0863af83 100644 --- a/unit_tests/operator/Test_GatherImpl.cpp +++ b/unit_tests/operator/Test_GatherImpl.cpp @@ -94,4 +94,36 @@ TEST_CASE("[cpu/operator] Gather(forward)", "[Gather][CPU]") { REQUIRE(*(op->getOutput(0)) == *expectedOutput); } + SECTION("Init with attributes") { + std::shared_ptr<Tensor> input = std::make_shared<Tensor>(Array2D<int,3,3> { + { + {1, 2, 3}, + {4, 5, 6}, + {7, 8, 9} + } + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array3D<int,3,1,2> { + { + { + {1, 3} + }, + { + {4, 6} + }, + { + {7, 9} + } + } + }); + + std::shared_ptr<Node> myGather = Gather(1, {0, 2}, {1, 2}); + auto op = std::static_pointer_cast<OperatorTensor>(myGather -> getOperator()); + op->associateInput(0,input); + op->setDataType(DataType::Int32); + op->setBackend("cpu"); + myGather->forward(); + + REQUIRE(*(op->getOutput(0)) == *expectedOutput); + + } } \ No newline at end of file