From 6cec7f27a4d5672515a5606d0664aaf61bb606ae Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Wed, 13 Dec 2023 16:55:19 +0100 Subject: [PATCH] switch shape input to attr for Reshape --- include/aidge/operator/Reshape.hpp | 42 +++++++++++++++------- python_binding/operator/pybind_Reshape.cpp | 2 +- src/operator/Reshape.cpp | 16 +++++---- 3 files changed, 41 insertions(+), 19 deletions(-) diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp index 368f4fff3..1ffa04596 100644 --- a/include/aidge/operator/Reshape.hpp +++ b/include/aidge/operator/Reshape.hpp @@ -16,30 +16,42 @@ #include <memory> #include <vector> -#include "aidge/utils/Registrar.hpp" -#include "aidge/operator/OperatorTensor.hpp" #include "aidge/backend/OperatorImpl.hpp" -#include "aidge/data/Tensor.hpp" -#include "aidge/data/Data.hpp" #include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/Types.h" namespace Aidge { +enum class ReshapeAttr { Shape }; + class Reshape_Op : public OperatorTensor, - public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)> { + public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>, + public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> { public: - static constexpr const char* Type = "Reshape"; + static const std::string Type; + + Reshape_Op() = delete; - Reshape_Op() : OperatorTensor(Type, 2, 0, 1) {} + using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>; + template <ReshapeAttr e> + using attr = typename Attributes_::template attr<e>; + + Reshape_Op(const std::vector<std::int64_t>& shape) + : OperatorTensor(Type, 1, 0, 1), + Attributes_(attr<ReshapeAttr::Shape>(shape)) + {} /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ Reshape_Op(const Reshape_Op& op) - : OperatorTensor(op) + : OperatorTensor(op), + Attributes_(op) { mImpl = op.mImpl ? Registrar<Reshape_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr; } @@ -60,20 +72,26 @@ public: // FIXME: temporary workaround getInput(0)->setBackend(name); - getInput(1)->setBackend(name); } static const std::vector<std::string> getInputsName(){ - return {"data_input", "output_shape"}; + return {"data_input"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; } }; -inline std::shared_ptr<Node> Reshape(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Reshape_Op>(), name); +inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape, + const std::string &name = "") { + // FIXME: properly handle default w&b initialization in every cases + return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name); } +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape" }; } #endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */ diff --git a/python_binding/operator/pybind_Reshape.cpp b/python_binding/operator/pybind_Reshape.cpp index 35c26c09d..d34a411c7 100644 --- a/python_binding/operator/pybind_Reshape.cpp +++ b/python_binding/operator/pybind_Reshape.cpp @@ -22,6 +22,6 @@ void init_Reshape(py::module& m) { .def("get_inputs_name", &Reshape_Op::getInputsName) .def("get_outputs_name", &Reshape_Op::getOutputsName); - m.def("Reshape", &Reshape, py::arg("name") = ""); + m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = ""); } } // namespace Aidge diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index f32e8b5af..2464d37a8 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -11,6 +11,7 @@ #include <cassert> #include <cstddef> +#include <string> #include <vector> #include <utility> @@ -19,21 +20,24 @@ #include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" + +const std::string Aidge::Reshape_Op::Type = "Reshape"; + void Aidge::Reshape_Op::computeOutputDims() { // check inputs have been associated - if (!getInput(0) || !getInput(1)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected"); + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); } + DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size(); std::vector<DimSize_t> outDims; std::size_t outSize = 1; - int* shapeElem = static_cast<int*>(getInput(1)->getImpl()->rawPtr()); - for(std::size_t i=0; i<mInputs[1]->size(); ++i) + for(std::size_t i=0; i<nbOutDims; ++i) { - int dimSize = shapeElem[i]; + int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; if (dimSize < 1) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input"); + AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value"); } outDims.push_back(dimSize); outSize *= dimSize; -- GitLab