diff --git a/aidge_core/unit_tests/test_topological_order.py b/aidge_core/unit_tests/test_topological_order.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7f2e2d9b9770c2fae1e5c2812ba33113589134 --- /dev/null +++ b/aidge_core/unit_tests/test_topological_order.py @@ -0,0 +1,67 @@ +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" + +import unittest +import aidge_core + +class test_topological_order(unittest.TestCase): + """Test python binding for nodes ordering""" + + def setUp(self): + pass + + def tearDown(self): + pass + + def test_generic_loop_order_0(self): + # Defines a Generic recurring loop header operator with + # inputs: (init, back) and outputs (loop, last) + # Note that one must specify the back edge as otherwise the + # generated order may not schedule the loop header before the add + loop0 = aidge_core.GenericOperator("Loop", 2, 0, 2, "Loop#0") + loop0.get_operator().set_back_edges({1}) + assert not loop0.get_operator().is_back_edge(0) + assert loop0.get_operator().is_back_edge(1) + add0 = aidge_core.Add(2, "add0") + + loop0.add_child(add0, 0, 1) + add0.add_child(loop0, 0, 1) + graph = aidge_core.GraphView() + graph.add(loop0) + graph.add(add0) + + nodes = graph.get_ordered_nodes() + assert len(nodes) == 2 + assert nodes == [loop0, add0] + + def test_generic_loop_order_1(self): + # Defines a Generic recurring loop header operator with + # inputs: (back, init) and outputs (loop, last) + # Note that one must specify the back edge as otherwise the + # generated order may not schedule the loop header before the add + loop0 = aidge_core.GenericOperator("Loop", 2, 0, 2, "Loop#0") + loop0.get_operator().set_back_edges({0}) + assert not loop0.get_operator().is_back_edge(1) + assert loop0.get_operator().is_back_edge(0) + add0 = aidge_core.Add(2, "add0") + + loop0.add_child(add0, 0, 1) + add0.add_child(loop0, 0, 0) + graph = aidge_core.GraphView() + graph.add(loop0) + graph.add(add0) + + nodes = graph.get_ordered_nodes() + assert len(nodes) == 2 + assert nodes == [loop0, add0] + + +if __name__ == '__main__': + unittest.main() diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 17bd3b1e9aeece2c80dab8c1aa1cba6498cc730f..efdb06c4ac6d0e6898d899cc639a88d1da301000 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -150,6 +150,24 @@ public: void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs); void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs); + /** + * @brief Get a topological node order for an acyclic walk of the graph + * Graph cycles are broken on operator back edges such that resolution on + * single level lattice can be done in a single pass as it is + * the case generally for static resolution of Tensor shapes/datatypes. + * When reversed is true, gets a topological order on the reversed graph + * which is equivalent to a post-dfs order of the graph. + * The returned order is deterministic given the graph node set and the + * graph ordered output nodes. + * The output nodes connectivity must cover all nodes of the graph, + * otherwise a runtime exception is thrown. + * The returned order is biased toward left-to-right child order both + * for topological and post-dfs order. + * @param reversed returns a topological order of the reversed graph + * @return the ordered list of nodes + */ + std::vector<Aidge::NodePtr> getOrderedNodes(bool reversed = false) const; + /** * @brief Get inputs of the current GraphView with their associated id. * The rank of the nodes are their rank in the vector. diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index ecc47c74578a6ec8bba6c47c07df3f2be6d43078..32932fa6f598737644f74d4e2ce5da89557b5d3d 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -260,6 +260,16 @@ public: return getOperator()->inputCategory(idx); } + /** + * @brief Returns whether the given node parent index is a back edge + * A back edge is defined by the operator and node parent index + * correspond to operator input index. + * @return true if the operator defines it as a back edge + */ + inline bool parentIsBackEdge(IOIndex_t idx) const { + return getOperator()->isBackEdge(idx); + } + /** * @brief Number of inputs linked to a Parent's output. * @return IOIndex_t diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 93e9664e266db6a14947170d960d52f198dcdce0..87aa4080e57d14d0d8a738afed2e976521b42048 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -57,6 +57,7 @@ private: const OperatorType mOperatorType; const std::vector<InputCategory> mInputsCategory; const IOIndex_t mNbOut; + std::set<IOIndex_t> mBackEdges; public: Operator() = delete; @@ -73,7 +74,8 @@ public: std::enable_shared_from_this<Operator>(), mOperatorType(op.mOperatorType), mInputsCategory(op.mInputsCategory), - mNbOut(op.mNbOut) + mNbOut(op.mNbOut), + mBackEdges(op.mBackEdges) { mType = op.mType; mImpl = nullptr; @@ -208,6 +210,21 @@ public: inline IOIndex_t nbInputs() const noexcept { return mInputsCategory.size(); }; inline IOIndex_t nbOutputs() const noexcept { return mNbOut; }; + /** + * @brief Set the back edge input indexes for recurring operators. + * Any recuring operators should specify it's back edges, otherwise + * the interpretation of the data flow graph may not be possible. + */ + inline void setBackEdges(const std::set<IOIndex_t>& backEdges) { mBackEdges = backEdges; } + + /** + * @brief Returns whether the given input index is a back edge. + * @return true if the input index is in the back edge set + */ + inline bool isBackEdge(IOIndex_t inputIdx) const { + return mBackEdges.find(inputIdx) != mBackEdges.end(); + } + static const std::vector<std::string> getInputsName() { return {}; } diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 2930383817d1555d51b8bddd8eff6402240e905a..c0ee183b072398e2e393bdbd7446de0155519169 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -130,6 +130,10 @@ void init_GraphView(py::module& m) { .def("__call__", &GraphView::operator(), py::arg("connectors")) .def("set_datatype", &GraphView::setDataType, py::arg("datatype")) .def("set_backend", &GraphView::setBackend, py::arg("backend"), py::arg("device") = 0) + .def("get_ordered_nodes", &GraphView::getOrderedNodes, py::arg("reversed") = false, + R"mydelimiter( + Get ordered nodes for the graph view + )mydelimiter") // .def("__getitem__", [](Tensor& b, size_t idx)-> py::object { // // TODO : Should return error if backend not compatible with get // if (idx >= b.size()) throw py::index_error(); diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index 81a62f4ed0eb12844453581f68165a282fff9817..6ffbdd007b9f929ccac18de12f2319dcd68b1eda 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -63,6 +63,8 @@ void init_Operator(py::module& m){ .def("get_hook", &Operator::getHook) .def("add_hook", &Operator::addHook) .def_property_readonly("attr", &Operator::attributes) + .def("set_back_edges", &Operator::setBackEdges, py::arg("input_indexes")) + .def("is_back_edge", &Operator::isBackEdge, py::arg("input_index")) ; } } diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 5a3eb695d7288c6414c01a82b36638f8b93d6b5f..ef322fe5b795b9cb9c62c3593abdd330fd471575 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -685,6 +685,61 @@ std::pair<std::vector<Aidge::NodePtr>, size_t> Aidge::GraphView::getRankedNodes( return std::make_pair(rankedNodes, orderUnicityLimit); } +std::vector<Aidge::NodePtr> Aidge::GraphView::getOrderedNodes(bool reversed) const { + // We compute the order from a post-dfs walk on the reverse graph starting from + // ordered output nodes. + // Also, we walk the graph upward left to right in order + // to get a topological left-right order when possible. + // For the case where reversed is true, we walk the graph upward right to left + // and reverse the final order to get a post-dfs left-right order when possible. + std::vector<std::pair<NodePtr,std::pair<size_t, std::vector<NodePtr>>>> stack; + std::vector<NodePtr> reversePostDfs; + std::set<NodePtr> visited; + std::vector<NodePtr> outNodes(mNodes.size()); + auto reverse_if_dfs = [reversed](auto &parents) { + if (reversed) std::reverse(parents.begin(), parents.end()); + }; + for (const auto& output : mOutputNodes) { + outNodes.push_back(output.first); + } + reverse_if_dfs(outNodes); + stack.push_back(std::make_pair(nullptr, std::make_pair(0, std::move(outNodes)))); + while (!stack.empty()) { + auto node = stack.back().first; + auto& parentIdx = stack.back().second.first; + auto& parents = stack.back().second.second; + if (parentIdx == parents.size()) { + stack.pop_back(); + if (node) { + reversePostDfs.push_back(node); + } + } else { + auto backEdgeIdx = reversed ? parents.size() - 1 - parentIdx: parentIdx; + auto isBackEdge = node != nullptr ? node->parentIsBackEdge(backEdgeIdx): false; + auto parent = parents[parentIdx++]; + if (parent != nullptr && inView(parent) && + visited.find(parent) == visited.end()) { + if (isBackEdge) { + stack[0].second.second.push_back(parent); + } else { + visited.insert(parent); + auto next_parents = parent->getParents(); + reverse_if_dfs(next_parents); + stack.push_back(std::make_pair(parent, std::make_pair(0, std::move(next_parents)))); + } + } + } + } + + if (reversePostDfs.size() != mNodes.size()) { + AIDGE_THROW_OR_ABORT(std::runtime_error, + "Could not enumerate all nodes, set output nodes such that all graph nodes are connected."); + } + + reverse_if_dfs(reversePostDfs); + return reversePostDfs; +} + std::map<Aidge::NodePtr, std::string> Aidge::GraphView::getRankedNodesName(const std::string& format, bool markNonUnicity) const { const auto rankedNodes = getRankedNodes(); std::map<NodePtr, std::string> rankedNodesName; diff --git a/src/operator/Memorize.cpp b/src/operator/Memorize.cpp index c3d64d5bc66bc00e1ed67fc6158f656c75fb2b82..61239071a99a9dfca8613ef78eba17757c4276b7 100644 --- a/src/operator/Memorize.cpp +++ b/src/operator/Memorize.cpp @@ -82,6 +82,8 @@ Aidge::Memorize_Op::Memorize_Op(const std::uint32_t endStep) attr<MemorizeAttr::ForwardStep>(0), attr<MemorizeAttr::EndStep>(endStep))) { + // The input idx 0 is a back edge for Memorize where inputs are (back, init) + setBackEdges({0}); mOutputs[1] = mOutputs[0]; } @@ -161,4 +163,4 @@ std::set<std::string> Aidge::Memorize_Op::getAvailableBackends() const { std::shared_ptr<Aidge::Node> Aidge::Memorize(const std::uint32_t endStep, const std::string& name) { return std::make_shared<Node>(std::make_shared<Memorize_Op>(endStep), name); -} \ No newline at end of file +} diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 30071248699dbc2dd697d1d1f09c47ebcb217967..a08808ee5e6c2657a76213dcff80cec53b23e7ee 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -23,6 +23,9 @@ #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Testing.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/Split.hpp" +#include "aidge/operator/Memorize.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/ReLU.hpp" #include "aidge/graph/OpArgs.hpp" @@ -440,6 +443,107 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { } } +TEST_CASE("[core/graph] GraphView(getOrderedNodes)", "[GraphView][getOrderedNodes]") { + auto data1 = Producer({2}, "data1"); + auto data2 = Producer({2}, "data2"); + auto data3 = Producer({2}, "data3"); + auto add1 = Add(2, "add1"); + auto add2 = Add(2, "add2"); + auto split1 = Split(2, 0, {1, 1}, "split1"); + auto add3 = Add(3, "add3"); + auto g = std::make_shared<GraphView>("TestGraph"); + data1->addChild(add1); + data2->addChild(add1); + add1->addChild(add2); + data3->addChild(add2); + add1->addChild(add3); + add2->addChild(split1); + split1->addChild(add3); + g->add(data1); + g->add(data2); + g->add(data3); + g->add(add1); + g->add(add2); + g->add(split1); + g->add(add3); + REQUIRE(g->getNodes().size() == 7); + + auto topo = g->getOrderedNodes(); + SECTION("Topological order") { + REQUIRE(topo[0] == data1); + REQUIRE(topo[1] == data2); + REQUIRE(topo[2] == add1); + REQUIRE(topo[3] == data3); + REQUIRE(topo[4] == add2); + REQUIRE(topo[5] == split1); + REQUIRE(topo[6] == add3); + } + + auto pdfs = g->getOrderedNodes(true); + SECTION("Post DFS order") { + REQUIRE(pdfs[0] == add3); + REQUIRE(pdfs[1] == split1); + REQUIRE(pdfs[2] == add2); + REQUIRE(pdfs[3] == add1); + REQUIRE(pdfs[4] == data1); + REQUIRE(pdfs[5] == data2); + REQUIRE(pdfs[6] == data3); + } + + // Invert output order + g->setOrderedOutputs({{split1, 1}, {add3, 0}}); + SECTION("Topological order output reversed") { + // As add3 depends upon split1, the order should not be changed + auto topo2 = g->getOrderedNodes(); + REQUIRE(topo2 == topo); + } + + SECTION("Post DFS order output reversed") { + // As add3 depends upon split1, the order should not be changed + auto pdfs2 = g->getOrderedNodes(true); + REQUIRE(pdfs2 == pdfs); + } +} + +TEST_CASE("[core/graph] GraphView(getOrderedNodes) cyclic", "[GraphView][getOrderedNodes]") { + auto data1 = Producer({2}, "data1"); + auto data2 = Producer({2}, "data2"); + auto add1 = Add(2, "add1"); + auto mem1 = Memorize(1, "mem1"); + auto add2 = Add(2, "add2"); + auto g = std::make_shared<GraphView>("TestGraph"); + data1->addChild(add1); + data2->addChild(add1); + add1->addChild(mem1, 0, 1); // init + data1->addChild(add2); + mem1->addChild(add2); + add2->addChild(mem1); // back edge + g->add(data1); + g->add(data2); + g->add(add1); + g->add(mem1); + g->add(add2); + REQUIRE(g->getNodes().size() == 5); + + auto topo = g->getOrderedNodes(); + SECTION("Topological order") { + REQUIRE(topo[0] == data1); + REQUIRE(topo[1] == data2); + REQUIRE(topo[2] == add1); + REQUIRE(topo[3] == mem1); + REQUIRE(topo[4] == add2); + } + + auto pdfs = g->getOrderedNodes(true); + SECTION("post DFS order") { + REQUIRE(pdfs[0] == add2); + REQUIRE(pdfs[1] == mem1); + REQUIRE(pdfs[2] == add1); + REQUIRE(pdfs[3] == data1); + REQUIRE(pdfs[4] == data2); + } +} + TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1");