From dc5ae734a8f3fbc91a3eafbd5fc5e2391be4400d Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Thu, 26 Oct 2023 14:55:24 +0000
Subject: [PATCH] [Add] Introduce compile() member function to set the
 GraphView ready for forward in one line

---
 include/aidge/graph/GraphView.hpp   | 25 ++++++++++++++++++++-----
 src/graph/GraphView.cpp             | 15 ++++++++++++++-
 unit_tests/graph/Test_GraphView.cpp |  2 +-
 3 files changed, 35 insertions(+), 7 deletions(-)

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 89ba14849..e87f6a3e8 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -162,6 +162,21 @@ public:
     std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs(
             std::string nodeName) const;
 
+    /**
+     * @brief Assert Datatype, Backend, data format and dimensions along the GraphView are coherent.
+     * If not, apply the required transformations.
+     * @details Sets the GraphView ready for computation in four steps:
+     * 1 - Assert input Tensors' datatype is compatible with each Operator's datatype.
+     * If not, a conversion Operator is inserted.
+     * 2 - Assert input Tensors' backend is compatible with each Operator's backend.
+     * If not, add a Transmitter Operator.
+     * 3 - Assert data format (NCHW, NHWC, ...) of each Operator's input Tensor is
+     * compatible with the selected kernel.
+     * If not, add a Transpose Operator.
+     * 4 - Propagate Tensor dimensions through the consecutive Operators.
+     */
+    void compile(const std::string& backend, const Aidge::DataType datatype);
+
     /**
      * @brief Compute dimensions of input/output Tensors for each Operator of the
      * GraphView object's Nodes.
@@ -322,17 +337,17 @@ public:
 
     /**
      * @brief Insert a node (newParentNode) as a parent of the passed node (childNode).
-     * 
+     *
      * @param childNode Node that gets a new parent.
      * @param newParentNode Inserted Node.
      * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output.
      * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode.
      * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor.
      */
-    void insertParent(NodePtr childNode, 
-                        NodePtr newParentNode, 
-                        IOIndex_t childInputTensorIdx, 
-                        IOIndex_t newParentInputTensorIdx, 
+    void insertParent(NodePtr childNode,
+                        NodePtr newParentNode,
+                        IOIndex_t childInputTensorIdx,
+                        IOIndex_t newParentInputTensorIdx,
                         IOIndex_t newParentOutputTensorIdx);
 
     /**
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 8f8f51c89..1ca54c9c1 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -164,6 +164,19 @@ Aidge::GraphView::inputs(std::string name) const {
   return mNodeRegistry.at(name)->inputs();
 }
 
+void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype) {
+    // Backend
+    // TODO: add Backend attribute to Operator
+    setBackend(backend);
+    // Data type
+    // TODO: manage Datatype attribute in OperatorImpl
+    setDatatype(datatype);
+    // Data Format
+    // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary
+    // Forward dimensions
+    forwardDims();
+}
+
 void Aidge::GraphView::forwardDims() {
     // setInputs
     // Link every tensor to the right pointer
@@ -225,7 +238,7 @@ void Aidge::GraphView::setBackend(const std::string &backend) {
   }
 }
 
-void Aidge::GraphView::setDatatype(const DataType &datatype) {
+void Aidge::GraphView::setDatatype(const Aidge::DataType &datatype) {
   for (auto node : getNodes()) {
     node->getOperator()->setDatatype(datatype);
   }
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index 9f0143646..0811f4abf 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -244,7 +244,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") {
     }
 }
 
-TEST_CASE("Graph Forward dims", "[GraphView]") {
+TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") {
     auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
     auto conv1 = Conv(3, 32, {3, 3}, "conv1");
     auto conv2 = Conv(32, 64, {3, 3}, "conv2");
-- 
GitLab