diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index f11136adaaa3d23fa9d3dc5749dd5d6771cbc42c..7a2b4bac008a82d0454a6dd057d8bf78c7605926 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -320,8 +320,20 @@ public: void link(std::string name1_inID, std::string name2_outID); - void insert(Node &newNode, Node &inNode, std::initializer_list<Node> outNodes, - IOIndex_t tensorIdx); + /** + * @brief Insert a node (newParentNode) as a parent of the passed node (childNode). + * + * @param childNode Node that gets a new parent. + * @param newParentNode Inserted Node. + * @param childInputTensorIdx Index of the input Tensor for the childNode linked to the inserted Node output. + * @param newParentInputTensorIdx Index of the input Tensor for the newParentNode linked to the former parent of childNode. + * @param newParentOutputTensorIdx Index of the output Tensor for the newParentNode linked to the childNode's input Tensor. + */ + void insertParent(NodePtr childNode, + NodePtr newParentNode, + IOIndex_t childInputTensorIdx, + IOIndex_t newParentInputTensorIdx, + IOIndex_t newParentOutputTensorIdx); /** * @brief Replace the current GraphView with the set of given Nodes if possible diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 1a377fb1ce49ccac5d065cf2e994a0f876713410..e50ea20e680e6ab874b14c14b23f77b27286f367 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -519,12 +519,24 @@ void Aidge::GraphView::link(std::string /*name1_inID*/, printf("Not implemented yet.\n"); } -void Aidge::GraphView::insert(Node & /*newNode*/, Node & /*inNode*/, - std::initializer_list<Node> /*outNodes*/, - IOIndex_t /*tensorIdx*/) { - printf("Not implemented yet.\n"); +void Aidge::GraphView::insertParent(NodePtr childNode, + NodePtr newParentNode, + IOIndex_t childInputTensorIdx, + IOIndex_t newParentInputTensorIdx, + IOIndex_t newParentOutputTensorIdx){ + NodePtr currentParentNode = childNode->getParent(childInputTensorIdx); + const IOIndex_t currentParentOutputTensorIdx = childNode->input(childInputTensorIdx).second; + // Remove child from current parent & current Parent from child + currentParentNode->removeChild(childNode, currentParentOutputTensorIdx); + + // Add child + currentParentNode->addChild(newParentNode,currentParentOutputTensorIdx, newParentInputTensorIdx); + newParentNode->addChild(childNode, newParentOutputTensorIdx, childInputTensorIdx); + + add(newParentNode); } + bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { // TODO : only supports one input/output node for now assert(mNodes.size()>0 && "There must be at least one Node to replace"); diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index dc693193c6606c99b1628d23ad253015f8f8dbe6..319370ebad95869efd450eade58a2ecd36075090 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -330,4 +330,48 @@ TEST_CASE("[core/graph] GraphView(replaceWith)") { REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); REQUIRE((r1->output(0))[0].first == r4); } +} + +TEST_CASE("[core/graph] GraphView(insertParent)") { + auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); + auto conv1 = Conv(3, 32, {3, 3}, "conv1"); + auto conv2 = Conv(32, 64, {3, 3}, "conv2"); + auto conv3 = Conv(32, 64, {1, 1}, "conv3"); + auto g = std::make_shared<GraphView>("TestGraph"); + dataProvider->addChild(conv1, 0); + g->add(conv1); + g->addChild(conv2, conv1, 0); + g->addChild(conv3, conv1, 0); + g->save("graphForwardDims"); + g->forwardDims(); + + auto newConv = Conv(32, 32, {1, 1}, "newConv"); + + SECTION("Check insertParent conv2 then insertParent conv3") { + g->insertParent(conv2, newConv, 0, 0, 0); + + std::set<NodePtr> expectedConv1Children = {conv3, newConv}; + std::set<NodePtr> expectedNewConvChildren = {conv2}; + + REQUIRE(conv1->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0)); + REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE((newConv->getChildren()) == expectedNewConvChildren); + REQUIRE((conv1->getChildren()) == expectedConv1Children); + + g->insertParent(conv3, newConv, 0, 0, 0); + + std::set<NodePtr> expectedConv1Children2 = {newConv}; + std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3}; + + REQUIRE(conv1->getOperator()->getOutput(0) != conv3->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) == newConv->getOperator()->getInput(0)); + REQUIRE(conv1->getOperator()->getOutput(0) != conv2->getOperator()->getInput(0)); + REQUIRE(newConv->getOperator()->getOutput(0) == conv2->getOperator()->getInput(0)); + REQUIRE(newConv->getOperator()->getOutput(0) == conv3->getOperator()->getInput(0)); + REQUIRE((newConv->getChildren()) == expectedNewConvChildren2); + REQUIRE((conv1->getChildren()) == expectedConv1Children2); + + } } \ No newline at end of file