diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index 20082eed28825ade9d62fb5d4e081840d3bd4442..f6647f99151304d0cf083aed109cc642c9f1ecc2 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -27,25 +27,26 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class GatherAttr { Axis }; +enum class GatherAttr { Indices, GatheredShape, Axis }; class Gather_Op : public OperatorTensor, public Registrable<Gather_Op, std::string, std::unique_ptr<OperatorImpl>(const Gather_Op&)>, - public StaticAttributes<GatherAttr, int> { + public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> { public: static const std::string Type; Gather_Op() = delete; - - using Attributes_ = StaticAttributes<GatherAttr, int>; + using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; - Gather_Op(int axis) - : OperatorTensor(Type, 2, 0, 1), + Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis) + : OperatorTensor(Type, 1, 0, 1), Attributes_( + attr<GatherAttr::Indices>(indices), + attr<GatherAttr::GatheredShape>(gatheredShape), attr<GatherAttr::Axis>(axis)) {} @@ -76,21 +77,21 @@ public: } static const std::vector<std::string> getInputsName(){ - return {"data_input", "indexes"}; + return {"data_input"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } }; -inline std::shared_ptr<Node> Gather(int axis = 0, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name); +inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"}; +const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Indices", "GatheredShape", "Axis"}; } #endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */ diff --git a/python_binding/operator/pybind_Gather.cpp b/python_binding/operator/pybind_Gather.cpp index f9768e38fbdceef4a15cc74430bc2205bb32cb6a..4369d4d22b205a40140cf5160d999743b2e9b4c1 100644 --- a/python_binding/operator/pybind_Gather.cpp +++ b/python_binding/operator/pybind_Gather.cpp @@ -23,6 +23,6 @@ void init_Gather(py::module& m) { .def("get_inputs_name", &Gather_Op::getInputsName) .def("get_outputs_name", &Gather_Op::getOutputsName); - m.def("Gather", &Gather, py::arg("axis"), py::arg("name") = ""); + m.def("Gather", &Gather, py::arg("indices"), py::arg("gathered_shape"), py::arg("axis"), py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 30804994b6084a5a5558f106a38a6087e54471bc..2bba696e90f7239dd1079f9a91875be3bcab8a11 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -22,18 +22,22 @@ const std::string Aidge::Gather_Op::Type = "Gather"; void Aidge::Gather_Op::computeOutputDims() { // check inputs have been associated - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); - } - - if (getInput(1)->nbDims()!=2){ - AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor"); + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); } std::vector<DimSize_t> outDims = getInput(0)->dims(); - std::vector<DimSize_t> indexesDims = getInput(1)->dims(); - int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size(); + std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); + // TODO: check indices and gatheredShape + + std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? + this->template getAttr<GatherAttr::Axis>() : + this->template getAttr<GatherAttr::Axis>() + outDims.size(); outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); - outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end()); + if (!gatheredShape.empty()) + { + outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), gatheredShape.begin(),gatheredShape.end()); + } + mOutputs[0]->resize(outDims); } \ No newline at end of file diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 139e84b561a48c2f6a5ecd14ed9d6905d66dec20..40f74adfb0e786af17e8fc14720f9098db387ddc 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -34,17 +34,19 @@ void Aidge::Slice_Op::computeOutputDims() { std::vector<DimSize_t> outDims = getInput(0)->dims(); for (std::size_t i = 0; i < nbAxes; ++i) { // For each slice operation get the params and cast them to size_t - const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; - const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; - const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; - const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : axis_ + getInput(0)->nbDims(); - const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : start_ + getInput(0)->dims()[axis]; - const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : end_ + getInput(0)->dims()[axis]; + const std::int32_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; + const std::int32_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; + const std::int32_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; + const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims(); + const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis]; + const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis]; const std::size_t sliceLength = end - start + 1; // Check if slice length is valid if (sliceLength > getInput(0)->dims()[axis]) + { AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); + } outDims[axis] = sliceLength; } mOutputs[0]->resize(outDims);