From 233fed88ecd9d2298f1ca6c69356f4a1d91d4443 Mon Sep 17 00:00:00 2001
From: Christophe Guillon <christophe.guillon@inria.fr>
Date: Fri, 13 Sep 2024 15:47:07 +0200
Subject: [PATCH] [GraphView]: Add getOrderedNodes() method for topological
 order

Add getOrderedNodes(reversed=false) method which produces
a topological order of the graph view (or of the reversed
graph view when reversed is true).
This order is also deterministic given the nodes set and
the ordered output nodes.
Cyclic graphs are handled by breaking back edges as declared
in the new setBackEdges() method of the Operator class.
The order is suitable for resolving constant propagation
of tensor types in one pass including for cyclic graphs
if back edges are declared correctly for operators.
It is also suitable for exporting to onnx graphs
which require a topological order of exported acyclic graphs.
---
 .../unit_tests/test_topological_order.py      |  67 +++++++++++
 include/aidge/graph/GraphView.hpp             |  18 +++
 include/aidge/graph/Node.hpp                  |  10 ++
 include/aidge/operator/Operator.hpp           |  19 +++-
 python_binding/graph/pybind_GraphView.cpp     |   4 +
 python_binding/operator/pybind_Operator.cpp   |   2 +
 src/graph/GraphView.cpp                       |  55 +++++++++
 src/operator/Memorize.cpp                     |   4 +-
 unit_tests/graph/Test_GraphView.cpp           | 104 ++++++++++++++++++
 9 files changed, 281 insertions(+), 2 deletions(-)
 create mode 100644 aidge_core/unit_tests/test_topological_order.py

diff --git a/aidge_core/unit_tests/test_topological_order.py b/aidge_core/unit_tests/test_topological_order.py
new file mode 100644
index 000000000..8e7f2e2d9
--- /dev/null
+++ b/aidge_core/unit_tests/test_topological_order.py
@@ -0,0 +1,67 @@
+"""
+Copyright (c) 2023 CEA-List
+
+This program and the accompanying materials are made available under the
+terms of the Eclipse Public License 2.0 which is available at
+http://www.eclipse.org/legal/epl-2.0.
+
+SPDX-License-Identifier: EPL-2.0
+"""
+
+import unittest
+import aidge_core
+
+class test_topological_order(unittest.TestCase):
+    """Test python binding for nodes ordering"""
+
+    def setUp(self):
+        pass
+
+    def tearDown(self):
+        pass
+
+    def test_generic_loop_order_0(self):
+        # Defines a Generic recurring loop header operator with
+        # inputs: (init, back) and outputs (loop, last)
+        # Note that one must specify the back edge as otherwise the
+        # generated order may not schedule the loop header before the add
+        loop0 = aidge_core.GenericOperator("Loop", 2, 0, 2, "Loop#0")
+        loop0.get_operator().set_back_edges({1})
+        assert not loop0.get_operator().is_back_edge(0)
+        assert loop0.get_operator().is_back_edge(1)
+        add0 = aidge_core.Add(2, "add0")
+
+        loop0.add_child(add0, 0, 1)
+        add0.add_child(loop0, 0, 1)
+        graph = aidge_core.GraphView()
+        graph.add(loop0)
+        graph.add(add0)
+
+        nodes = graph.get_ordered_nodes()
+        assert len(nodes) == 2
+        assert nodes == [loop0, add0]
+
+    def test_generic_loop_order_1(self):
+        # Defines a Generic recurring loop header operator with
+        # inputs: (back, init) and outputs (loop, last)
+        # Note that one must specify the back edge as otherwise the
+        # generated order may not schedule the loop header before the add
+        loop0 = aidge_core.GenericOperator("Loop", 2, 0, 2, "Loop#0")
+        loop0.get_operator().set_back_edges({0})
+        assert not loop0.get_operator().is_back_edge(1)
+        assert loop0.get_operator().is_back_edge(0)
+        add0 = aidge_core.Add(2, "add0")
+
+        loop0.add_child(add0, 0, 1)
+        add0.add_child(loop0, 0, 0)
+        graph = aidge_core.GraphView()
+        graph.add(loop0)
+        graph.add(add0)
+
+        nodes = graph.get_ordered_nodes()
+        assert len(nodes) == 2
+        assert nodes == [loop0, add0]
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 17bd3b1e9..efdb06c4a 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -150,6 +150,24 @@ public:
     void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
     void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
 
+    /**
+     * @brief Get a topological node order for an acyclic walk of the graph
+     * Graph cycles are broken on operator back edges such that resolution on
+     * single level lattice can be done in a single pass as it is
+     * the case generally for static resolution of Tensor shapes/datatypes.
+     * When reversed is true, gets a topological order on the reversed graph
+     * which is equivalent to a post-dfs order of the graph.
+     * The returned order is deterministic given the graph node set and the
+     * graph ordered output nodes.
+     * The output nodes connectivity must cover all nodes of the graph,
+     * otherwise a runtime exception is thrown.
+     * The returned order is biased toward left-to-right child order both
+     * for topological and post-dfs order.
+     * @param reversed returns a topological order of the reversed graph
+     * @return the ordered list of nodes
+     */
+    std::vector<Aidge::NodePtr> getOrderedNodes(bool reversed = false) const;
+
     /**
      * @brief Get inputs of the current GraphView with their associated id.
      * The rank of the nodes are their rank in the vector.
diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp
index ecc47c745..32932fa6f 100644
--- a/include/aidge/graph/Node.hpp
+++ b/include/aidge/graph/Node.hpp
@@ -260,6 +260,16 @@ public:
     return getOperator()->inputCategory(idx);
   }
 
+  /**
+   * @brief Returns whether the given node parent index is a back edge
+   * A back edge is defined by the operator and node parent index
+   * correspond to operator input index.
+   * @return true if the operator defines it as a back edge
+   */
+  inline bool parentIsBackEdge(IOIndex_t idx) const {
+    return getOperator()->isBackEdge(idx);
+  }
+
   /**
    * @brief Number of inputs linked to a Parent's output.
    * @return IOIndex_t
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 93e9664e2..87aa4080e 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -57,6 +57,7 @@ private:
     const OperatorType mOperatorType;
     const std::vector<InputCategory> mInputsCategory;
     const IOIndex_t mNbOut;
+    std::set<IOIndex_t> mBackEdges;
 
 public:
     Operator() = delete;
@@ -73,7 +74,8 @@ public:
         std::enable_shared_from_this<Operator>(),
         mOperatorType(op.mOperatorType),
         mInputsCategory(op.mInputsCategory),
-        mNbOut(op.mNbOut)
+        mNbOut(op.mNbOut),
+        mBackEdges(op.mBackEdges)
     {
         mType = op.mType;
         mImpl = nullptr;
@@ -208,6 +210,21 @@ public:
     inline IOIndex_t nbInputs() const noexcept { return mInputsCategory.size(); };
     inline IOIndex_t nbOutputs() const noexcept { return mNbOut; };
 
+    /**
+     * @brief Set the back edge input indexes for recurring operators.
+     * Any recuring operators should specify it's back edges, otherwise
+     * the interpretation of the data flow graph may not be possible.
+     */
+    inline void setBackEdges(const std::set<IOIndex_t>& backEdges) { mBackEdges = backEdges; }
+
+    /**
+     * @brief Returns whether the given input index is a back edge.
+     * @return true if the input index is in the back edge set
+     */
+    inline bool isBackEdge(IOIndex_t inputIdx) const {
+        return mBackEdges.find(inputIdx) != mBackEdges.end();
+    }
+
     static const std::vector<std::string> getInputsName() {
         return {};
     }
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index 293038381..c0ee183b0 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -130,6 +130,10 @@ void init_GraphView(py::module& m) {
           .def("__call__", &GraphView::operator(), py::arg("connectors"))
           .def("set_datatype", &GraphView::setDataType, py::arg("datatype"))
           .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0)
+          .def("get_ordered_nodes", &GraphView::getOrderedNodes, py::arg("reversed") = false,
+               R"mydelimiter(
+               Get ordered nodes for the graph view
+               )mydelimiter")
           //   .def("__getitem__", [](Tensor& b, size_t idx)-> py::object {
           //      // TODO : Should return error if backend not compatible with get
           //      if (idx >= b.size()) throw py::index_error();
diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp
index 81a62f4ed..6ffbdd007 100644
--- a/python_binding/operator/pybind_Operator.cpp
+++ b/python_binding/operator/pybind_Operator.cpp
@@ -63,6 +63,8 @@ void init_Operator(py::module& m){
     .def("get_hook", &Operator::getHook)
     .def("add_hook", &Operator::addHook)
     .def_property_readonly("attr", &Operator::attributes)
+    .def("set_back_edges", &Operator::setBackEdges, py::arg("input_indexes"))
+    .def("is_back_edge", &Operator::isBackEdge, py::arg("input_index"))
     ;
 }
 }
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 5a3eb695d..ef322fe5b 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -685,6 +685,61 @@ std::pair<std::vector<Aidge::NodePtr>, size_t> Aidge::GraphView::getRankedNodes(
   return std::make_pair(rankedNodes, orderUnicityLimit);
 }
 
+std::vector<Aidge::NodePtr> Aidge::GraphView::getOrderedNodes(bool reversed) const {
+    // We compute the order from a post-dfs walk on the reverse graph starting from
+    // ordered output nodes.
+    // Also, we walk the graph upward left to right in order
+    // to get a topological left-right order when possible.
+    // For the case where reversed is true, we walk the graph upward right to left
+    // and reverse the final order to get a post-dfs left-right order when possible.
+    std::vector<std::pair<NodePtr,std::pair<size_t, std::vector<NodePtr>>>> stack;
+    std::vector<NodePtr> reversePostDfs;
+    std::set<NodePtr> visited;
+    std::vector<NodePtr> outNodes(mNodes.size());
+    auto reverse_if_dfs = [reversed](auto &parents) {
+        if (reversed) std::reverse(parents.begin(), parents.end());
+    };
+    for (const auto& output : mOutputNodes) {
+            outNodes.push_back(output.first);
+    }
+    reverse_if_dfs(outNodes);
+    stack.push_back(std::make_pair(nullptr, std::make_pair(0, std::move(outNodes))));
+    while (!stack.empty()) {
+        auto node = stack.back().first;
+        auto& parentIdx = stack.back().second.first;
+        auto& parents = stack.back().second.second;
+        if (parentIdx == parents.size()) {
+            stack.pop_back();
+            if (node) {
+                reversePostDfs.push_back(node);
+            }
+        } else {
+            auto backEdgeIdx = reversed ? parents.size() - 1 - parentIdx: parentIdx;
+            auto isBackEdge = node != nullptr ? node->parentIsBackEdge(backEdgeIdx): false;
+            auto parent = parents[parentIdx++];
+            if (parent != nullptr && inView(parent) &&
+                visited.find(parent) == visited.end()) {
+                if (isBackEdge) {
+                    stack[0].second.second.push_back(parent);
+                } else {
+                    visited.insert(parent);
+                    auto next_parents = parent->getParents();
+                    reverse_if_dfs(next_parents);
+                    stack.push_back(std::make_pair(parent, std::make_pair(0, std::move(next_parents))));
+                }
+            }
+        }
+    }
+
+    if (reversePostDfs.size() != mNodes.size()) {
+        AIDGE_THROW_OR_ABORT(std::runtime_error,
+                             "Could not enumerate all nodes, set output nodes such that all graph nodes are connected.");
+    }
+
+    reverse_if_dfs(reversePostDfs);
+    return reversePostDfs;
+}
+
 std::map<Aidge::NodePtr, std::string> Aidge::GraphView::getRankedNodesName(const std::string& format, bool markNonUnicity) const {
   const auto rankedNodes = getRankedNodes();
   std::map<NodePtr, std::string> rankedNodesName;
diff --git a/src/operator/Memorize.cpp b/src/operator/Memorize.cpp
index c3d64d5bc..61239071a 100644
--- a/src/operator/Memorize.cpp
+++ b/src/operator/Memorize.cpp
@@ -82,6 +82,8 @@ Aidge::Memorize_Op::Memorize_Op(const std::uint32_t endStep)
                     attr<MemorizeAttr::ForwardStep>(0),
                     attr<MemorizeAttr::EndStep>(endStep)))
 {
+    // The input idx 0 is a back edge for Memorize where inputs are (back, init)
+    setBackEdges({0});
     mOutputs[1] = mOutputs[0];
 }
 
@@ -161,4 +163,4 @@ std::set<std::string> Aidge::Memorize_Op::getAvailableBackends() const {
 
 std::shared_ptr<Aidge::Node> Aidge::Memorize(const std::uint32_t endStep, const std::string& name) {
     return std::make_shared<Node>(std::make_shared<Memorize_Op>(endStep), name);
-}
\ No newline at end of file
+}
diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp
index 300712486..a08808ee5 100644
--- a/unit_tests/graph/Test_GraphView.cpp
+++ b/unit_tests/graph/Test_GraphView.cpp
@@ -23,6 +23,9 @@
 #include "aidge/data/Tensor.hpp"
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/graph/Testing.hpp"
+#include "aidge/operator/Add.hpp"
+#include "aidge/operator/Split.hpp"
+#include "aidge/operator/Memorize.hpp"
 #include "aidge/operator/Conv.hpp"
 #include "aidge/operator/ReLU.hpp"
 #include "aidge/graph/OpArgs.hpp"
@@ -440,6 +443,107 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") {
     }
 }
 
+TEST_CASE("[core/graph] GraphView(getOrderedNodes)", "[GraphView][getOrderedNodes]") {
+    auto data1 = Producer({2}, "data1");
+    auto data2 = Producer({2}, "data2");
+    auto data3 = Producer({2}, "data3");
+    auto add1 = Add(2, "add1");
+    auto add2 = Add(2, "add2");
+    auto split1 = Split(2, 0, {1, 1}, "split1");
+    auto add3 = Add(3, "add3");
+    auto g = std::make_shared<GraphView>("TestGraph");
+    data1->addChild(add1);
+    data2->addChild(add1);
+    add1->addChild(add2);
+    data3->addChild(add2);
+    add1->addChild(add3);
+    add2->addChild(split1);
+    split1->addChild(add3);
+    g->add(data1);
+    g->add(data2);
+    g->add(data3);
+    g->add(add1);
+    g->add(add2);
+    g->add(split1);
+    g->add(add3);
+    REQUIRE(g->getNodes().size() == 7);
+
+    auto topo = g->getOrderedNodes();
+    SECTION("Topological order") {
+        REQUIRE(topo[0] == data1);
+        REQUIRE(topo[1] == data2);
+        REQUIRE(topo[2] == add1);
+        REQUIRE(topo[3] == data3);
+        REQUIRE(topo[4] == add2);
+        REQUIRE(topo[5] == split1);
+        REQUIRE(topo[6] == add3);
+    }
+
+    auto pdfs = g->getOrderedNodes(true);
+    SECTION("Post DFS order") {
+        REQUIRE(pdfs[0] == add3);
+        REQUIRE(pdfs[1] == split1);
+        REQUIRE(pdfs[2] == add2);
+        REQUIRE(pdfs[3] == add1);
+        REQUIRE(pdfs[4] == data1);
+        REQUIRE(pdfs[5] == data2);
+        REQUIRE(pdfs[6] == data3);
+    }
+
+    // Invert output order
+    g->setOrderedOutputs({{split1, 1}, {add3, 0}});
+    SECTION("Topological order output reversed") {
+        // As add3 depends upon split1, the order should not be changed
+        auto topo2 = g->getOrderedNodes();
+        REQUIRE(topo2 == topo);
+    }
+
+    SECTION("Post DFS order output reversed") {
+        // As add3 depends upon split1, the order should not be changed
+        auto pdfs2 = g->getOrderedNodes(true);
+        REQUIRE(pdfs2 == pdfs);
+    }
+}
+
+TEST_CASE("[core/graph] GraphView(getOrderedNodes) cyclic", "[GraphView][getOrderedNodes]") {
+    auto data1 = Producer({2}, "data1");
+    auto data2 = Producer({2}, "data2");
+    auto add1 = Add(2, "add1");
+    auto mem1 = Memorize(1, "mem1");
+    auto add2 = Add(2, "add2");
+    auto g = std::make_shared<GraphView>("TestGraph");
+    data1->addChild(add1);
+    data2->addChild(add1);
+    add1->addChild(mem1, 0, 1); // init
+    data1->addChild(add2);
+    mem1->addChild(add2);
+    add2->addChild(mem1); // back edge
+    g->add(data1);
+    g->add(data2);
+    g->add(add1);
+    g->add(mem1);
+    g->add(add2);
+    REQUIRE(g->getNodes().size() == 5);
+
+    auto topo = g->getOrderedNodes();
+    SECTION("Topological order") {
+        REQUIRE(topo[0] == data1);
+        REQUIRE(topo[1] == data2);
+        REQUIRE(topo[2] == add1);
+        REQUIRE(topo[3] == mem1);
+        REQUIRE(topo[4] == add2);
+    }
+
+    auto pdfs = g->getOrderedNodes(true);
+    SECTION("post DFS order") {
+        REQUIRE(pdfs[0] == add2);
+        REQUIRE(pdfs[1] == mem1);
+        REQUIRE(pdfs[2] == add1);
+        REQUIRE(pdfs[3] == data1);
+        REQUIRE(pdfs[4] == data2);
+    }
+}
+
 TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") {
     auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
     auto conv1 = Conv(3, 32, {3, 3}, "conv1");
-- 
GitLab