diff --git a/include/aidge/operator/Gather.hpp b/include/aidge/operator/Gather.hpp index 20082eed28825ade9d62fb5d4e081840d3bd4442..f6647f99151304d0cf083aed109cc642c9f1ecc2 100644 --- a/include/aidge/operator/Gather.hpp +++ b/include/aidge/operator/Gather.hpp @@ -27,25 +27,26 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class GatherAttr { Axis }; +enum class GatherAttr { Indices, GatheredShape, Axis }; class Gather_Op : public OperatorTensor, public Registrable<Gather_Op, std::string, std::unique_ptr<OperatorImpl>(const Gather_Op&)>, - public StaticAttributes<GatherAttr, int> { + public StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t> { public: static const std::string Type; Gather_Op() = delete; - - using Attributes_ = StaticAttributes<GatherAttr, int>; + using Attributes_ = StaticAttributes<GatherAttr, std::vector<std::int64_t>, std::vector<DimSize_t>, std::int64_t>; template <GatherAttr e> using attr = typename Attributes_::template attr<e>; - Gather_Op(int axis) - : OperatorTensor(Type, 2, 0, 1), + Gather_Op(const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis) + : OperatorTensor(Type, 1, 0, 1), Attributes_( + attr<GatherAttr::Indices>(indices), + attr<GatherAttr::GatheredShape>(gatheredShape), attr<GatherAttr::Axis>(axis)) {} @@ -76,21 +77,21 @@ public: } static const std::vector<std::string> getInputsName(){ - return {"data_input", "indexes"}; + return {"data_input"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } }; -inline std::shared_ptr<Node> Gather(int axis = 0, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Gather_Op>(axis), name); +inline std::shared_ptr<Node> Gather( const std::vector<std::int64_t>& indices, const std::vector<DimSize_t>& gatheredShape, std::int64_t axis = 0, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Gather_Op>(indices, gatheredShape, axis), name); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Axis"}; +const char *const EnumStrings<Aidge::GatherAttr>::data[] = {"Indices", "GatheredShape", "Axis"}; } #endif /* AIDGE_CORE_OPERATOR_GATHER_H_ */ diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 12a7425f3339b7fbc0ae010639aacf23d97b0f5f..4a073bc525640846c28d718d09741a67d499830e 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -29,17 +29,17 @@ enum class SliceAttr { Starts, Ends, Axes }; class Slice_Op : public OperatorTensor, public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>, - public StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>> { + public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> { public: static const std::string Type; Slice_Op() = delete; - using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int32_t>, std::vector<std::int32_t>, std::vector<std::int32_t>>; + using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>; template <SliceAttr e> using attr = typename Attributes_::template attr<e>; - Slice_Op(const std::vector<std::int32_t>& starts, const std::vector<std::int32_t>& ends, const std::vector<std::int32_t>& axes) + Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int64_t>& axes) : OperatorTensor(Type, 1, 0, 1), Attributes_(attr<SliceAttr::Starts>(starts), attr<SliceAttr::Ends>(ends), @@ -94,9 +94,9 @@ public: * @param name Name of the Operator. * @return std::shared_ptr<Node> A Node containing the Operator. */ -inline std::shared_ptr<Node> Slice(const std::vector<std::int32_t> starts, - const std::vector<std::int32_t> ends, - const std::vector<std::int32_t> axes, +inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts, + const std::vector<std::int64_t> ends, + const std::vector<std::int64_t> axes, const std::string &name = "") { // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); diff --git a/python_binding/operator/pybind_Gather.cpp b/python_binding/operator/pybind_Gather.cpp index f9768e38fbdceef4a15cc74430bc2205bb32cb6a..4369d4d22b205a40140cf5160d999743b2e9b4c1 100644 --- a/python_binding/operator/pybind_Gather.cpp +++ b/python_binding/operator/pybind_Gather.cpp @@ -23,6 +23,6 @@ void init_Gather(py::module& m) { .def("get_inputs_name", &Gather_Op::getInputsName) .def("get_outputs_name", &Gather_Op::getOutputsName); - m.def("Gather", &Gather, py::arg("axis"), py::arg("name") = ""); + m.def("Gather", &Gather, py::arg("indices"), py::arg("gathered_shape"), py::arg("axis"), py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 30804994b6084a5a5558f106a38a6087e54471bc..b5f9d738a0280b3bacdb2ce201c8303b2b4d0a1f 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -9,8 +9,8 @@ * ********************************************************************************/ -#include <cassert> #include <cstddef> +#include <cstdint> #include <string> #include <vector> @@ -22,18 +22,26 @@ const std::string Aidge::Gather_Op::Type = "Gather"; void Aidge::Gather_Op::computeOutputDims() { // check inputs have been associated - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); } - if (getInput(1)->nbDims()!=2){ - AIDGE_THROW_OR_ABORT(std::runtime_error, "Indices input must be a 2D Tensor"); - } + if (!getInput(0)->empty()) { + std::vector<DimSize_t> outDims = getInput(0)->dims(); + const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); + // TODO: check indices and gatheredShape + + const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? + this->template getAttr<GatherAttr::Axis>() : + this->template getAttr<GatherAttr::Axis>() + outDims.size(); + outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); + if (!gatheredShape.empty()) + { + outDims.insert(outDims.cbegin() + static_cast<std::size_t>(axisIdx), + gatheredShape.cbegin(), + gatheredShape.cend()); + } - std::vector<DimSize_t> outDims = getInput(0)->dims(); - std::vector<DimSize_t> indexesDims = getInput(1)->dims(); - int axisIdx = this->template getAttr<GatherAttr::Axis>()>=0?this->template getAttr<GatherAttr::Axis>():this->template getAttr<GatherAttr::Axis>()+outDims.size(); - outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); - outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), indexesDims.begin(),indexesDims.end()); - mOutputs[0]->resize(outDims); + mOutputs[0]->resize(outDims); + } } \ No newline at end of file diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index c1a7c35e395418995a720efd49c7cfce0801863e..30b060cd2a58d7995a7447bd9b85b9bc0026a7f7 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -27,31 +27,32 @@ void Aidge::Reshape_Op::computeOutputDims() { AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); } - std::vector<DimSize_t> outDims; - - // variables to handle a negative dimension - bool foundNegativeDimension = false; - std::size_t outSize = 1; - DimIdx_t negativeIndex = 0; - - for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) - { - std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; - if (dimSize < 0) { - if (foundNegativeDimension) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); + if (!getInput(0)->empty()) { + std::vector<DimSize_t> outDims; + // variables to handle a negative dimension + bool foundNegativeDimension = false; + std::size_t outSize = 1; + DimIdx_t negativeIndex = 0; + + for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) + { + std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; + if (dimSize < 0) { + if (foundNegativeDimension) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); + } + foundNegativeDimension = true; + dimSize = 1; + negativeIndex = static_cast<DimIdx_t>(i); } - foundNegativeDimension = true; - dimSize = 1; - negativeIndex = static_cast<DimIdx_t>(i); + outDims.push_back(static_cast<DimSize_t>(dimSize)); + outSize *= static_cast<DimSize_t>(dimSize); } - outDims.push_back(static_cast<DimSize_t>(dimSize)); - outSize *= static_cast<DimSize_t>(dimSize); - } - if (foundNegativeDimension) { - outDims[negativeIndex] = (getInput(0) -> size()) / outSize; - } + if (foundNegativeDimension) { + outDims[negativeIndex] = (getInput(0) -> size()) / outSize; + } - mOutputs[0]->resize(outDims); + mOutputs[0]->resize(outDims); + } } \ No newline at end of file diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 139e84b561a48c2f6a5ecd14ed9d6905d66dec20..11d91a1fcd4c1d4ee6bcc5f9d830870fa6e732e5 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -30,21 +30,23 @@ void Aidge::Slice_Op::computeOutputDims() { AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); } - DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); + const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); std::vector<DimSize_t> outDims = getInput(0)->dims(); for (std::size_t i = 0; i < nbAxes; ++i) { // For each slice operation get the params and cast them to size_t const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; - const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : axis_ + getInput(0)->nbDims(); - const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : start_ + getInput(0)->dims()[axis]; - const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : end_ + getInput(0)->dims()[axis]; + const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims(); + const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis]; + const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis]; const std::size_t sliceLength = end - start + 1; // Check if slice length is valid if (sliceLength > getInput(0)->dims()[axis]) + { AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); + } outDims[axis] = sliceLength; } mOutputs[0]->resize(outDims); diff --git a/src/recipies/HorizontalTiling.cpp b/src/recipies/HorizontalTiling.cpp index 6cc34eba076934b884b336ce40081a855d917182..7d3fafc0a15d1b797fdfb1a2884b62d2d8d766c5 100644 --- a/src/recipies/HorizontalTiling.cpp +++ b/src/recipies/HorizontalTiling.cpp @@ -82,16 +82,16 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: clonedInputs[1] -> addChild(newNode, 0, 1); clonedInputs[2] -> addChild(newNode, 0, 2); // Slice for input and each parameter - std::vector<std::int32_t> inputDimsEnd(inputDims[0].first.size()); + std::vector<std::int64_t> inputDimsEnd(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsEnd.size(); ++dim) { - inputDimsEnd[dim] = static_cast<std::int32_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1; + inputDimsEnd[dim] = static_cast<std::int64_t>(inputDims[0].first[dim] + inputDims[0].second[dim]) - 1; } - std::vector<std::int32_t> inputDimsStart(inputDims[0].first.size()); + std::vector<std::int64_t> inputDimsStart(inputDims[0].first.size()); for (std::size_t dim = 0; dim < inputDimsStart.size(); ++dim) { - inputDimsStart[dim] = static_cast<std::int32_t>(inputDims[0].first[dim]); + inputDimsStart[dim] = static_cast<std::int64_t>(inputDims[0].first[dim]); } - std::vector<std::int32_t> usedDims(inputDimsEnd.size()); - std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int32_t>(0)); + std::vector<std::int64_t> usedDims(inputDimsEnd.size()); + std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0)); auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); slice -> addChild(newNode, 0, 0); newNode -> addChild(concat, 0, i);