From 080743a94fb11f6331751176d7c59aa1fcf38e9f Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 27 Oct 2023 16:06:40 +0000 Subject: [PATCH] [Upd] replace() instead of replaceWith() in GraphView --- include/aidge/graph/GraphView.hpp | 36 +++++++--- src/graph/GraphView.cpp | 101 ++++++++++++++++++++++++++++ unit_tests/graph/Test_GraphView.cpp | 64 ++++++++++++++++++ 3 files changed, 190 insertions(+), 11 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 89ba14849..404b0fd02 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -322,17 +322,17 @@ public: /** * @brief Insert a node (newParentNode) as a parent of the passed node (childNode). - * + * * @param childNode Node that gets a new parent. * @param newParentNode Inserted Node. * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output. * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode. * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor. */ - void insertParent(NodePtr childNode, - NodePtr newParentNode, - IOIndex_t childInputTensorIdx, - IOIndex_t newParentInputTensorIdx, + void insertParent(NodePtr childNode, + NodePtr newParentNode, + IOIndex_t childInputTensorIdx, + IOIndex_t newParentInputTensorIdx, IOIndex_t newParentOutputTensorIdx); /** @@ -342,6 +342,20 @@ public: * @return false */ bool replaceWith(std::set<NodePtr> newNodes); + + /** + * @brief Replace a set of Nodes in the current GraphView with a new set of Nodes if possible. + * Both sets should include all the necessary Producers. + * @details Replaced Nodes are only removed from the current GraphView. Other GraphView containing + * them will not be affected by the replacement. The oldNodes set should have only one input/output + * Node for automatic connections of newNodes set. + * @param oldNodes actual set of shared_ptr<Node> to replace. + * @param newNodes new set of shared_ptr<Node>. + * @return true + * @return false + */ + bool replace(std::set<NodePtr>& oldNodes, std::set<NodePtr>& newNodes); + void updateInputNodes(); /** * @brief Process from zero the set of output Nodes. @@ -379,6 +393,12 @@ public: */ std::shared_ptr<GraphView> cloneCallback(NodePtr(*cloneNode)(NodePtr)) const; + /** + * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object. + * @return IOIndex_t + */ + IOIndex_t getNbFreeDataInputs() const; + private: /////////////////////////////////////////////////////// // TENSOR MANAGEMENT @@ -390,12 +410,6 @@ private: */ IOIndex_t getNbDataInputs() const; - /** - * @brief Get the sum of the number of free dataInput connection for all inputNodes of the GraphView object. - * @return IOIndex_t - */ - IOIndex_t getNbFreeDataInputs() const; - /** * @brief Update the set of inputNodes with a new Node, checking if it can be * added and removing any Node not part of mInputNode anymore. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 8f8f51c89..9b048b126 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -17,6 +17,7 @@ #include "aidge/utils/Types.h" #include "aidge/graph/GraphView.hpp" #include "aidge/data/Tensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" /////////////////////////////////////////////////////// // FUNCTIONAL DESCRIPTION @@ -594,6 +595,106 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { return replacable; } +bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidge::NodePtr>& newNodes) { + for (const auto& node : oldNodes) { + if (mNodes.find(node) == mNodes.end()) { + AIDGE_INTERNAL_ASSERT("GraphView asked to replace a Node it does not contain."); + } + } + // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) + // How to distinguish it from data input? + // TODO: Parameter Tensors could be identified with their dimensions + // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. + // It also avoids specifying each producer since they are automatically included + + auto oldG = std::make_shared<GraphView>(); + oldG->add(oldNodes, false); + auto newG = std::make_shared<GraphView>(); + newG->add(newNodes, false); + + if ((oldG->inputNodes().size() != 1) || (oldG->outputNodes().size() != 1)) { + return false; + } + if (!(newNodes.empty()) && ((newG->inputNodes().size() != 1) || + (newG->outputNodes().size() != 1))) { + return false; + } + + std::shared_ptr<Node> previousInputNode = (*(oldG->inputNodes()).begin()); + std::shared_ptr<Node> previousOutputNode = (*(oldG->outputNodes()).begin()); + + // find Node to link to new input Node + //compute number of input for previousInputNode not in oldNodes set + std::size_t nbExternalInputs = 0; + std::shared_ptr<Node> externalInput = nullptr; + IOIndex_t externalInputId = gk_IODefaultIndex; + for (const auto& input : previousInputNode->inputs()) { + if (oldNodes.find(input.first) == oldNodes.end()) { + nbExternalInputs++; + externalInput = input.first; + externalInputId = input.second; + } + } + if (nbExternalInputs > 1) { + AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set"); + } + if (previousOutputNode->nbOutputs() != 1) { + return false; + } + + // find Node to replicate output connections + std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin()); + + auto copyOutputs = previousOutputNode->outputs(); + // manage Views for newNodes + // only keep common views to each node for the new set + std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); + for (const auto& nodePtr : oldNodes) { + const auto nodeView = nodePtr->views(); + std::set<std::shared_ptr<GraphView>> intersection; + std::set_intersection(commonGraphViews.begin(), commonGraphViews.end(), + nodeView.begin(), nodeView.end(), + std::inserter(intersection, intersection.begin())); + commonGraphViews = intersection; + } + + // clean Nodes to replace + // Do not include common Nodes to avoid cleaning Producers linked to newNodes + std::set<std::shared_ptr<Node>> nodesToClean; + std::set_difference(oldNodes.begin(), oldNodes.end(), + newNodes.begin(), newNodes.end(), + std::inserter(nodesToClean, nodesToClean.begin())); + for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } + + // copy output connections + for (IOIndex_t o = 0; o < previousOutputNode->nbOutputs(); ++o) { + auto outputPairs = copyOutputs[o]; + for (const auto& onePair : outputPairs) { + newOutputNode->addChild(onePair.first, o, onePair.second); + } + } + // copy input connections + if (!newNodes.empty()) { + std::shared_ptr<Node> newInputNode = (*(newG->inputNodes()).begin()); + if (newInputNode->getNbFreeDataInputs() > 1) { + return false; + } + // one non-connected input in newNodes set + externalInput->addChild(newInputNode, externalInputId, newInputNode->getFirstFreeDataInput()); + } + + // insert new Nodes in the right GraphViews + for (auto& graphPtr : commonGraphViews) { + graphPtr->add(newNodes, false); + if (newNodes.empty()) { + graphPtr->updateInputNodes(); + graphPtr->updateOutputNodes(); + } + } + return true; +} + + void Aidge::GraphView::updateInputNodes() { mInputNodes.clear(); for (const std::shared_ptr<Node>& go_ptr : mNodes) { diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 9f0143646..4390dbe11 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -332,6 +332,70 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { } } +TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { + SECTION("replace small pattern") { + // create original graph + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); + auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w"); + auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b"); + auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); + auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); + auto matmul = GenericOperator("MatMul", 1, 2, 1, "matmul"); + auto add = GenericOperator("Add", 1, 2, 1, "add"); + otherInput->addChild(other1); + other1->addChild(matmul); + matmul->addChild(add); + add->addChild(other2); + matmulWeight->addChild(matmul, 0, 1); + addBias->addChild(add, 0, 1); + g->add({other1, matmul, add, other2}); + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add})); + + // create graph to replace + std::set<std::shared_ptr<Node>> nodeToReplace = std::set<std::shared_ptr<Node>>({matmulWeight, addBias, matmul, add}); + + // create replacing graph + std::shared_ptr<Node> myFC = GenericOperator("FC", 1, 3, 1, "fc"); + auto newMatmulWeight = matmulWeight->cloneSharedOperators(); + newMatmulWeight->addChild(myFC, 0, 1); + auto newAddBias = addBias->cloneSharedOperators(); + newAddBias->addChild(myFC, 0, 2); + std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias}); + + // replace + g->replace(nodeToReplace, newNodes); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC})); + REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias))); + } + SECTION("replace with nothing") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); + auto r1 = GenericOperator("relu", 0, 0, 1); + auto r2 = GenericOperator("relu", 1, 1, 1); + auto r3 = GenericOperator("relu", 1, 1, 1); + auto r4 = GenericOperator("relu", 1, 1, 0); + r1->addChild(r2); + r2->addChild(r3); + r3->addChild(r4); + g->add({r1, r2, r3, r4}); + auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3}); + auto newNodes = std::set<std::shared_ptr<Node>>({}); + g->replace(nodesToReplace, newNodes); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); + REQUIRE((r1->output(0))[0].first == r4); + } + // SECTION("replace for tiling") { + // std::shared_ptr<GraphView> g = std::make_shared<GraphView>(); + // auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); + // auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); + // auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv"); + // auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); + // otherInput->addChild() + // } +} + TEST_CASE("[GraphView] clone") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); -- GitLab