From 6cec7f27a4d5672515a5606d0664aaf61bb606ae Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Wed, 13 Dec 2023 16:55:19 +0100
Subject: [PATCH] switch shape input to attr for Reshape

---
 include/aidge/operator/Reshape.hpp         | 42 +++++++++++++++-------
 python_binding/operator/pybind_Reshape.cpp |  2 +-
 src/operator/Reshape.cpp                   | 16 +++++----
 3 files changed, 41 insertions(+), 19 deletions(-)

diff --git a/include/aidge/operator/Reshape.hpp b/include/aidge/operator/Reshape.hpp
index 368f4fff3..1ffa04596 100644
--- a/include/aidge/operator/Reshape.hpp
+++ b/include/aidge/operator/Reshape.hpp
@@ -16,30 +16,42 @@
 #include <memory>
 #include <vector>
 
-#include "aidge/utils/Registrar.hpp"
-#include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/backend/OperatorImpl.hpp"
-#include "aidge/data/Tensor.hpp"
-#include "aidge/data/Data.hpp"
 #include "aidge/graph/Node.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/utils/Registrar.hpp"
+#include "aidge/utils/StaticAttributes.hpp"
 #include "aidge/utils/Types.h"
 
 namespace Aidge {
 
+enum class ReshapeAttr { Shape };
+
 class Reshape_Op : public OperatorTensor,
-    public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)> {
+                   public Registrable<Reshape_Op, std::string, std::unique_ptr<OperatorImpl>(const Reshape_Op&)>,
+                   public StaticAttributes<ReshapeAttr, std::vector<std::int64_t>> {
 
 public:
-    static constexpr const char* Type = "Reshape";
+    static const std::string Type;
+
+    Reshape_Op() = delete;
 
-    Reshape_Op() : OperatorTensor(Type, 2, 0, 1) {}
+    using Attributes_ = StaticAttributes<ReshapeAttr, std::vector<std::int64_t>>;
+    template <ReshapeAttr e>
+    using attr = typename Attributes_::template attr<e>;
+
+    Reshape_Op(const std::vector<std::int64_t>& shape)
+        : OperatorTensor(Type, 1, 0, 1),
+          Attributes_(attr<ReshapeAttr::Shape>(shape))
+    {}
 
     /**
      * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated).
      * @param op Operator to copy.
      */
     Reshape_Op(const Reshape_Op& op)
-        : OperatorTensor(op)
+        : OperatorTensor(op),
+          Attributes_(op)
     {
         mImpl = op.mImpl ? Registrar<Reshape_Op>::create(mOutputs[0]->getImpl()->backend())(*this) : nullptr;
     }
@@ -60,20 +72,26 @@ public:
 
         // FIXME: temporary workaround
         getInput(0)->setBackend(name);
-        getInput(1)->setBackend(name);
     }
 
     static const std::vector<std::string> getInputsName(){
-        return {"data_input", "output_shape"};
+        return {"data_input"};
     }
     static const std::vector<std::string> getOutputsName(){
         return {"data_output"};
     }
 };
 
-inline std::shared_ptr<Node> Reshape(const std::string& name = "") {
-    return std::make_shared<Node>(std::make_shared<Reshape_Op>(), name);
+inline std::shared_ptr<Node> Reshape(const std::vector<std::int64_t>& shape,
+                                   		const std::string &name = "") {
+    // FIXME: properly handle default w&b initialization in every cases
+    return std::make_shared<Node>(std::make_shared<Reshape_Op>(shape), name);
 }
+}  // namespace Aidge
+
+namespace {
+template <>
+const char *const EnumStrings<Aidge::ReshapeAttr>::data[] = { "Shape" };
 }
 
 #endif /* AIDGE_CORE_OPERATOR_RESHAPE_H_ */
diff --git a/python_binding/operator/pybind_Reshape.cpp b/python_binding/operator/pybind_Reshape.cpp
index 35c26c09d..d34a411c7 100644
--- a/python_binding/operator/pybind_Reshape.cpp
+++ b/python_binding/operator/pybind_Reshape.cpp
@@ -22,6 +22,6 @@ void init_Reshape(py::module& m) {
     .def("get_inputs_name", &Reshape_Op::getInputsName)
     .def("get_outputs_name", &Reshape_Op::getOutputsName);
 
-    m.def("Reshape", &Reshape, py::arg("name") = "");
+    m.def("Reshape", &Reshape, py::arg("shape"), py::arg("name") = "");
 }
 }  // namespace Aidge
diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp
index f32e8b5af..2464d37a8 100644
--- a/src/operator/Reshape.cpp
+++ b/src/operator/Reshape.cpp
@@ -11,6 +11,7 @@
 
 #include <cassert>
 #include <cstddef>
+#include <string>
 #include <vector>
 #include <utility>
 
@@ -19,21 +20,24 @@
 #include "aidge/utils/Types.h"
 #include "aidge/utils/ErrorHandling.hpp"
 
+
+const std::string Aidge::Reshape_Op::Type = "Reshape";
+
 void Aidge::Reshape_Op::computeOutputDims() {
     // check inputs have been associated
-    if (!getInput(0) || !getInput(1)) {
-        AIDGE_THROW_OR_ABORT(std::runtime_error, "At least one input was not connected");
+    if (!getInput(0)) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
     }
 
+    DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size();
     std::vector<DimSize_t> outDims;
     std::size_t outSize = 1;
-    int* shapeElem = static_cast<int*>(getInput(1)->getImpl()->rawPtr());
-    for(std::size_t i=0; i<mInputs[1]->size(); ++i)
+    for(std::size_t i=0; i<nbOutDims; ++i)
     {
-        int dimSize = shapeElem[i];
+        int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
         if (dimSize < 1)
         {
-            AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input");
+            AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value");
         }
         outDims.push_back(dimSize);
         outSize *= dimSize;
-- 
GitLab