/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <algorithm> // std::sort #include <cassert> #include <map> #include <memory> #include <set> #include <string> #include <catch2/catch_test_macros.hpp> #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Testing.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/Producer.hpp" using namespace Aidge; TEST_CASE("genRandomGraph", "[GraphView][randomGen]") { const size_t nbTests = 100; size_t nbUnicity = 0; for (int test = 0; test < nbTests; ++test) { std::random_device rd; const std::mt19937::result_type seed(rd()); RandomGraph randGraph; const auto g1 = std::make_shared<GraphView>("g1"); const bool unicity1 = g1->add(randGraph.gen(seed, 10)); const auto g2 = std::make_shared<GraphView>("g2"); const bool unicity2 = g2->add(randGraph.gen(seed, 10)); // g1->save("./genRandomGraph1"); // g2->save("./genRandomGraph2"); REQUIRE(unicity1 == unicity2); if (unicity1) { REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName)); REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); ++nbUnicity; // Check that inputs/outputs are the same regardless of the order auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName); auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName); std::sort(orderedInputs1.begin(), orderedInputs1.end()); std::sort(orderedInputs2.begin(), orderedInputs2.end()); std::sort(orderedOutputs1.begin(), orderedOutputs1.end()); std::sort(orderedOutputs2.begin(), orderedOutputs2.end()); REQUIRE(orderedInputs1 == orderedInputs2); REQUIRE(orderedOutputs1 == orderedOutputs2); REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName)); REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName)); } } printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests); } TEST_CASE("clone", "[GraphView][clone]") { const size_t nbTests = 100; for (int test = 0; test < nbTests; ++test) { std::random_device rd; const std::mt19937::result_type seed(rd()); RandomGraph randGraph; const auto g1 = std::make_shared<GraphView>("g1"); g1->add(randGraph.gen(seed, 10)); // g1 -> save("GraphView_clone"); const auto g2 = g1->clone(); REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); } } NodePtr nodeDel(NodePtr node) { if (node->type() == "DelFictive") { return nullptr; } return node->clone(); } TEST_CASE("clone_with_delete", "[GraphView][cloneDelete]") { const size_t nbTests = 100; size_t nbClonedWithDelete = 0; // Note: initial seed is chosen such that for nbTests=100, the generated // graphs keep the same inputs/outputs despites the deleted nodes // (meaning the deleted nodes are not input/output of the graph). // Otherwise, the last two REQUIRE are not garanteed to be true! // Warning: distributions are not required to behave the same way by the standard, // therefore the seed has to work for both GCC and MSVC... // See https://stackoverflow.com/questions/38532927/why-gcc-and-msvc-stdnormal-distribution-are-different std::mt19937::result_type seed(243); for (int test = 0; test < nbTests; ++test) { RandomGraph randGraph; randGraph.types = {"Fictive", "DelFictive"}; randGraph.typesWeights = {0.9, 0.1}; const auto g1 = std::make_shared<GraphView>("g1"); const bool unicity1 = g1->add(randGraph.gen(seed, 10)); if (unicity1) { randGraph.omitType = "DelFictive"; const auto g2 = std::make_shared<GraphView>("g2"); const bool unicity2 = g2->add(randGraph.gen(seed, 10)); // g1->save("./clone_with_delete1"); // g2->save("./clone_with_delete2"); try { const auto gCloned = g1->cloneCallback(&nodeDel); REQUIRE(nodePtrTo(gCloned->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); REQUIRE(nodePtrTo(gCloned->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); REQUIRE(nodePtrTo(gCloned->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); ++nbClonedWithDelete; } catch (const std::runtime_error& error) { // pass } } ++seed; } printf("nbClonedWithDelete = %zu/%zu\n", nbClonedWithDelete, nbTests); } TEST_CASE("remove", "[GraphView][remove]") { const size_t nbTests = 100; size_t nbTested = 0; for (int test = 0; test < nbTests; ++test) { std::random_device rd; const std::mt19937::result_type seed(rd()); RandomGraph randGraph; randGraph.types = {"Fictive", "DelFictive"}; randGraph.typesWeights = {0.8, 0.2}; const auto g1 = std::make_shared<GraphView>("g1"); const bool unicity1 = g1->add(randGraph.gen(seed, 10)); if (unicity1) { // g1->save("./remove1_before"); const auto nodes = g1->getNodes(); int step = 1; for (auto node : nodes) { if (node->type() == "DelFictive") { g1->remove(node, false); // g1->save("./remove1_after" + std::to_string(step)); step++; } } randGraph.omitType = "DelFictive"; const auto g2 = std::make_shared<GraphView>("g2"); g2->add(randGraph.gen(seed, 10)); // g1->save("./remove1"); // g2->save("./remove2"); REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName)); // Order not garanteed, because when a node is removed, it can create new GraphView inputs/outputs // Their order thus depends on the deletion order! //REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName)); //REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName)); // Check that inputs/outputs are the same regardless of the order auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName); auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName); auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName); auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName); std::sort(orderedInputs1.begin(), orderedInputs1.end()); std::sort(orderedInputs2.begin(), orderedInputs2.end()); std::sort(orderedOutputs1.begin(), orderedOutputs1.end()); std::sort(orderedOutputs2.begin(), orderedOutputs2.end()); REQUIRE(orderedInputs1 == orderedInputs2); REQUIRE(orderedOutputs1 == orderedOutputs2); ++nbTested; } } printf("nbTested = %zu/%zu\n", nbTested, nbTests); } TEST_CASE("[core/graph] GraphView(Constructor)", "[GraphView][constructor()]") { std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>(); std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1"); REQUIRE(g0 != nullptr); REQUIRE(g1 != nullptr); } TEST_CASE("[core/graph] GraphView(add)", "[GraphView][add]") { SECTION("Node alone") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); g->add(GOp1); std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 1, "Gop2"); g->add(GOp2); std::shared_ptr<Node> GOp3 = GenericOperator("Fictive", 1, 0, 0, "Gop3"); g->add(GOp3); std::shared_ptr<Node> GOp4 = GenericOperator("Fictive", 0, 1, 0, "Gop4"); g->add(GOp4); std::shared_ptr<Node> GOp5 = GenericOperator("Fictive", 1, 0, 1, "Gop5"); g->add(GOp5); std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 1, 1, "Gop6"); g->add(GOp6); // g->save("node_alone"); REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop3", 0}, {"Gop4", 0}, {"Gop5", 0}, {"Gop6", 0}, {"Gop6", 1}})); REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop2", 0}, {"Gop5", 0}, {"Gop6", 0}})); } SECTION("Several Nodes") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); // should automaticaly add parents for learnable parameters std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 1, 1, "Gop1"); std::shared_ptr<Node> GOp1parent = GenericOperator("Fictive", 0, 0, 1, "Gop1parent"); GOp1parent->addChild(GOp1, 0, 0); g->add(GOp1); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({})); REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}})); // there should be no deplicates g->add(GOp1); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent})); REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({})); REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}})); } SECTION("Initializer list ofr Node") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1"); std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 0, "Gop2"); g->add({GOp1, GOp1, GOp2}); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp2})); } SECTION("another GraphView") { std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph-1"); std::shared_ptr<GraphView> g2 = std::make_shared<GraphView>("TestGraph-2"); auto conv = GenericOperator("Conv", 1, 0, 1, "c"); auto conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); auto conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); auto conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); auto conv4 = GenericOperator("Conv", 1, 0, 1, "c4"); conv->addChild(conv1); conv1->addChild(conv2); conv2->addChild(conv3); conv3->addChild(conv4); g1->add({conv, conv1, conv2, conv3, conv4}); g2->add(g1); REQUIRE(((g1->getNodes() == g2->getNodes()) && (g2->getNodes() == std::set<std::shared_ptr<Node>>({conv, conv1, conv2, conv3, conv4})))); REQUIRE(((g1->inputNodes() == g2->inputNodes()) && (g2->inputNodes() == std::set<std::shared_ptr<Node>>({conv})))); REQUIRE(((g1->outputNodes() == g2->outputNodes()) && (g2->outputNodes() == std::set<std::shared_ptr<Node>>({conv4})))); } } TEST_CASE("[core/graph] GraphView(addChild)") { std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 0, 1, "c3.5"); std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4"); std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5"); g1->add(conv); SECTION("add(node)") { REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv}); } g1->addChild(conv1, "c"); SECTION("add(node, outputNodeName)") { REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv1}); REQUIRE(conv->getChildren() == std::set<std::shared_ptr<Node>>({conv1})); REQUIRE(conv1->getParents() == std::vector<std::shared_ptr<Node>>({conv})); } g1->addChild(conv2, "c1", 0); SECTION("add(node, pair<outputNodeName, outID>)") { REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv2}); REQUIRE(conv1->getChildren() == std::set<std::shared_ptr<Node>>({conv2})); REQUIRE(conv2->getParents() == std::vector<std::shared_ptr<Node>>({conv1})); } g1->addChild(conv3, "c2", 0, 0); SECTION("add(node, list(outputNodeName))") { REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3}); REQUIRE(conv2->getChildren() == std::set<std::shared_ptr<Node>>({conv3})); REQUIRE(conv3->getParents() == std::vector<std::shared_ptr<Node>>({conv2})); } g1->addChild(conv3_5, conv3); SECTION("add(node, list(outputNodeName))") { REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3_5}); REQUIRE(conv3->getChildren() == std::set<std::shared_ptr<Node>>({conv3_5})); REQUIRE(conv3_5->getParents() == std::vector<std::shared_ptr<Node>>({conv3})); } g1->addChild(conv4, conv3_5, 0); SECTION("add(node, vector<pair<outputNodeName, outID>>)") { REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv4}); REQUIRE(conv3_5->getChildren() == std::set<std::shared_ptr<Node>>({conv4})); REQUIRE(conv4->getParents() == std::vector<std::shared_ptr<Node>>({conv3_5})); } g1->addChild(conv5, conv4, 0, 0); SECTION("add(node, vector<pair<outputNodeName, outID>>)") { REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv}); REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv5}); REQUIRE(conv4->getChildren() == std::set<std::shared_ptr<Node>>({conv5})); REQUIRE(conv5->getParents() == std::vector<std::shared_ptr<Node>>({conv4})); } std::set<std::shared_ptr<Node>> requiredNodes = {conv, conv1, conv2, conv3, conv3_5, conv4, conv5}; REQUIRE(g1->getNodes() == requiredNodes); REQUIRE(g1->getChildren(conv3) == std::set<std::shared_ptr<Node>>({conv3_5})); } TEST_CASE("[core/graph] GraphView(inputs)") { auto g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); g1->add(conv, false); REQUIRE(g1->inputs() == conv->inputs()); } TEST_CASE("[core/graph] GraphView(outputs)") { std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); g1->add(conv); REQUIRE(g1->outputs() == conv->outputs()); } TEST_CASE("[core/graph] GraphView(save)") { std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3"); std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4"); std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5"); g1->add(conv); g1->addChild(conv1, "c"); g1->addChild(conv2, "c1", 0); g1->addChild(conv3, "c2"); g1->addChild(conv4, "c3", 0); g1->addChild(conv5, "c4", 0, 0); g1->save("./graphExample"); printf("File saved in ./graphExample.md\n"); } TEST_CASE("[core/graph] GraphView(resetConnections)") { SECTION("disconnect data iput") { std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1"); std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2"); conv->addChild(conv1); prod1->addChild(conv1,0,1); prod2->addChild(conv1,0,2); conv1->addChild(conv2); conv1->resetConnections(false); REQUIRE(conv->output(0).size() == 0); for (std::size_t i = 0; i < conv1->nbData(); ++i) { REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); } REQUIRE((conv1->input(1) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod1, 0))); REQUIRE((conv1->input(2) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod2, 0))); REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) { REQUIRE(conv->output(i).size() == 0U); } } SECTION("disconnect data input + learnable parameters") { std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c"); std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1"); std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2"); std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1"); std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2"); conv->addChild(conv1); prod1->addChild(conv1,0,1); prod2->addChild(conv1,0,2); conv1->addChild(conv2); conv1->resetConnections(true); REQUIRE(conv->output(0).size() == 0); for (std::size_t i = 0; i < conv1->nbInputs(); ++i) { REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); } REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex))); for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) { REQUIRE(conv->output(i).size() == 0U); } } } TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv3 = Conv(64, 10, {1, 1}, "conv3"); auto g = std::make_shared<GraphView>("TestGraph"); dataProvider->addChild(conv1, 0); g->add(conv1); g->addChild(conv2, conv1, 0); g->addChild(conv3, conv2, 0); g->save("graphForwardDims"); g->forwardDims(); SECTION("Check input-output connections") { REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0)); REQUIRE(conv1->getOperator()->getRawInput(1) == g->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(conv1->getOperator()->getRawInput(2) == g->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); REQUIRE(conv2->getOperator()->getRawInput(1) == g->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(conv2->getOperator()->getRawInput(2) == g->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); REQUIRE(conv3->getOperator()->getRawInput(1) == g->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(conv3->getOperator()->getRawInput(2) == g->getNode("conv3_b")->getOperator()->getRawOutput(0)); } SECTION("Check forwarded dims") { REQUIRE(std::static_pointer_cast<Tensor>(conv1->getOperator()->getRawOutput(0)) ->dims() == std::vector<DimSize_t>({16, 32, 222, 222})); REQUIRE(std::static_pointer_cast<Tensor>(conv2->getOperator()->getRawOutput(0)) ->dims() == std::vector<DimSize_t>({16, 64, 220, 220})); } } TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") { SECTION("replace small pattern") { // create original graph std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w"); auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b"); auto other1 = GenericOperator("Other", 1, 0, 1, "other1"); auto other2 = GenericOperator("Other", 1, 0, 1, "other2"); auto matmul = GenericOperator("MatMul", 1, 1, 1, "matmul"); auto add = GenericOperator("Add", 1, 1, 1, "add"); otherInput->addChild(other1); other1->addChild(matmul); matmul->addChild(add); add->addChild(other2); matmulWeight->addChild(matmul, 0, 1); addBias->addChild(add, 0, 1); g->add({other1, matmul, add, other2}); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add})); // create graph to replace std::set<std::shared_ptr<Node>> nodeToReplace = std::set<std::shared_ptr<Node>>({matmulWeight, addBias, matmul, add}); // create replacing graph std::shared_ptr<Node> myFC = GenericOperator("FC", 1, 2, 1, "fc"); auto newMatmulWeight = matmulWeight->cloneSharedOperators(); newMatmulWeight->addChild(myFC, 0, 1); auto newAddBias = addBias->cloneSharedOperators(); newAddBias->addChild(myFC, 0, 2); std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias}); // replace GraphView::replace(nodeToReplace, newNodes); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC})); REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias))); } SECTION("replace with nothing") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph"); auto r1 = GenericOperator("relu", 0, 0, 1); auto r2 = GenericOperator("relu", 1, 0, 1); auto r3 = GenericOperator("relu", 1, 0, 1); auto r4 = GenericOperator("relu", 1, 0, 0); r1->addChild(r2); r2->addChild(r3); r3->addChild(r4); g->add({r1, r2, r3, r4}); auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3}); auto newNodes = std::set<std::shared_ptr<Node>>({}); GraphView::replace(nodesToReplace, newNodes); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4})); REQUIRE((r1->output(0))[0].first == r4); } SECTION("replace for tiling") { std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph"); auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input"); auto other1 = GenericOperator("Other", 1, 0, 1, "other1"); auto myConv = GenericOperator("Conv", 1, 0, 1, "myConv"); auto other2 = GenericOperator("Other", 1, 0, 1, "other2"); otherInput->addChild(other1); other1->addChild(myConv); myConv->addChild(other2); g->add({other1, myConv, other2}); // create tiled Conv auto conv1 = GenericOperator("Conv", 1, 0, 1, "myConv1"); auto conv2 = GenericOperator("Conv", 1, 0, 1, "myConv2"); auto conv3 = GenericOperator("Conv", 1, 0, 1, "myConv3"); auto conv4 = GenericOperator("Conv", 1, 0, 1, "myConv4"); auto concat = GenericOperator("Concat", 4, 0, 1, "myConcat"); conv1->addChild(concat); conv2->addChild(concat); conv3->addChild(concat); conv4->addChild(concat); GraphView::replace({myConv}, {conv1, conv2, conv3, conv4, concat}); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, conv1, conv2, conv3, conv4, concat, other2})); GraphView::replace({conv1, conv2, conv3, conv4, concat}, {myConv}); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2})); } SECTION("Change every Nodes in a GraphView") { auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0"); auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0"); auto matmul0 = GenericOperator("MatMul", 1, 1, 1, "matmul0"); auto add0 = GenericOperator("Add", 1, 1, 1, "add0"); auto matmulWeight1 = GenericOperator("Producer", 0, 0, 1, "matmul_w1"); auto addBias1 = GenericOperator("Producer", 0, 0, 1, "add_b1"); auto matmul1 = GenericOperator("MatMul", 1, 1, 1, "matmul1"); auto add1 = GenericOperator("Add", 1, 1, 1, "add1"); matmulWeight0 -> addChild(matmul0, 0, 1); addBias0 -> addChild(add0, 0, 1); matmulWeight1 -> addChild(matmul1, 0, 1); addBias1 -> addChild(add1, 0, 1); matmul0 -> addChild(add0, 0, 0); add0 -> addChild(matmul1, 0, 0); matmul1 -> addChild(add1, 0, 0); auto g = std::make_shared<GraphView>("TestGraph"); g -> add({matmulWeight0, addBias0, matmulWeight1, addBias1, matmul0, add0, matmul1, add1}); auto newMatmulWeight0 = matmulWeight0->cloneSharedOperators(); auto newAddBias0 = addBias0->cloneSharedOperators(); auto newMatmulWeight1 = matmulWeight1->cloneSharedOperators(); auto newAddBias1 = addBias1->cloneSharedOperators(); auto fc0 = GenericOperator("FC", 1, 2, 1, "fc0"); auto fc1 = GenericOperator("FC", 1, 2, 1, "fc1"); newMatmulWeight0 -> addChild(fc0, 0, 1); newAddBias0 -> addChild(fc0, 0, 2); newMatmulWeight1 -> addChild(fc1, 0, 1); newAddBias1 -> addChild(fc1, 0, 2); GraphView::replace({matmul0, add0, matmulWeight0, addBias0}, {newMatmulWeight0, newAddBias0, fc0}); GraphView::replace({matmul1, add1, matmulWeight1, addBias1}, {newMatmulWeight1, newAddBias1, fc1}); REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0})); } } TEST_CASE("[GraphView] clone") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv3 = Conv(64, 10, {1, 1}, "conv3"); auto g1 = std::make_shared<GraphView>("TestGraph"); dataProvider->addChild(conv1, 0); g1->add(conv1); g1->addChild(conv2, conv1, 0); g1->addChild(conv3, conv2, 0); g1->save("clone_g1"); SECTION("Check input-output connections") { REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0)); REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0)); } auto g2 = g1->clone(); auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); dataProvider2->addChild(g2->getNode("conv1"), 0); g2->forwardDims(); g2->save("clone_g2"); SECTION("Check node cloning") { REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); } SECTION("Check operator cloning") { REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator()); REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator()); REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator()); REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator()); REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator()); REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator()); } SECTION("Check new connections") { REQUIRE(dataProvider->getOperator()->getRawOutput(0) != g2->getNode("conv1")->getOperator()->getRawInput(0)); REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(1) != g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(2) != g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv1")->getOperator()->getRawOutput(0) != g2->getNode("conv2")->getOperator()->getRawInput(0)); REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(1) != g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(2) != g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv2")->getOperator()->getRawOutput(0) != g2->getNode("conv3")->getOperator()->getRawInput(0)); REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(1) != g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(2) != g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } SECTION("Check input-output connections") { REQUIRE(dataProvider2->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } } TEST_CASE("[GraphView] cloneSharedProducers") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv3 = Conv(64, 10, {1, 1}, "conv3"); auto g1 = std::make_shared<GraphView>("TestGraph"); dataProvider->addChild(conv1, 0); g1->add(conv1); g1->addChild(conv2, conv1, 0); g1->addChild(conv3, conv2, 0); g1->save("cloneSharedProducers_g1"); SECTION("Check input-output connections") { REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0)); REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0)); } auto g2 = g1->cloneSharedProducers(); auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider"); dataProvider2->addChild(g2->getNode("conv1"), 0); g2->forwardDims(); g2->save("cloneSharedProducers_g2"); SECTION("Check node cloning") { REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); } SECTION("Check operator cloning") { REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator()); REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator()); REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator()); REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); } SECTION("Check new connections") { REQUIRE(dataProvider->getOperator()->getRawOutput(0) != g2->getNode("conv1")->getOperator()->getRawInput(0)); REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv1")->getOperator()->getRawOutput(0) != g2->getNode("conv2")->getOperator()->getRawInput(0)); REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv2")->getOperator()->getRawOutput(0) != g2->getNode("conv3")->getOperator()->getRawInput(0)); REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } SECTION("Check input-output connections") { REQUIRE(dataProvider2->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } } TEST_CASE("[GraphView] cloneSharedOperators") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv3 = Conv(64, 10, {1, 1}, "conv3"); auto g1 = std::make_shared<GraphView>("TestGraph"); dataProvider->addChild(conv1, 0); g1->add(conv1); g1->addChild(conv2, conv1, 0); g1->addChild(conv3, conv2, 0); g1->save("cloneSharedOperators_g1"); SECTION("Check input-output connections") { REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0)); REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0)); } auto g2 = g1->cloneSharedOperators(); g2->forwardDims(); g2->save("cloneSharedOperators_g2"); SECTION("Check node cloning") { REQUIRE(g1->getNode("conv1") != g2->getNode("conv1")); REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w")); REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b")); REQUIRE(g1->getNode("conv2") != g2->getNode("conv2")); REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w")); REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b")); REQUIRE(g1->getNode("conv3") != g2->getNode("conv3")); REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w")); REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b")); } SECTION("Check operator cloning") { REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator()); REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator()); REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator()); REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator()); REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator()); REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator()); REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator()); REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator()); REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator()); } SECTION("Check input-output connections") { REQUIRE(dataProvider->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0)); REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0)); } } TEST_CASE("[core/graph] GraphView(insertParent)") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); auto conv3 = Conv(32, 64, {1, 1}, "conv3"); auto g = std::make_shared<GraphView>("TestGraph"); dataProvider->addChild(conv1, 0); g->add(conv1); g->addChild(conv2, conv1, 0); g->addChild(conv3, conv1, 0); g->save("graphForwardDims"); g->forwardDims(); auto newConv = Conv(32, 32, {1, 1}, "newConv"); SECTION("Check insertParent conv2 then insertParent conv3") { g->insertParent(conv2, newConv, 0, 0, 0); std::set<NodePtr> expectedConv1Children = {conv3, newConv}; std::set<NodePtr> expectedNewConvChildren = {conv2}; REQUIRE(conv1->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); REQUIRE(conv1->getOperator()->getRawOutput(0) == newConv->getOperator()->getRawInput(0)); REQUIRE(conv1->getOperator()->getRawOutput(0) != conv2->getOperator()->getRawInput(0)); REQUIRE(newConv->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); REQUIRE((newConv->getChildren()) == expectedNewConvChildren); REQUIRE((conv1->getChildren()) == expectedConv1Children); g->insertParent(conv3, newConv, 0, 0, 0); std::set<NodePtr> expectedConv1Children2 = {newConv}; std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3}; REQUIRE(conv1->getOperator()->getRawOutput(0) != conv3->getOperator()->getRawInput(0)); REQUIRE(conv1->getOperator()->getRawOutput(0) == newConv->getOperator()->getRawInput(0)); REQUIRE(conv1->getOperator()->getRawOutput(0) != conv2->getOperator()->getRawInput(0)); REQUIRE(newConv->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0)); REQUIRE(newConv->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0)); REQUIRE((newConv->getChildren()) == expectedNewConvChildren2); REQUIRE((conv1->getChildren()) == expectedConv1Children2); } }