diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index d6fcb60c1815ee00643ca4bbc176a19d63aedefd..fbfeafce2a0aa3672449dcccce4316eb7dc1e85c 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -945,7 +945,8 @@ TEST_CASE("[core/graph] AIDGE_TEST_0108000: Replacing a set of nodes, same input myOld21->addChild(myOld22); myOld22->addChild(other22); graphTest->add({other11, myOld11, myOld12, other12, other21, myOld21, myOld22, other22}); - + // graphTest->setOrderedOutputs(std::vector<std::pair<other11, 0>, std::pair<other21>>) + // graphTest->setOrderedInputs(std::vector<std::pair<other1, 0>, std::pair<other21>>) // Create and link new graph auto myNew11 = GenericOperator("New", {InputCategory::Data}, 1, "new11"); auto myNew12 = GenericOperator("New", {InputCategory::Data}, 1, "new12"); @@ -979,7 +980,217 @@ TEST_CASE("[core/graph] AIDGE_TEST_0108000: Replacing a set of nodes, same input CHECK(graphTest->getNode("new22") == myNew22); } } +TEST_CASE("[core/graph] AIDGE_TEST_0108010: Replacing a set of nodes, one old input, same number of outputs", "[GraphView][replace]") +{ + SECTION("test case 2 one old input several (2) new inputs, same number of outputs (1)") { + removeFile("my_log.txt"); + Log::setFileName("my_log.txt"); + Log::setFileLevel(Log::Level::Warn); + std::vector<std::string> logs; + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto myOld2 = GenericOperator("Old", {InputCategory::Data}, 1, "old2"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2); + myOld2->addChild(other2); + graphTest->add({other1, myOld1, myOld2, other2}); + graphOld->add({myOld1, myOld2}); + + // Create and link new graph + auto myNew1 = GenericOperator("New", {InputCategory::Data}, 1, "new1"); + auto myNew2 = GenericOperator("New", {InputCategory::Data}, 1, "new2"); + auto myNew3 = GenericOperator("New", {InputCategory::Data, InputCategory::Data}, 1, "new3"); + myNew1->addChild(myNew3); + myNew2->addChild(myNew3, 0, 1); + graphNew->add({myNew, myNew2, myNew3}); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + + // Check outputs + CHECK(LoadTextFile("my_log.txt", logs)); + CHECK(logs.size() == 0); + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, myNew3, other2})); + graphTest->save("myGraph",true,true); + CHECK(retValue); + // Check links + CHECK(myNew1->input(0).first == other1); + CHECK(myNew2->input(0).first == other1); + CHECK(myNew3->output(0).at(0).first == other2); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("old2") == myNew1); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + CHECK(graphTest->getNode("new3") == myNew3); + } + SECTION("test case 2 one old input, several (2) new inputs, same number of outputs (0)") { + removeFile("my_log.txt"); + Log::setFileName("my_log.txt"); + Log::setFileLevel(Log::Level::Warn); + std::vector<std::string> logs; + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto myOld2 = GenericOperator("Old", {InputCategory::Data}, 1, "old2"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2); + graphTest->add({other1, myOld1, myOld2}); + graphOld->add({myOld1, myOld2}); + + // Create and link new graph + auto myNew1 = GenericOperator("New", {InputCategory::Data}, 1, "new1"); + auto myNew2 = GenericOperator("New", {InputCategory::Data}, 1, "new2"); + auto myNew3 = GenericOperator("New", {InputCategory::Data, InputCategory::Data}, 1, "new3"); + myNew1->addChild(myNew3); + myNew2->addChild(myNew3, 0, 1); + graphNew->add({myNew, myNew2, myNew3}); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + + // Check outputs + CHECK(LoadTextFile("my_log.txt", logs)); + CHECK(logs.size() == 0); + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, myNew3})); + graphTest->save("myGraph",true,true); + CHECK(retValue); + // Check links + CHECK(myNew1->input(0).first == other1); + CHECK(myNew2->input(0).first == other1); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + CHECK(graphTest->getNode("new3") == myNew3); + } + SECTION("test case 2 one old input, several (2) new input, same number of outputs (4)") { + removeFile("my_log.txt"); + Log::setFileName("my_log.txt"); + Log::setFileLevel(Log::Level::Warn); + std::vector<std::string> logs; + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto myOld2 = GenericOperator("Old", {InputCategory::Data}, 2, "old2"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, {InputCategory::Data}, 1, "other2"); + auto other3 = GenericOperator("Other", {InputCategory::Data}, 1, "other3"); + auto other4 = GenericOperator("Other", {InputCategory::Data}, 1, "other4"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2); + myOld1->addChild(other2); + myOld1->addChild(other3); + myOld2->addChild(other2, 0, 1); + myOld2->addChild(other4, 1); + graphTest->add({other1, myOld1, myOld2, other2}); + graphOld->add({myOld1, myOld2}); + std::vector<std::pair<NodePtr, IOIndex_t>> oldOutputs; + graphOld->setOrderedOutputs(std::vector<std::pair>> {myOld1, 0}, {myOld1, 0}, {myOld2, 0}, {myOld2, 1}) + // Create and link new graph + auto myNew1 = GenericOperator("New", {InputCategory::Data}, 1, "new1"); + auto myNew2 = GenericOperator("New", {InputCategory::Data}, 1, "new2"); + auto myNew3 = GenericOperator("New", {InputCategory::Data, InputCategory::Data}, 2, "new3"); + myNew1->addChild(myNew3); + myNew2->addChild(myNew3, 0, 1); + graphNew->add({myNew, myNew2, myNew3}); + graphNew->setOrderedOutputs(std::vector<std::pair>> {myNew1, 0}, {myNew2, 0}, {myNew3, 0}, {myNew3, 1}) + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + + // Check outputs + CHECK(LoadTextFile("my_log.txt", logs)); + CHECK(logs.size() == 0); + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, myNew3, other2})); + graphTest->save("myGraph",true,true); + CHECK(retValue); + // Check links + CHECK(myNew1->input(0).first == other1); + CHECK(myNew2->input(0).first == other1); + CHECK(myNew3->output(0).at(0).first == other2); + CHECK(myNew3->output(0).at(1).first == other3); + CHECK(myNew3->output(0).at(2).first == other2); + CHECK(myNew3->output(0).at(3).first == other4); + // Check graph Nodes + REQUIRE(false); + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + CHECK(graphTest->getNode("new3") == myNew3); + } + SECTION("test case 2 one old input no (0) new inputs, same number of outputs (1)") { + removeFile("my_log.txt"); + Log::setFileName("my_log.txt"); + Log::setFileLevel(Log::Level::Warn); + std::vector<std::string> logs; + std::shared_ptr<GraphView> graphTest = std::make_shared<GraphView>("test_graph"); + std::shared_ptr<GraphView> graphOld = std::make_shared<GraphView>("old_graph"); + std::shared_ptr<GraphView> graphNew = std::make_shared<GraphView>("new_graph"); + // Create old graph + auto otherInput = GenericOperator("Producer", {}, 1, "other_input"); + auto other1 = GenericOperator("Other", {InputCategory::Data}, 1, "other1"); + auto myOld1 = GenericOperator("Old", {InputCategory::Data}, 1, "old1"); + auto myOld2 = GenericOperator("Old", {InputCategory::Data}, 1, "old2"); + auto other2 = GenericOperator("Other", {InputCategory::Data}, 1, "other2"); + // Link old graph + otherInput->addChild(other1); + other1->addChild(myOld1); + myOld1->addChild(myOld2); + myOld2->addChild(other2); + graphTest->add({other1, myOld1, myOld2, other2}); + graphOld->add({myOld1, myOld2}); + // Create and link new graph + auto myNew1 = GenericOperator("Producer", {}, 1, "new1"); + auto myNew2 = GenericOperator("Producer", {}, 1, "new2"); + auto myNew3 = GenericOperator("New", {InputCategory::Data, InputCategory::Data}, 1, "new3"); + myNew1->addChild(myNew3); + myNew2->addChild(myNew3, 0, 1); + graphNew->add({myNew, myNew2, myNew3}); + + // Replace + bool retValue = GraphView::replace(graphOld, graphNew); + + // Check outputs + CHECK(LoadTextFile("my_log.txt", logs)); + CHECK(logs.size() == 0); + CHECK(graphTest->getNodes() == std::set<std::shared_ptr<Node>>({other1, myNew1, myNew2, myNew3, other2})); + graphTest->save("myGraph",true,true); + CHECK(retValue); + // Check links + CHECK(myNew3->output(0).at(0).first == other2); + // Check graph Nodes + CHECK(graphTest->getNode("old1") == nullptr); + CHECK(graphTest->getNode("old2") == nullptr); + CHECK(graphTest->getNode("new1") == myNew1); + CHECK(graphTest->getNode("new2") == myNew2); + CHECK(graphTest->getNode("new3") == myNew3); + } + +} TEST_CASE("[GraphView] clone") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1");