diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 0e55db15abf80fc1ab7e416549cc9625a2785f7b..52729000ece50aa0c9872ccba1cb078714cc2d98 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -111,6 +111,8 @@ public: return mRootNode; } + void setRootNode(NodePtr node); + /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// @@ -448,6 +450,7 @@ public: * @return true replacement has been performed * @return false no replacement has been performed */ + static bool replace(const std::shared_ptr<GraphView>& oldG, const std::shared_ptr<GraphView>& newG); static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); /** diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp index fb4bc22c69ec2b4e8dcc6178c9fcda0a85190f78..da917030e5341f15de575400490d95be83f123bf 100644 --- a/include/aidge/recipies/Recipies.hpp +++ b/include/aidge/recipies/Recipies.hpp @@ -114,6 +114,12 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No */ void explicitCastMove(std::shared_ptr<GraphView> graphView); +/** + * Flatten the graph by replacing the meta operators by their micro graph. + * @param recursive If true, recursively replace meta operators until there is + * no more meta operator in the graph. +*/ +void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); } // namespace Aidge diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 0be4104afcf68d4282637eec714ce4e4cfcd37ab..10d83c6e6ca401d9a9be98dcadb4f41404d1fc5a 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -193,6 +193,11 @@ void Aidge::GraphView::save(std::string path, bool verbose, bool showProducers) fmt::print(fp.get(), "\n"); } +void Aidge::GraphView::setRootNode(NodePtr node) { + AIDGE_ASSERT(mNodes.find(node) != mNodes.end(), "Root node is not in the GraphView!"); + mRootNode = node; +} + /////////////////////////////////////////////////////// // TENSOR MANAGEMENT /////////////////////////////////////////////////////// @@ -839,18 +844,24 @@ void Aidge::GraphView::insertParent(NodePtr childNode, } bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { - // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) - // How to distinguish it from data input? - // TODO: Parameter Tensors could be identified with their dimensions - // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. - // It also avoids specifying each producer since they are automatically included - // (1) create GraphViews from both sets of Nodes auto oldG = std::make_shared<GraphView>("oldG"); oldG->add(oldNodes, false); auto newG = std::make_shared<GraphView>("newG"); newG->add(newNodes, false); + return GraphView::replace(oldG, newG); +} + +bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldG, const std::shared_ptr<GraphView>& newG) { + // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) + // How to distinguish it from data input? + // TODO: Parameter Tensors could be identified with their dimensions + // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. + // It also avoids specifying each producer since they are automatically included + const auto& oldNodes = oldG->getNodes(); + const auto& newNodes = newG->getNodes(); + const auto oldOI = oldG->getOrderedInputs(); const auto oldOO = oldG->getOrderedOutputs(); const auto newOI = newG->getOrderedInputs(); @@ -1198,6 +1209,10 @@ void Aidge::GraphView::updateInputsOutputsDelete(std::shared_ptr<Node> deletedNo } } } + + if (deletedNode == mRootNode) { + mRootNode = nullptr; + } } diff --git a/src/recipies/ExpandMetaOps.cpp b/src/recipies/ExpandMetaOps.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3a72fe04315b2e746ebeb75f2456973bbf142440 --- /dev/null +++ b/src/recipies/ExpandMetaOps.cpp @@ -0,0 +1,36 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/recipies/Recipies.hpp" +#include "aidge/operator/MetaOperator.hpp" + +void Aidge::expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive) { + bool found = false; + const auto nodes = graph->getNodes(); + for (auto node : nodes) { + auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator()); + + if (metaOp != nullptr) { + // Replace meta op by its micro-graph + // graph will be updated accordingly in GraphView::replace() + auto g = std::make_shared<GraphView>(); + g->add(node, false); + GraphView::replace(g, metaOp->getMicroGraph()); + found = true; + } + } + + if (found && recursive) { + expandMetaOps(graph, true); + } +} diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp index 421dec584c6edbbf11f229741ed85c1605474c8c..8245d0044b6358866b06115c33a9bb7ef2862185 100644 --- a/unit_tests/operator/Test_MetaOperator.cpp +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -11,10 +11,12 @@ #include <catch2/catch_test_macros.hpp> +#include "aidge/operator/Pop.hpp" #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Testing.hpp" +#include "aidge/recipies/Recipies.hpp" #include <cstddef> using namespace Aidge; @@ -80,4 +82,49 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator]") { //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op)->getMicroGraphScheduler(); //microGraphScheduler->saveSchedulingDiagram("lstm_scheduling"); } + + SECTION("LSTM(expanded)") { + auto pop = Pop(); + auto myLSTM = LSTM(2, 3, 2, true, "ltsm"); + auto myGraph = Sequential({pop, myLSTM}); + auto op = std::static_pointer_cast<OperatorTensor>(myLSTM->getOperator()); + + REQUIRE(myLSTM->nbInputs() == 3 + 8 + 8); + REQUIRE(myLSTM->nbData() == 1); + REQUIRE(myLSTM->nbOutputs() == 2); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>( + Array3D<float, 2, 3, 2>{{{{1.0, 2.0}, {3.0, 4.0}, {5.0, 6.0}}, {{2.0, 3.0}, {4.0, 5.0}, {6.0, 7.0}}}}); + std::shared_ptr<Tensor> myInit = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}, {0.0, 0.0, 0.0}}}); + std::shared_ptr<Tensor> myInitW = std::make_shared<Tensor>( + Array2D<float, 3, 2>{{{0.1, 0.1}, {0.1, 0.1}, {0.1, 0.1}}}); + std::shared_ptr<Tensor> myInitR = std::make_shared<Tensor>( + Array2D<float, 3, 3>{{{0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}, {0.1, 0.1, 0.1}}}); + + pop->getOperator()->associateInput(0, myInput); + op->associateInput(17, myInit); + op->associateInput(18, myInit); + + // Weights X + myLSTM->input(1).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(2).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(3).first->getOperator()->setOutput(0, myInitW); + myLSTM->input(4).first->getOperator()->setOutput(0, myInitW); + // Weights H + myLSTM->input(5).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(6).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(7).first->getOperator()->setOutput(0, myInitR); + myLSTM->input(8).first->getOperator()->setOutput(0, myInitR); + + auto g = getConnectedGraphView(myLSTM); + g->save("lstm_before_expand", true, true); + + expandMetaOps(g); + g->setRootNode(pop); + REQUIRE(g->getRootNode() == pop); + g->save("lstm_expanded", true, true); + + REQUIRE(g->getNodes().size() == 41); + } }