diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 404b0fd02038af1156269d6fcd0c973b24a28351..18d865b4132b30561444b666fd530c0904b0b041 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -344,17 +344,17 @@ public: bool replaceWith(std::set<NodePtr> newNodes); /** - * @brief Replace a set of Nodes in the current GraphView with a new set of Nodes if possible. + * @brief Replace a set of Nodes in every available GraphView with a new set of Nodes if possible. * Both sets should include all the necessary Producers. - * @details Replaced Nodes are only removed from the current GraphView. Other GraphView containing - * them will not be affected by the replacement. The oldNodes set should have only one input/output - * Node for automatic connections of newNodes set. + * @details Replaced Nodes are removed from any GraphView pointing at them all. + * The oldNodes set should have only one input/output + * Tensor for automatic connections of newNodes set. * @param oldNodes actual set of shared_ptr<Node> to replace. * @param newNodes new set of shared_ptr<Node>. * @return true * @return false */ - bool replace(std::set<NodePtr>& oldNodes, std::set<NodePtr>& newNodes); + static bool replace(const std::set<NodePtr>& oldNodes, const std::set<NodePtr>& newNodes); void updateInputNodes(); /** diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 9b048b126b4308009a512b8a921b0bca71172fc5..72a499d18907610f07d193f4fb31707a7e01946f 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -595,41 +595,38 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { return replacable; } -bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidge::NodePtr>& newNodes) { - for (const auto& node : oldNodes) { - if (mNodes.find(node) == mNodes.end()) { - AIDGE_INTERNAL_ASSERT("GraphView asked to replace a Node it does not contain."); - } - } +bool Aidge::GraphView::replace(const std::set<Aidge::NodePtr>& oldNodes, const std::set<Aidge::NodePtr>& newNodes) { + // TODO: handle case where an oldNodes parameter does not come from a Producer but another Node (not included in oldNodes) // How to distinguish it from data input? // TODO: Parameter Tensors could be identified with their dimensions // TODO: Take GraphView as input parameters since new Nodes should be connected whatever. // It also avoids specifying each producer since they are automatically included - auto oldG = std::make_shared<GraphView>(); + auto oldG = std::make_shared<GraphView>("oldG"); oldG->add(oldNodes, false); - auto newG = std::make_shared<GraphView>(); + auto newG = std::make_shared<GraphView>("newG"); newG->add(newNodes, false); - if ((oldG->inputNodes().size() != 1) || (oldG->outputNodes().size() != 1)) { + if ((oldG->inputNodes().size() == 0) || (oldG->outputNodes().size() != 1)) { return false; } - if (!(newNodes.empty()) && ((newG->inputNodes().size() != 1) || + if (!(newNodes.empty()) && ((newG->inputNodes().size() == 0) || (newG->outputNodes().size() != 1))) { return false; } - std::shared_ptr<Node> previousInputNode = (*(oldG->inputNodes()).begin()); - std::shared_ptr<Node> previousOutputNode = (*(oldG->outputNodes()).begin()); + // there is at least one inputNode in the old/new GraphView + std::shared_ptr<Node> firstPreviousInputNode = (*(oldG->inputNodes()).begin()); + std::shared_ptr<Node> firstPreviousOutputNode = (*(oldG->outputNodes()).begin()); // find Node to link to new input Node - //compute number of input for previousInputNode not in oldNodes set + //compute number of input for firstPreviousInputNode not in oldNodes set std::size_t nbExternalInputs = 0; std::shared_ptr<Node> externalInput = nullptr; IOIndex_t externalInputId = gk_IODefaultIndex; - for (const auto& input : previousInputNode->inputs()) { - if (oldNodes.find(input.first) == oldNodes.end()) { + for (const auto& input : firstPreviousInputNode->inputs()) { + if (oldNodes.find(input.first) == oldNodes.end()) { // Node connected to another Node outside of oldG nbExternalInputs++; externalInput = input.first; externalInputId = input.second; @@ -638,14 +635,28 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg if (nbExternalInputs > 1) { AIDGE_INTERNAL_ASSERT("To many input to link for oldNodes set"); } - if (previousOutputNode->nbOutputs() != 1) { + + if (oldG->inputNodes().size() > 1){ + // one or no input has been identified. Checking every input points to the same source + for (const auto& previousInputNode : oldG->inputNodes()) { + for (const auto& input : previousInputNode->inputs()) { + if (oldNodes.find(input.first) == oldNodes.end()) { + if ( (externalInput != input.first) || (externalInputId != input.second) ) { + return false; // an inputNode points to an external Node different from the registered one + } + } + } + } + } + + if (firstPreviousOutputNode->nbOutputs() != 1) { return false; } // find Node to replicate output connections std::shared_ptr<Node> newOutputNode = newNodes.empty() ? externalInput : *(newG->outputNodes().begin()); - auto copyOutputs = previousOutputNode->outputs(); + auto copyOutputs = firstPreviousOutputNode->outputs(); // manage Views for newNodes // only keep common views to each node for the new set std::set<std::shared_ptr<GraphView>> commonGraphViews = (*oldNodes.begin())->views(); @@ -657,6 +668,8 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg std::inserter(intersection, intersection.begin())); commonGraphViews = intersection; } + commonGraphViews.erase(oldG); + commonGraphViews.erase(newG); // clean Nodes to replace // Do not include common Nodes to avoid cleaning Producers linked to newNodes @@ -667,7 +680,7 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg for (auto& nodePtr : nodesToClean) { nodePtr->resetConnections(true); } // copy output connections - for (IOIndex_t o = 0; o < previousOutputNode->nbOutputs(); ++o) { + for (IOIndex_t o = 0; o < firstPreviousOutputNode->nbOutputs(); ++o) { auto outputPairs = copyOutputs[o]; for (const auto& onePair : outputPairs) { newOutputNode->addChild(onePair.first, o, onePair.second); @@ -675,22 +688,32 @@ bool Aidge::GraphView::replace(std::set<Aidge::NodePtr>& oldNodes, std::set<Aidg } // copy input connections if (!newNodes.empty()) { - std::shared_ptr<Node> newInputNode = (*(newG->inputNodes()).begin()); - if (newInputNode->getNbFreeDataInputs() > 1) { - return false; + for (const auto& newInputNode : newG->inputNodes()) { + IOIndex_t inputId = 0; + for (const auto& input : newInputNode->inputs()) { + if (newNodes.find(input.first) == newNodes.end()) { + externalInput->addChild(newInputNode, externalInputId, inputId); + } + inputId++; + } } - // one non-connected input in newNodes set - externalInput->addChild(newInputNode, externalInputId, newInputNode->getFirstFreeDataInput()); } // insert new Nodes in the right GraphViews - for (auto& graphPtr : commonGraphViews) { + for (const auto& graphPtr : commonGraphViews) { graphPtr->add(newNodes, false); if (newNodes.empty()) { graphPtr->updateInputNodes(); graphPtr->updateOutputNodes(); } } + + for (const auto& node : oldNodes) { + node->removeView(oldG); + } + for (const auto& node : newNodes) { + node->removeView(newG); + } return true; } diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 4390dbe1100050d677b1e14f2d50af0b5624b475..387d34ad82efebc2855c51e07b4999dd7bdfd9e2 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -364,7 +364,7 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias}); // replace - g->replace(nodeToReplace, newNodes); + GraphView::replace(nodeToReplace, newNodes); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC})); REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias))); @@ -381,19 +381,42 @@ TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { g->add({r1, r2, r3, r4}); auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3}); auto newNodes = std::set<std::shared_ptr<Node>>({}); - g->replace(nodesToReplace, newNodes); + GraphView::replace(nodesToReplace, newNodes); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); REQUIRE((r1->output(0))[0].first == r4); } - // SECTION("replace for tiling") { - // std::shared_ptr<GraphView> g = std::make_shared<GraphView>(); - // auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); - // auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); - // auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv"); - // auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); - // otherInput->addChild() - // } + + SECTION("replace for tiling") { + std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph"); + auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); + auto other1 = GenericOperator("Other", 1, 1, 1, "other1"); + auto myConv = GenericOperator("Conv", 1, 1, 1, "myConv"); + auto other2 = GenericOperator("Other", 1, 1, 1, "other2"); + otherInput->addChild(other1); + other1->addChild(myConv); + myConv->addChild(other2); + g->add({other1, myConv, other2}); + + // create tiled Conv + auto conv1 = GenericOperator("Conv", 1, 1, 1, "myConv1"); + auto conv2 = GenericOperator("Conv", 1, 1, 1, "myConv2"); + auto conv3 = GenericOperator("Conv", 1, 1, 1, "myConv3"); + auto conv4 = GenericOperator("Conv", 1, 1, 1, "myConv4"); + auto concat = GenericOperator("Concat", 4, 4, 1, "myConcat"); + conv1->addChild(concat); + conv2->addChild(concat); + conv3->addChild(concat); + conv4->addChild(concat); + + GraphView::replace({myConv}, {conv1, conv2, conv3, conv4, concat}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, conv1, conv2, conv3, conv4, concat, other2})); + + GraphView::replace({conv1, conv2, conv3, conv4, concat}, {myConv}); + + REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2})); + } } TEST_CASE("[GraphView] clone") {