From 1b929c79842c709ac178f88da69a72acfdc779cd Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 22 Nov 2023 15:53:55 +0000
Subject: [PATCH] Multiple changes

- Remove setInput in Node
- Change setDatatype to setDataType in GraphView and Tensor binding
- Add namespace comment
- Update Node includes
- Run forwardDims() only if operators use Tensors
---
 include/aidge/graph/Node.hpp              |  2 +-
 include/aidge/recipies/Recipies.hpp       |  2 +-
 python_binding/data/pybind_Tensor.cpp     |  2 +-
 python_binding/graph/pybind_GraphView.cpp |  2 +-
 python_binding/graph/pybind_Node.cpp      |  6 +--
 src/graph/GraphView.cpp                   | 64 +++++++++++++----------
 src/graph/Node.cpp                        | 25 ++++-----
 7 files changed, 55 insertions(+), 48 deletions(-)

diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp
index b81f5288e..9717999da 100644
--- a/include/aidge/graph/Node.hpp
+++ b/include/aidge/graph/Node.hpp
@@ -169,7 +169,7 @@ public:
    * @param idx Input index.
    * @param tensor Constant Tensor to add as parent for specified index.
    */
-  void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor);
+  // void setInput(const IOIndex_t idx, const std::shared_ptr<Tensor> tensor);
 
   /**
    * @brief Get the lowest index in the InputData Parent list equal to the
diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp
index 97544937e..26f4cc9da 100644
--- a/include/aidge/recipies/Recipies.hpp
+++ b/include/aidge/recipies/Recipies.hpp
@@ -89,6 +89,6 @@ void fuseBatchNorm(std::shared_ptr<GraphView> graphView);
 // std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
 // void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
 
-}
+} // namespace Aidge
 
 #endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
diff --git a/python_binding/data/pybind_Tensor.cpp b/python_binding/data/pybind_Tensor.cpp
index 31470e0eb..babc534bd 100644
--- a/python_binding/data/pybind_Tensor.cpp
+++ b/python_binding/data/pybind_Tensor.cpp
@@ -35,7 +35,7 @@ void addCtor(py::class_<Tensor,
         /* Request a buffer descriptor from Python */
         py::buffer_info info = b.request();
         Tensor* newTensor = new Tensor();
-        newTensor->setDatatype(NativeType<T>::type);
+        newTensor->setDataType(NativeType<T>::type);
         const std::vector<DimSize_t> dims(info.shape.begin(), info.shape.end());
         newTensor->resize(dims);
         // TODO : Find a better way to choose backend
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index 6ac2199b4..6a29c6941 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -89,7 +89,7 @@ void init_GraphView(py::module& m) {
           .def("get_node", &GraphView::getNode, py::arg("node_name"))
           .def("forward_dims", &GraphView::forwardDims)
           .def("__call__", &GraphView::operator(), py::arg("connectors"))
-          .def("set_datatype", &GraphView::setDatatype, py::arg("datatype"))
+          .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
           .def("set_backend", &GraphView::setBackend, py::arg("backend"))
           //   .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
           //      // TODO : Should return error if backend not compatible with get
diff --git a/python_binding/graph/pybind_Node.cpp b/python_binding/graph/pybind_Node.cpp
index e3666d247..3b63189c8 100644
--- a/python_binding/graph/pybind_Node.cpp
+++ b/python_binding/graph/pybind_Node.cpp
@@ -90,7 +90,7 @@ void init_Node(py::module& m) {
             .def("input", &Node::input, py::arg("in_id"),
             R"mydelimiter(
             Get the parent Node and the associated output index connected to the i-th input of the current Node.
-            
+
             :param in_id: input index of the current Node object.
             :type in_id: int
             :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
@@ -108,7 +108,7 @@ void init_Node(py::module& m) {
             .def("output", &Node::output, py::arg("out_id"),
             R"mydelimiter(
             Get a list of the children Node for a specific output and the associated input index connected to it.
-            
+
             :param out_id: input index of the current Node object.
             :type out_id: int
             :return: i-th connection. When an input is not linked to any parent, the default value is (None, default_index)
@@ -122,7 +122,7 @@ void init_Node(py::module& m) {
             :rtype: int
             )mydelimiter")
 
-            .def("get_nb_datainputs", &Node::nbDataInputs,
+            .def("get_nb_data", &Node::nbData,
             R"mydelimiter(
             Number of data inputs.
 
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index af3e24c20..2306ec8ab 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -17,6 +17,7 @@
 #include "aidge/utils/Types.h"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/data/Tensor.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
 
 ///////////////////////////////////////////////////////
@@ -171,7 +172,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
     setBackend(backend);
     // Data type
     // TODO: manage Datatype attribute in OperatorImpl
-    setDatatype(datatype);
+    setDataType(datatype);
     // Data Format
     // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
     // Forward dimensions
@@ -208,41 +209,46 @@ void Aidge::GraphView::forwardDims() {
 }
 
 void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
-  // TODO: support multi-inputs/outputs
-  std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
-  for (std::shared_ptr<Node> nodePtr : listNodes) {
-    if (!nodePtr->getOperator()->outputDimsForwarded()) {
-      nodePtr->getOperator()->computeOutputDims();
-    }
-    if (!nodePtr->getOperator()->outputDimsForwarded()) {
-      nextList.insert(nodePtr);
-    } else {
-      std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
-      nextList.insert(children.begin(), children.end());
+    // TODO: support multi-inputs/outputs
+    std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
+    for (std::shared_ptr<Node> nodePtr : listNodes) {
+        if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
+            const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
+            if (!op->outputDimsForwarded()) {
+                op->computeOutputDims();
+            }
+            if (!op->outputDimsForwarded()) { // try to compute output dimensions again later
+                nextList.insert(nodePtr);
+            } else { // compute output dimensions of children
+                std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
+                nextList.insert(children.begin(), children.end());
+            }
+        }
     }
-  }
-  if (nextList.empty()) {
-    for (std::shared_ptr<Node> nodePtr : getNodes()) {
-      if (!nodePtr->getOperator()->outputDimsForwarded()) {
-        nextList.insert(nodePtr);
-      }
+    if (nextList.empty()) {
+        for (std::shared_ptr<Node> nodePtr : getNodes()) {
+            if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
+                if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) {
+                    nextList.insert(nodePtr);
+                }
+            }
+        }
+    }
+    if (!nextList.empty()) {
+        _forwardDims(nextList);
     }
-  }
-  if (!nextList.empty()) {
-    _forwardDims(nextList);
-  }
 }
 
 void Aidge::GraphView::setBackend(const std::string &backend) {
-  for (auto node : getNodes()) {
-    node->getOperator()->setBackend(backend);
-  }
+    for (auto node : getNodes()) {
+        node->getOperator()->setBackend(backend);
+    }
 }
 
-void Aidge::GraphView::setDatatype(const Aidge::DataType &datatype) {
-  for (auto node : getNodes()) {
-    node->getOperator()->setDatatype(datatype);
-  }
+void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) {
+    for (auto node : getNodes()) {
+        node->getOperator()->setDataType(datatype);
+    }
 }
 
 void Aidge::GraphView::updateOutputNodes() {
diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp
index 6dca8eaaf..5a7b05e46 100644
--- a/src/graph/Node.cpp
+++ b/src/graph/Node.cpp
@@ -15,6 +15,7 @@
 #include "aidge/operator/Producer.hpp"
 #include <memory>
 #include <vector>
+#include "aidge/operator/OperatorTensor.hpp"
 #include "aidge/utils/Types.h"
 
 Aidge::Node::Node(std::shared_ptr<Operator> op, const std::string& name)
@@ -111,18 +112,18 @@ std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::No
     return res;
 }
 
-void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) {
-    assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
-    if (mParents[idx] != nullptr) {
-        mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
-        removeParent(idx);
-    }
-    std::shared_ptr<Node> newConstantNode = Producer(tensor);
-    newConstantNode->addChild(shared_from_this(), 0, idx);
-    for (auto& graphPtr : views()) {
-        graphPtr->add(newConstantNode);
-    }
-}
+// void Aidge::Node::setInput(const Aidge::IOIndex_t idx, const std::shared_ptr<Aidge::Tensor> tensor) {
+//     assert(((idx != gk_IODefaultIndex) && (idx < nbInputs())) && "Parent index out of bound.");
+//     if (mParents[idx] != nullptr) {
+//         mParents[idx]->removeChild(shared_from_this(), mIdOutParents[idx]);
+//         removeParent(idx);
+//     }
+//     std::shared_ptr<Node> newConstantNode = Producer(tensor);
+//     newConstantNode->addChild(shared_from_this(), 0, idx);
+//     for (auto& graphPtr : views()) {
+//         graphPtr->add(newConstantNode);
+//     }
+// }
 
 std::vector<std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>>>
 Aidge::Node::outputs() const {
-- 
GitLab