diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index 142f6582a3afbc85ccd951fcfeff2a924a35e718..43077b73181c70d1a057c4f5af68b473c5540fa9 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -27,27 +27,24 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class GatherAttr { Indices, GatheredShape, Axis }; +enum class GatherAttr { Axis }; class Gather_Op : public OperatorTensor, public Registrable<Gather_Op, std::string, - std::shared_ptr<OperatorImpl>(const Gather_Op&)>, - public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> { + std::unique_ptr<OperatorImpl>(const Gather_Op&)>, + public StaticAttributes<GatherAttr, std::int64_t> { public: static const std::string Type; Gather_Op() = delete; - using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>; + using Attributes_ = StaticAttributes<GatherAttr, std::int64_t>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; - Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis) - : OperatorTensor(Type, 1, 0, 1), - Attributes_( - attr<GatherAttr::Indices>(indices), - attr<GatherAttr::GatheredShape>(gatheredShape), - attr<GatherAttr::Axis>(axis)) + Gather_Op(std::int64_t axis) + : OperatorTensor(Type, 2, 0, 1), + Attributes_(attr<GatherAttr::Axis>(axis)) {} /** @@ -81,21 +78,21 @@ public: } static const std::vector<std::string> getInputsName(){ - return {"data_input"}; + return {"data_input", "indices"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } }; -inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name); +inline std::shared_ptr<Node> Gather(std::int64_t axis = 0, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Indices", "GatheredShape", "Axis"}; +const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"}; } #endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */ diff --git a/python_binding/operator/pybind_Gather.cpp b/python_binding/operator/pybind_Gather.cpp index db6bdb15a2e6288b5f775d538a5e14f15d79d2c1..493c5c118c92d17902ab7ffd69a4fcb70964219a 100644 --- a/python_binding/operator/pybind_Gather.cpp +++ b/python_binding/operator/pybind_Gather.cpp @@ -24,6 +24,7 @@ void init_Gather(py::module& m) { .def("get_outputs_name", &Gather_Op::getOutputsName) .def("attributes_name", &Gather_Op::staticGetAttrsName); declare_registrable<Gather_Op>(m, "GatherOp"); - m.def("Gather", &Gather, py::arg("indices"), py::arg("gathered_shape"), py::arg("axis")= 0, py::arg("name") = ""); + + m.def("Gather", &Gather, py::arg("axis")=0, py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index b5f9d738a0280b3bacdb2ce201c8303b2b4d0a1f..920f161e4c6d916bcd127bc26e8cda49e5d592a7 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -15,33 +15,29 @@ #include <vector> #include "aidge/operator/Gather.hpp" -#include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" const std::string Aidge::Gather_Op::Type = "Gather"; void Aidge::Gather_Op::computeOutputDims() { // check inputs have been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); + if (!getInput(0) || !getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); } - if (!getInput(0)->empty()) { + if (!getInput(0)->empty() && !getInput(1)->empty()) { std::vector<DimSize_t> outDims = getInput(0)->dims(); - const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); - // TODO: check indices and gatheredShape + std::vector<DimSize_t> indicesDims = getInput(1)->dims(); - const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? - this->template getAttr<GatherAttr::Axis>() : - this->template getAttr<GatherAttr::Axis>() + outDims.size(); + std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>()>=0? + this->template getAttr<GatherAttr::Axis>(): + this->template getAttr<GatherAttr::Axis>()+outDims.size(); outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); - if (!gatheredShape.empty()) + if( indicesDims[0]>0 ) // In case indices is a scalar indicesDims is a 0 { - outDims.insert(outDims.cbegin() + static_cast<std::size_t>(axisIdx), - gatheredShape.cbegin(), - gatheredShape.cend()); + outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indicesDims.begin(),indicesDims.end()); } - mOutputs[0]->resize(outDims); } } \ No newline at end of file