Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <set>
#include <string>
#include <catch2/catch_test_macros.hpp>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
TEST_CASE("genRandomGraph", "[GraphView][randomGen]") {
const size_t nbTests = 100;
size_t nbUnicity = 0;
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
const auto g1 = std::make_shared<GraphView>("g1");
const bool unicity1 = g1->add(randGraph.gen(seed, 10));
const auto g2 = std::make_shared<GraphView>("g2");
const bool unicity2 = g2->add(randGraph.gen(seed, 10));
// g1->save("./genRandomGraph1");
// g2->save("./genRandomGraph2");
REQUIRE(unicity1 == unicity2);
if (unicity1) {
REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName));
REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName));
// Check that inputs/outputs are the same regardless of the order
auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName);
auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName);
auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName);
auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName);
std::sort(orderedInputs1.begin(), orderedInputs1.end());
std::sort(orderedInputs2.begin(), orderedInputs2.end());
std::sort(orderedOutputs1.begin(), orderedOutputs1.end());
std::sort(orderedOutputs2.begin(), orderedOutputs2.end());
REQUIRE(orderedInputs1 == orderedInputs2);
REQUIRE(orderedOutputs1 == orderedOutputs2);
REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName));
REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName));
printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests);
TEST_CASE("clone", "[GraphView][clone]") {
const size_t nbTests = 100;
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
const auto g1 = std::make_shared<GraphView>("g1");
// g1 -> save("GraphView_clone");
const auto g2 = g1->clone();
REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
}
NodePtr nodeDel(NodePtr node) {
if (node->type() == "DelFictive") {
return nullptr;
}
return node->clone();
}
TEST_CASE("clone_with_delete", "[GraphView][cloneDelete]") {
const size_t nbTests = 100;
size_t nbClonedWithDelete = 0;
// Note: initial seed is chosen such that for nbTests=100, the generated
// graphs keep the same inputs/outputs despites the deleted nodes
// (meaning the deleted nodes are not input/output of the graph).
// Otherwise, the last two REQUIRE are not garanteed to be true!
// Warning: distributions are not required to behave the same way by the standard,
// therefore the seed has to work for both GCC and MSVC...
// See https://stackoverflow.com/questions/38532927/why-gcc-and-msvc-stdnormal-distribution-are-different
std::mt19937::result_type seed(243);
for (int test = 0; test < nbTests; ++test) {
RandomGraph randGraph;
randGraph.types = {"Fictive", "DelFictive"};
randGraph.typesWeights = {0.9, 0.1};
const auto g1 = std::make_shared<GraphView>("g1");
const bool unicity1 = g1->add(randGraph.gen(seed, 10));
if (unicity1) {
const auto g2 = std::make_shared<GraphView>("g2");
const bool unicity2 = g2->add(randGraph.gen(seed, 10));
// g1->save("./clone_with_delete1");
// g2->save("./clone_with_delete2");
try {
const auto gCloned = g1->cloneCallback(&nodeDel);
REQUIRE(nodePtrTo(gCloned->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
REQUIRE(nodePtrTo(gCloned->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
REQUIRE(nodePtrTo(gCloned->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
++nbClonedWithDelete;
}
catch (const std::runtime_error& error) {
// pass
}
printf("nbClonedWithDelete = %zu/%zu\n", nbClonedWithDelete, nbTests);
TEST_CASE("remove", "[GraphView][remove]") {
const size_t nbTests = 100;
size_t nbTested = 0;
for (int test = 0; test < nbTests; ++test) {
std::random_device rd;
const std::mt19937::result_type seed(rd());
RandomGraph randGraph;
randGraph.types = {"Fictive", "DelFictive"};
randGraph.typesWeights = {0.8, 0.2};
const auto g1 = std::make_shared<GraphView>("g1");
const bool unicity1 = g1->add(randGraph.gen(seed, 10));
// g1->save("./remove1_before");
const auto nodes = g1->getNodes();
int step = 1;
for (auto node : nodes) {
if (node->type() == "DelFictive") {
g1->remove(node, false);
// g1->save("./remove1_after" + std::to_string(step));
const auto g2 = std::make_shared<GraphView>("g2");
// g1->save("./remove1");
// g2->save("./remove2");
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
// Order not garanteed, because when a node is removed, it can create new GraphView inputs/outputs
// Their order thus depends on the deletion order!
//REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
//REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
// Check that inputs/outputs are the same regardless of the order
auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName);
auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName);
auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName);
auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName);
std::sort(orderedInputs1.begin(), orderedInputs1.end());
std::sort(orderedInputs2.begin(), orderedInputs2.end());
std::sort(orderedOutputs1.begin(), orderedOutputs1.end());
std::sort(orderedOutputs2.begin(), orderedOutputs2.end());
REQUIRE(orderedInputs1 == orderedInputs2);
REQUIRE(orderedOutputs1 == orderedOutputs2);
++nbTested;
}
}
printf("nbTested = %zu/%zu\n", nbTested, nbTests);
}
TEST_CASE("[core/graph] GraphView(Constructor)", "[GraphView][constructor()]") {
std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1");
REQUIRE(g0 != nullptr);
REQUIRE(g1 != nullptr);
}
TEST_CASE("[core/graph] GraphView(add)", "[GraphView][add]") {
SECTION("Node alone") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1");
g->add(GOp1);
std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 1, "Gop2");
g->add(GOp2);
std::shared_ptr<Node> GOp3 = GenericOperator("Fictive", 1, 0, 0, "Gop3");
g->add(GOp3);
std::shared_ptr<Node> GOp4 = GenericOperator("Fictive", 0, 1, 0, "Gop4");
g->add(GOp4);
std::shared_ptr<Node> GOp5 = GenericOperator("Fictive", 1, 0, 1, "Gop5");
std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 1, 1, "Gop6");
REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop3", 0}, {"Gop4", 0}, {"Gop5", 0}, {"Gop6", 0}, {"Gop6", 1}}));
REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop2", 0}, {"Gop5", 0}, {"Gop6", 0}}));
}
SECTION("Several Nodes") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
// should automaticaly add parents for learnable parameters
std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 1, 1, "Gop1");
std::shared_ptr<Node> GOp1parent = GenericOperator("Fictive", 0, 0, 1, "Gop1parent");
GOp1parent->addChild(GOp1, 0, 0);
g->add(GOp1);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent}));
REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({}));
REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}}));
// there should be no deplicates
g->add(GOp1);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent}));
REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({}));
REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}}));
}
SECTION("Initializer list ofr Node") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1");
std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 0, "Gop2");
g->add({GOp1, GOp1, GOp2});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp2}));
}
SECTION("another GraphView") {
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph-1");
std::shared_ptr<GraphView> g2 = std::make_shared<GraphView>("TestGraph-2");
auto conv = GenericOperator("Conv", 1, 0, 1, "c");
auto conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
auto conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
auto conv3 = GenericOperator("Conv", 1, 0, 1, "c3");
auto conv4 = GenericOperator("Conv", 1, 0, 1, "c4");
conv->addChild(conv1);
conv1->addChild(conv2);
conv2->addChild(conv3);
conv3->addChild(conv4);
g1->add({conv, conv1, conv2, conv3, conv4});
g2->add(g1);
REQUIRE(((g1->getNodes() == g2->getNodes()) && (g2->getNodes() == std::set<std::shared_ptr<Node>>({conv, conv1, conv2, conv3, conv4}))));
REQUIRE(((g1->inputNodes() == g2->inputNodes()) &&
(g2->inputNodes() == std::set<std::shared_ptr<Node>>({conv}))));
REQUIRE(((g1->outputNodes() == g2->outputNodes()) &&
(g2->outputNodes() == std::set<std::shared_ptr<Node>>({conv4}))));
}
}
TEST_CASE("[core/graph] GraphView(addChild)") {
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3");
std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 0, 1, "c3.5");
std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4");
std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5");
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
g1->add(conv);
SECTION("add(node)") {
REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv});
}
g1->addChild(conv1, "c");
SECTION("add(node, outputNodeName)") {
REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv1});
REQUIRE(conv->getChildren() == std::set<std::shared_ptr<Node>>({conv1}));
REQUIRE(conv1->getParents() == std::vector<std::shared_ptr<Node>>({conv}));
}
g1->addChild(conv2, "c1", 0);
SECTION("add(node, pair<outputNodeName, outID>)") {
REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv2});
REQUIRE(conv1->getChildren() == std::set<std::shared_ptr<Node>>({conv2}));
REQUIRE(conv2->getParents() == std::vector<std::shared_ptr<Node>>({conv1}));
}
g1->addChild(conv3, "c2", 0, 0);
SECTION("add(node, list(outputNodeName))") {
REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3});
REQUIRE(conv2->getChildren() == std::set<std::shared_ptr<Node>>({conv3}));
REQUIRE(conv3->getParents() == std::vector<std::shared_ptr<Node>>({conv2}));
}
g1->addChild(conv3_5, conv3);
SECTION("add(node, list(outputNodeName))") {
REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3_5});
REQUIRE(conv3->getChildren() == std::set<std::shared_ptr<Node>>({conv3_5}));
REQUIRE(conv3_5->getParents() ==
std::vector<std::shared_ptr<Node>>({conv3}));
}
g1->addChild(conv4, conv3_5, 0);
SECTION("add(node, vector<pair<outputNodeName, outID>>)") {
REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv4});
REQUIRE(conv3_5->getChildren() == std::set<std::shared_ptr<Node>>({conv4}));
REQUIRE(conv4->getParents() ==
std::vector<std::shared_ptr<Node>>({conv3_5}));
}
g1->addChild(conv5, conv4, 0, 0);
SECTION("add(node, vector<pair<outputNodeName, outID>>)") {
REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv5});
REQUIRE(conv4->getChildren() == std::set<std::shared_ptr<Node>>({conv5}));
REQUIRE(conv5->getParents() == std::vector<std::shared_ptr<Node>>({conv4}));
}
std::set<std::shared_ptr<Node>> requiredNodes = {conv, conv1, conv2, conv3,
conv3_5, conv4, conv5};
REQUIRE(g1->getNodes() == requiredNodes);
REQUIRE(g1->getChildren(conv3) == std::set<std::shared_ptr<Node>>({conv3_5}));
}
TEST_CASE("[core/graph] GraphView(inputs)") {
auto g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = Conv(3, 32, {3, 3});
g1->add(conv, false);
TEST_CASE("[core/graph] GraphView(outputs)") {
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = Conv(3, 32, {3, 3});
g1->add(conv);
REQUIRE(g1->outputs() == conv->outputs());
}
TEST_CASE("[core/graph] GraphView(save)") {
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3");
std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4");
std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5");
g1->add(conv);
g1->addChild(conv1, "c");
g1->addChild(conv2, "c1", 0);
g1->addChild(conv3, "c2");
g1->addChild(conv4, "c3", 0);
g1->addChild(conv5, "c4", 0, 0);
g1->save("./graphExample");
printf("File saved in ./graphExample.md\n");
}
TEST_CASE("[core/graph] GraphView(resetConnections)") {
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1");
std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2");
conv->addChild(conv1);
prod1->addChild(conv1,0,1);
prod2->addChild(conv1,0,2);
conv1->addChild(conv2);
conv1->resetConnections(false);
REQUIRE(conv->output(0).size() == 0);
for (std::size_t i = 0; i < conv1->nbData(); ++i) {
REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)));
}
REQUIRE((conv1->input(1) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod1, 0)));
REQUIRE((conv1->input(2) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod2, 0)));
REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)));
for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) {
REQUIRE(conv->output(i).size() == 0U);
}
}
SECTION("disconnect data input + learnable parameters") {
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1");
std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2");
conv->addChild(conv1);
prod1->addChild(conv1,0,1);
prod2->addChild(conv1,0,2);
conv1->addChild(conv2);
conv1->resetConnections(true);
REQUIRE(conv->output(0).size() == 0);
for (std::size_t i = 0; i < conv1->nbInputs(); ++i) {
REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)));
}
REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)));
for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) {
REQUIRE(conv->output(i).size() == 0U);
}
}
}

Maxence Naud
committed
TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g->add(conv1);
g->addChild(conv2, conv1, 0);
g->addChild(conv3, conv2, 0);
g->save("graphForwardDims");
g->forwardDims();
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0));
REQUIRE(conv1->getOperator()->getRawInput(1) == g->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(conv1->getOperator()->getRawInput(2) == g->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
REQUIRE(conv2->getOperator()->getRawInput(1) == g->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(conv2->getOperator()->getRawInput(2) == g->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
REQUIRE(conv3->getOperator()->getRawInput(1) == g->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(conv3->getOperator()->getRawInput(2) == g->getNode("conv3_b")->getOperator()->getRawOutput(0));
REQUIRE(std::static_pointer_cast<Tensor>(conv1->getOperator()->getRawOutput(0))
REQUIRE(std::static_pointer_cast<Tensor>(conv2->getOperator()->getRawOutput(0))
->dims() == std::vector<DimSize_t>({16, 64, 220, 220}));
}
}
TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
SECTION("replace small pattern") {
// create original graph
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w");
auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b");
auto other1 = GenericOperator("Other", 1, 0, 1, "other1");
auto other2 = GenericOperator("Other", 1, 0, 1, "other2");
auto matmul = GenericOperator("MatMul", 1, 1, 1, "matmul");
auto add = GenericOperator("Add", 1, 1, 1, "add");
otherInput->addChild(other1);
other1->addChild(matmul);
matmul->addChild(add);
add->addChild(other2);
matmulWeight->addChild(matmul, 0, 1);
addBias->addChild(add, 0, 1);
g->add({other1, matmul, add, other2});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add}));
// create graph to replace
std::set<std::shared_ptr<Node>> nodeToReplace = std::set<std::shared_ptr<Node>>({matmulWeight, addBias, matmul, add});
// create replacing graph
std::shared_ptr<Node> myFC = GenericOperator("FC", 1, 2, 1, "fc");
auto newMatmulWeight = matmulWeight->cloneSharedOperators();
newMatmulWeight->addChild(myFC, 0, 1);
auto newAddBias = addBias->cloneSharedOperators();
newAddBias->addChild(myFC, 0, 2);
std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias});
// replace
GraphView::replace(nodeToReplace, newNodes);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC}));
REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias)));
}
SECTION("replace with nothing") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
auto r1 = GenericOperator("relu", 0, 0, 1);
auto r2 = GenericOperator("relu", 1, 0, 1);
auto r3 = GenericOperator("relu", 1, 0, 1);
auto r4 = GenericOperator("relu", 1, 0, 0);
r1->addChild(r2);
r2->addChild(r3);
r3->addChild(r4);
g->add({r1, r2, r3, r4});
auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3});
auto newNodes = std::set<std::shared_ptr<Node>>({});
GraphView::replace(nodesToReplace, newNodes);
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
REQUIRE((r1->output(0))[0].first == r4);
}
SECTION("replace for tiling") {
std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
auto other1 = GenericOperator("Other", 1, 0, 1, "other1");
auto myConv = GenericOperator("Conv", 1, 0, 1, "myConv");
auto other2 = GenericOperator("Other", 1, 0, 1, "other2");
otherInput->addChild(other1);
other1->addChild(myConv);
myConv->addChild(other2);
g->add({other1, myConv, other2});
// create tiled Conv
auto conv1 = GenericOperator("Conv", 1, 0, 1, "myConv1");
auto conv2 = GenericOperator("Conv", 1, 0, 1, "myConv2");
auto conv3 = GenericOperator("Conv", 1, 0, 1, "myConv3");
auto conv4 = GenericOperator("Conv", 1, 0, 1, "myConv4");
auto concat = GenericOperator("Concat", 4, 0, 1, "myConcat");
conv1->addChild(concat);
conv2->addChild(concat);
conv3->addChild(concat);
conv4->addChild(concat);
GraphView::replace({myConv}, {conv1, conv2, conv3, conv4, concat});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, conv1, conv2, conv3, conv4, concat, other2}));
GraphView::replace({conv1, conv2, conv3, conv4, concat}, {myConv});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
}
SECTION("Change every Nodes in a GraphView") {
auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0");
auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0");
auto matmul0 = GenericOperator("MatMul", 1, 1, 1, "matmul0");
auto add0 = GenericOperator("Add", 1, 1, 1, "add0");
auto matmulWeight1 = GenericOperator("Producer", 0, 0, 1, "matmul_w1");
auto addBias1 = GenericOperator("Producer", 0, 0, 1, "add_b1");
auto matmul1 = GenericOperator("MatMul", 1, 1, 1, "matmul1");
auto add1 = GenericOperator("Add", 1, 1, 1, "add1");
matmulWeight0 -> addChild(matmul0, 0, 1);
addBias0 -> addChild(add0, 0, 1);
matmulWeight1 -> addChild(matmul1, 0, 1);
addBias1 -> addChild(add1, 0, 1);
matmul0 -> addChild(add0, 0, 0);
add0 -> addChild(matmul1, 0, 0);
matmul1 -> addChild(add1, 0, 0);
auto g = std::make_shared<GraphView>("TestGraph");
g -> add({matmulWeight0, addBias0, matmulWeight1, addBias1, matmul0, add0, matmul1, add1});
auto newMatmulWeight0 = matmulWeight0->cloneSharedOperators();
auto newAddBias0 = addBias0->cloneSharedOperators();
auto newMatmulWeight1 = matmulWeight1->cloneSharedOperators();
auto newAddBias1 = addBias1->cloneSharedOperators();
auto fc0 = GenericOperator("FC", 1, 2, 1, "fc0");
auto fc1 = GenericOperator("FC", 1, 2, 1, "fc1");
newMatmulWeight0 -> addChild(fc0, 0, 1);
newAddBias0 -> addChild(fc0, 0, 2);
newMatmulWeight1 -> addChild(fc1, 0, 1);
newAddBias1 -> addChild(fc1, 0, 2);
GraphView::replace({matmul0, add0, matmulWeight0, addBias0}, {newMatmulWeight0, newAddBias0, fc0});
GraphView::replace({matmul1, add1, matmulWeight1, addBias1}, {newMatmulWeight1, newAddBias1, fc1});
REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0}));
}
TEST_CASE("[GraphView] clone") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("clone_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0));
REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0));
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
g2->save("clone_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator());
}
SECTION("Check new connections") {
REQUIRE(dataProvider->getOperator()->getRawOutput(0) != g2->getNode("conv1")->getOperator()->getRawInput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(1) != g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(2) != g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getRawOutput(0) != g2->getNode("conv2")->getOperator()->getRawInput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(1) != g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(2) != g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getRawOutput(0) != g2->getNode("conv3")->getOperator()->getRawInput(0));
REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(1) != g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(2) != g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider2->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
}
}
TEST_CASE("[GraphView] cloneSharedProducers") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("cloneSharedProducers_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0));
REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0));
}
auto g2 = g1->cloneSharedProducers();
auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
dataProvider2->addChild(g2->getNode("conv1"), 0);
g2->forwardDims();
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
g2->save("cloneSharedProducers_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
}
SECTION("Check new connections") {
REQUIRE(dataProvider->getOperator()->getRawOutput(0) != g2->getNode("conv1")->getOperator()->getRawInput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv1")->getOperator()->getRawOutput(0) != g2->getNode("conv2")->getOperator()->getRawInput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv2")->getOperator()->getRawOutput(0) != g2->getNode("conv3")->getOperator()->getRawInput(0));
REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider2->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
}
}
TEST_CASE("[GraphView] cloneSharedOperators") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(64, 10, {1, 1}, "conv3");
auto g1 = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g1->add(conv1);
g1->addChild(conv2, conv1, 0);
g1->addChild(conv3, conv2, 0);
g1->save("cloneSharedOperators_g1");
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0));
REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0));
}
auto g2 = g1->cloneSharedOperators();
g2->forwardDims();
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
g2->save("cloneSharedOperators_g2");
SECTION("Check node cloning") {
REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
}
SECTION("Check operator cloning") {
REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator());
REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator());
REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator());
REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
}
SECTION("Check input-output connections") {
REQUIRE(dataProvider->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
TEST_CASE("[core/graph] GraphView(insertParent)") {
auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
auto conv1 = Conv(3, 32, {3, 3}, "conv1");
auto conv2 = Conv(32, 64, {3, 3}, "conv2");
auto conv3 = Conv(32, 64, {1, 1}, "conv3");
auto g = std::make_shared<GraphView>("TestGraph");
dataProvider->addChild(conv1, 0);
g->add(conv1);
g->addChild(conv2, conv1, 0);
g->addChild(conv3, conv1, 0);
g->save("graphForwardDims");
g->forwardDims();
auto newConv = Conv(32, 32, {1, 1}, "newConv");
SECTION("Check insertParent conv2 then insertParent conv3") {
g->insertParent(conv2, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children = {conv3, newConv};
std::set<NodePtr> expectedNewConvChildren = {conv2};
REQUIRE(conv1->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
REQUIRE(conv1->getOperator()->getRawOutput(0) == newConv->getOperator()->getRawInput(0));
REQUIRE(conv1->getOperator()->getRawOutput(0) != conv2->getOperator()->getRawInput(0));
REQUIRE(newConv->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren);
REQUIRE((conv1->getChildren()) == expectedConv1Children);
g->insertParent(conv3, newConv, 0, 0, 0);
std::set<NodePtr> expectedConv1Children2 = {newConv};
std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3};
REQUIRE(conv1->getOperator()->getRawOutput(0) != conv3->getOperator()->getRawInput(0));
REQUIRE(conv1->getOperator()->getRawOutput(0) == newConv->getOperator()->getRawInput(0));
REQUIRE(conv1->getOperator()->getRawOutput(0) != conv2->getOperator()->getRawInput(0));
REQUIRE(newConv->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
REQUIRE(newConv->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
REQUIRE((newConv->getChildren()) == expectedNewConvChildren2);
REQUIRE((conv1->getChildren()) == expectedConv1Children2);
}