From cea8cbcd21a40736fb82559bf548efbcb3b38c93 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Wed, 3 Apr 2024 11:33:57 +0000
Subject: [PATCH] update 'GraphView::compile()' member function

---
 include/aidge/graph/GraphView.hpp         |  5 ++++-
 python_binding/graph/pybind_GraphView.cpp |  2 +-
 src/graph/GraphView.cpp                   | 12 ++++++------
 3 files changed, 11 insertions(+), 8 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 06f73c97f..845599fd3 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -201,7 +201,10 @@ public:
      * If not, add a Transpose Operator.
      * 4 - Propagate Tensor dimensions through the consecutive Operators.
      */
-    void compile(const std::string& backend = "cpu", const Aidge::DataType datatype = DataType::Float32, DeviceIdx_t device = 0);
+    void compile(const std::string& backend = "cpu",
+                 const Aidge::DataType datatype = DataType::Float32,
+                 DeviceIdx_t device = 0,
+                 const std::vector<std::vector<DimSize_t>> dims = {});
 
     /**
      * @brief Compute dimensions of input/output Tensors for each Operator of the
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index f06a70f32..953ec981e 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -118,7 +118,7 @@ void init_GraphView(py::module& m) {
           .def("get_nodes", &GraphView::getNodes)
           .def("get_node", &GraphView::getNode, py::arg("node_name"))
           .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>())
-          .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0)
+          .def("compile", &GraphView::compile, py::arg("backend"), py::arg("datatype"), py::arg("device") = 0, py::arg("dims")=std::vector<std::vector<DimSize_t>>())
           .def("__call__", &GraphView::operator(), py::arg("connectors"))
           .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
           .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index f498d5e82..dcd7a06ef 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -378,7 +378,7 @@ Aidge::GraphView::inputs(const std::string& name) const {
   return mNodeRegistry.at(name)->inputs();
 }
 
-void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device) {
+void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype, DeviceIdx_t device, const std::vector<std::vector<DimSize_t>> dims) {
     // Backend
     // TODO: add Backend attribute to Operator
     setBackend(backend, device);
@@ -388,7 +388,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
     // Data Format
     // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
     // Forward dimensions
-    forwardDims();
+    forwardDims(dims);
 }
 
 void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) {
@@ -913,14 +913,14 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
     // keep in memory every node related to the node to replace :
     // Parent
     for (std::size_t i = 0; i < oldOIn.size(); ++i) {
-        std::pair<NodePtr, IOIndex_t> inputParent = 
+        std::pair<NodePtr, IOIndex_t> inputParent =
                   oldOIn[i].first -> input(oldOIn[i].second);
         inputParents[i]= inputParent;
         // inputParent.first -> addChild(newOI[i].first, inputParent.second, newOI[i].second);
     }
     // Children
     for (std::size_t i = 0; i < oldOOut.size();) {
-        std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild = 
+        std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> outputChild =
               oldOOut[i].first -> output(oldOOut[i].second);
         if (outputChild.empty()) {
             outputChildren[i] = std::pair<std::shared_ptr<Node>, IOIndex_t>({nullptr, gk_IODefaultIndex});
@@ -983,7 +983,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
                 for (std::size_t i = 0; i < oldOIn.size(); ++i) {
                     if (inputParents[i].first) {
                       inputParents[i].first -> addChild(outputChildren[i].first, inputParents[i].second, outputChildren[i].second);
-                    }  
+                    }
                 }
             }
             else if ((oldOIn.size() == 1) && (inputParents[0].first)) {
@@ -1259,7 +1259,7 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo
   if (deletedNode == mRootNode) {
     const std::pair<std::vector<NodePtr>, size_t> ranked_nodes = getRankedNodes();
     if(ranked_nodes.second== 0 || ranked_nodes.first.size() <= 1)
-    {      
+    {
       mRootNode = nullptr;
     } else {
       // The new root node will be the second node in the order of ranked nodes
-- 
GitLab