From ade64013ff2f18e155e211ad6b2065c487547875 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Thu, 7 Mar 2024 14:33:40 +0100 Subject: [PATCH] switch slice attrs into inputs --- include/aidge/operator/Slice.hpp | 45 ++------------- python_binding/operator/pybind_Gather.cpp | 2 +- python_binding/operator/pybind_Slice.cpp | 3 +- src/operator/Slice.cpp | 69 +++++++++++++++-------- src/recipes/HorizontalTiling.cpp | 5 +- 5 files changed, 58 insertions(+), 66 deletions(-) diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp index 363c3c2b4..e71eaf40f 100644 --- a/include/aidge/operator/Slice.hpp +++ b/include/aidge/operator/Slice.hpp @@ -20,31 +20,16 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/utils/Registrar.hpp" -#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" namespace Aidge { -enum class SliceAttr { Starts, Ends, Axes }; - class Slice_Op : public OperatorTensor, - public Registrable<Slice_Op, std::string, std::shared_ptr<OperatorImpl>(const Slice_Op &)>, - public StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>> { + public Registrable<Slice_Op, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op &)>{ public: static const std::string Type; - Slice_Op() = delete; - - using Attributes_ = StaticAttributes<SliceAttr, std::vector<std::int64_t>, std::vector<std::int64_t>, std::vector<std::int64_t>>; - template <SliceAttr e> - using attr = typename Attributes_::template attr<e>; - - Slice_Op(const std::vector<std::int64_t>& starts, const std::vector<std::int64_t>& ends, const std::vector<std::int64_t>& axes) - : OperatorTensor(Type, 1, 0, 1), - Attributes_(attr<SliceAttr::Starts>(starts), - attr<SliceAttr::Ends>(ends), - attr<SliceAttr::Axes>(axes)) - {} + Slice_Op() : OperatorTensor(Type, 4, 0, 1) {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its @@ -52,8 +37,7 @@ public: * @param op Operator to copy. */ Slice_Op(const Slice_Op &op) - : OperatorTensor(op), - Attributes_(op) + : OperatorTensor(op) { if (op.mImpl){ SET_IMPL_MACRO(Slice_Op, *this, op.mOutputs[0]->getImpl()->backend()); @@ -77,7 +61,7 @@ public: } static const std::vector<std::string> getInputsName(){ - return {"data_input"}; + return {"data_input", "starts", "ends", "axes"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; @@ -86,29 +70,12 @@ public: /** * @brief Exract a sub-Tensor from a bigger original Tensor. - * @param starts Indexes for each dimension of the first element. - * Can be a negative value. Negative values start their reference from the last index. - * ``-1`` referes to the last index of a dimension. - * @param ends Indexes for each dimension of the last element. - * Can be a negative value. Negative values start their reference from the last index. - * ``-1`` referes to the last index of a dimension. - * @param axes Dimensions for which start/end indexes apply. Not specifying a dimensions - * means the whole dimensions is extracted. * @param name Name of the Operator. * @return std::shared_ptr<Node> A Node containing the Operator. */ -inline std::shared_ptr<Node> Slice(const std::vector<std::int64_t> starts, - const std::vector<std::int64_t> ends, - const std::vector<std::int64_t> axes, - const std::string &name = "") { - // FIXME: properly handle default w&b initialization in every cases - return std::make_shared<Node>(std::make_shared<Slice_Op>(starts, ends, axes), name); +inline std::shared_ptr<Node> Slice(const std::string &name = "") { + return std::make_shared<Node>(std::make_shared<Slice_Op>(), name); } } // namespace Aidge -namespace { -template <> -const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Starts", "Ends", "Axes" }; -} - #endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ diff --git a/python_binding/operator/pybind_Gather.cpp b/python_binding/operator/pybind_Gather.cpp index 493c5c118..e999aa6ab 100644 --- a/python_binding/operator/pybind_Gather.cpp +++ b/python_binding/operator/pybind_Gather.cpp @@ -25,6 +25,6 @@ void init_Gather(py::module& m) { .def("attributes_name", &Gather_Op::staticGetAttrsName); declare_registrable<Gather_Op>(m, "GatherOp"); - m.def("Gather", &Gather, py::arg("axis")=0, py::arg("name") = ""); + m.def("Gather", &Gather, py::arg("axis") = 0, py::arg("name") = ""); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Slice.cpp b/python_binding/operator/pybind_Slice.cpp index 3bb1b082c..45baa1d9a 100644 --- a/python_binding/operator/pybind_Slice.cpp +++ b/python_binding/operator/pybind_Slice.cpp @@ -22,6 +22,7 @@ void init_Slice(py::module& m) { .def("get_inputs_name", &Slice_Op::getInputsName) .def("get_outputs_name", &Slice_Op::getOutputsName); declare_registrable<Slice_Op>(m, "SliceOp"); - m.def("Slice", &Slice, py::arg("starts"), py::arg("ends"), py::arg("axes"), py::arg("name") = ""); + + m.def("Slice", &Slice, py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Slice.cpp b/src/operator/Slice.cpp index 6d2670695..3062895b7 100644 --- a/src/operator/Slice.cpp +++ b/src/operator/Slice.cpp @@ -8,17 +8,16 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ -#include "aidge/operator/Slice.hpp" -#include "aidge/utils/Types.h" -#include "aidge/utils/ErrorHandling.hpp" #include <cassert> #include <cstddef> +#include <cstdint> #include <string> #include <utility> #include <vector> #include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/Slice.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" @@ -26,28 +25,50 @@ const std::string Aidge::Slice_Op::Type = "Slice"; void Aidge::Slice_Op::computeOutputDims() { // check input have been associated - if (!getInput(0) || (getInput(0)->empty())) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); + if (!getInput(0) || !getInput(1) || !getInput(2) || !getInput(3)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Every input should be associated with a Tensor"); } - const DimSize_t nbAxes = this->template getAttr<SliceAttr::Axes>().size(); - std::vector<DimSize_t> outDims = getInput(0)->dims(); - for (std::size_t i = 0; i < nbAxes; ++i) { - // For each slice operation get the params and cast them to size_t - const std::int64_t axis_ = this->template getAttr<SliceAttr::Axes>()[i]; - const std::int64_t start_ = this->template getAttr<SliceAttr::Starts>()[i]; - const std::int64_t end_ = this->template getAttr<SliceAttr::Ends>()[i]; - const std::size_t axis = axis_ >= 0 ? static_cast<std::size_t>(axis_) : static_cast<std::size_t>(axis_) + getInput(0)->nbDims(); - const std::size_t start = start_ >= 0 ? static_cast<std::size_t>(start_) : static_cast<std::size_t>(start_) + getInput(0)->dims()[axis]; - const std::size_t end = end_ >= 0 ? static_cast<std::size_t>(end_) : static_cast<std::size_t>(end_) + getInput(0)->dims()[axis]; - - const std::size_t sliceLength = end - start + 1; - // Check if slice length is valid - if (sliceLength > getInput(0)->dims()[axis]) - { - AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); + if((!getInput(0)->empty()) && (!getInput(1)->empty()) && (!getInput(2)->empty()) && (!getInput(3)->empty())) + { + const auto starts = mInputs[1]->getImpl()->rawPtr(); + const auto ends = mInputs[2]->getImpl()->rawPtr(); + const auto axes = mInputs[3]->getImpl()->rawPtr(); + DimSize_t nbAxes = mInputs[1]->size(); + std::vector<DimSize_t> outDims = getInput(0)->dims(); + for (std::size_t i = 0; i < nbAxes; ++i) { + // For each slice operation get the params and cast them to size_t + std::size_t axis, start, end; //TODO find a better way to cast "starts", "ends" and "axes" + if (mInputs[1]->dataType() == DataType::Float32 && mInputs[2]->dataType() == DataType::Float32 && mInputs[3]->dataType() == DataType::Float32) + { + const float* axes_ = static_cast<float*>(axes); + axis = axes_[i] >= 0 ? static_cast<std::size_t>(axes_[i]) : static_cast<std::size_t>(axes_[i]) + getInput(0)->nbDims(); + const float* starts_ = static_cast<float*>(starts); + start = starts_[i] >= 0 ? static_cast<std::size_t>(starts_[i]) : static_cast<std::size_t>(starts_[i]) + getInput(0)->dims()[axis]; + const float* ends_ = static_cast<float*>(ends); + end = ends_[i] >= 0 ? static_cast<std::size_t>(ends_[i]) : static_cast<std::size_t>(ends_[i]) + getInput(0)->dims()[ends_[i]]; + } + else if(mInputs[1]->dataType() == DataType::Int32 && mInputs[2]->dataType() == DataType::Int32 && mInputs[3]->dataType() == DataType::Int32) + { + const std::int32_t* axes_ = static_cast<std::int32_t*>(axes); + axis = axes_[i] >= 0 ? static_cast<std::size_t>(axes_[i]) : static_cast<std::size_t>(axes_[i]) + getInput(0)->nbDims(); + const std::int32_t* starts_ = static_cast<std::int32_t*>(starts); + start = starts_[i] >= 0 ? static_cast<std::size_t>(starts_[i]) : static_cast<std::size_t>(starts_[i]) + getInput(0)->dims()[axis]; + const std::int32_t* ends_ = static_cast<std::int32_t*>(ends); + end = ends_[i] >= 0 ? static_cast<std::size_t>(ends_[i]) : static_cast<std::size_t>(ends_[i]) + getInput(0)->dims()[ends_[i]]; + } + else + { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Slice inputs type is not supported yet"); + } + const std::size_t sliceLength = end - start; + // Check if slice length is valid + if (sliceLength > getInput(0)->dims()[axis]) + { + AIDGE_THROW_OR_ABORT(std::runtime_error, "ROI of Slice operator out of bounds"); + } + outDims[axis] = sliceLength; } - outDims[axis] = sliceLength; + mOutputs[0]->resize(outDims); } - mOutputs[0]->resize(outDims); -} +} \ No newline at end of file diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp index 8e27fea58..7e08457bc 100644 --- a/src/recipes/HorizontalTiling.cpp +++ b/src/recipes/HorizontalTiling.cpp @@ -93,7 +93,10 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: } std::vector<std::int64_t> usedDims(inputDimsEnd.size()); std::iota(usedDims.begin(), usedDims.end(), static_cast<std::int64_t>(0)); - auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); + Tensor(std::vector<std::size_t>({inputDimsStart.size()})); + // TODO create producer nodes for the attributes + // auto slice = Slice(inputDimsStart, inputDimsEnd, usedDims, "Slice_" + std::to_string(currentFirstDims[axis])); + auto slice = Slice("Slice_" + std::to_string(currentFirstDims[axis])); slice -> addChild(newNode, 0, 0); newNode -> addChild(concat, 0, i); -- GitLab