Skip to content
Snippets Groups Projects
Test_GraphView.cpp 46.1 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <algorithm> // std::sort
Cyril Moineau's avatar
Cyril Moineau committed
#include <cassert>
#include <map>
#include <memory>
Cyril Moineau's avatar
Cyril Moineau committed
#include <string>

#include <catch2/catch_test_macros.hpp>

#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
Cyril Moineau's avatar
Cyril Moineau committed

using namespace Aidge;

TEST_CASE("genRandomGraph", "[GraphView][randomGen]") {
    const size_t nbTests = 100;
    size_t nbUnicity = 0;
    for (int test = 0; test < nbTests; ++test) {
        std::random_device rd;
        const std::mt19937::result_type seed(rd());
        RandomGraph randGraph;
        const auto g1 = std::make_shared<GraphView>("g1");
        const bool unicity1 = g1->add(randGraph.gen(seed, 10));
        const auto g2 = std::make_shared<GraphView>("g2");
        const bool unicity2 = g2->add(randGraph.gen(seed, 10));
        // g1->save("./genRandomGraph1");
        // g2->save("./genRandomGraph2");
        REQUIRE(unicity1 == unicity2);

        if (unicity1) {
            REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
            REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
            REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
            REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName));
            REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName));

            // Check that inputs/outputs are the same regardless of the order
            auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName);
            auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName);
            auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName);
            auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName);
            std::sort(orderedInputs1.begin(), orderedInputs1.end());
            std::sort(orderedInputs2.begin(), orderedInputs2.end());
            std::sort(orderedOutputs1.begin(), orderedOutputs1.end());
            std::sort(orderedOutputs2.begin(), orderedOutputs2.end());

            REQUIRE(orderedInputs1 == orderedInputs2);
            REQUIRE(orderedOutputs1 == orderedOutputs2);
            REQUIRE(nodePtrTo(g1->inputNodes(), nodePtrToName) == nodePtrTo(g2->inputNodes(), nodePtrToName));
            REQUIRE(nodePtrTo(g1->outputNodes(), nodePtrToName) == nodePtrTo(g2->outputNodes(), nodePtrToName));
    printf("nbUnicity = %zu/%zu\n", nbUnicity, nbTests);
TEST_CASE("clone", "[GraphView][clone]") {
    const size_t nbTests = 100;

    for (int test = 0; test < nbTests; ++test) {
        std::random_device rd;
        const std::mt19937::result_type seed(rd());

        RandomGraph randGraph;
        const auto g1 = std::make_shared<GraphView>("g1");
        g1->add(randGraph.gen(seed, 10));
        // g1 -> save("GraphView_clone");
        const auto g2 = g1->clone();

        REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
        REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
        REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
    }
NodePtr nodeDel(NodePtr node) {
    if (node->type() == "DelFictive") {
        return nullptr;
    }
    return node->clone();
}
TEST_CASE("clone_with_delete", "[GraphView][cloneDelete]") {
    const size_t nbTests = 100;
    size_t nbClonedWithDelete = 0;
    // Note: initial seed is chosen such that for nbTests=100, the generated
    // graphs keep the same inputs/outputs despites the deleted nodes
    // (meaning the deleted nodes are not input/output of the graph).
    // Otherwise, the last two REQUIRE are not garanteed to be true!
    // Warning: distributions are not required to behave the same way by the standard,
    // therefore the seed has to work for both GCC and MSVC...
    // See https://stackoverflow.com/questions/38532927/why-gcc-and-msvc-stdnormal-distribution-are-different
    std::mt19937::result_type seed(243);
    for (int test = 0; test < nbTests; ++test) {
        RandomGraph randGraph;
        randGraph.types = {"Fictive", "DelFictive"};
        randGraph.typesWeights = {0.9, 0.1};
        const auto g1 = std::make_shared<GraphView>("g1");
        const bool unicity1 = g1->add(randGraph.gen(seed, 10));
            randGraph.omitType = "DelFictive";
            const auto g2 = std::make_shared<GraphView>("g2");
            const bool unicity2 = g2->add(randGraph.gen(seed, 10));
            // g1->save("./clone_with_delete1");
            // g2->save("./clone_with_delete2");
            try {
                const auto gCloned = g1->cloneCallback(&nodeDel);
                REQUIRE(nodePtrTo(gCloned->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
                REQUIRE(nodePtrTo(gCloned->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
                REQUIRE(nodePtrTo(gCloned->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));
            catch (const std::runtime_error& error) {
                // pass
            }
    printf("nbClonedWithDelete = %zu/%zu\n", nbClonedWithDelete, nbTests);
TEST_CASE("remove", "[GraphView][remove]") {
    const size_t nbTests = 100;
    size_t nbTested = 0;

    for (int test = 0; test < nbTests; ++test) {
        std::random_device rd;
        const std::mt19937::result_type seed(rd());

        RandomGraph randGraph;
        randGraph.types = {"Fictive", "DelFictive"};
        randGraph.typesWeights = {0.8, 0.2};
        const auto g1 = std::make_shared<GraphView>("g1");
        const bool unicity1 = g1->add(randGraph.gen(seed, 10));

        if (unicity1) {
            // g1->save("./remove1_before");
            const auto nodes = g1->getNodes();
            int step = 1;
            for (auto node : nodes) {
                if (node->type() == "DelFictive") {
                    g1->remove(node, false);
                    // g1->save("./remove1_after" + std::to_string(step));
            randGraph.omitType = "DelFictive";
            const auto g2 = std::make_shared<GraphView>("g2");
            g2->add(randGraph.gen(seed, 10));
            // g1->save("./remove1");
            // g2->save("./remove2");

            REQUIRE(nodePtrTo(g1->getNodes(), nodePtrToName) == nodePtrTo(g2->getNodes(), nodePtrToName));
            // Order not garanteed, because when a node is removed, it can create new GraphView inputs/outputs
            // Their order thus depends on the deletion order!
            //REQUIRE(nodePtrTo(g1->getOrderedInputs(), nodePtrToName) == nodePtrTo(g2->getOrderedInputs(), nodePtrToName));
            //REQUIRE(nodePtrTo(g1->getOrderedOutputs(), nodePtrToName) == nodePtrTo(g2->getOrderedOutputs(), nodePtrToName));

            // Check that inputs/outputs are the same regardless of the order
            auto orderedInputs1 = nodePtrTo(g1->getOrderedInputs(), nodePtrToName);
            auto orderedInputs2 = nodePtrTo(g2->getOrderedInputs(), nodePtrToName);
            auto orderedOutputs1 = nodePtrTo(g1->getOrderedOutputs(), nodePtrToName);
            auto orderedOutputs2 = nodePtrTo(g2->getOrderedOutputs(), nodePtrToName);
            std::sort(orderedInputs1.begin(), orderedInputs1.end());
            std::sort(orderedInputs2.begin(), orderedInputs2.end());
            std::sort(orderedOutputs1.begin(), orderedOutputs1.end());
            std::sort(orderedOutputs2.begin(), orderedOutputs2.end());

            REQUIRE(orderedInputs1 == orderedInputs2);
            REQUIRE(orderedOutputs1 == orderedOutputs2);
            ++nbTested;
        }
    }

    printf("nbTested = %zu/%zu\n", nbTested, nbTests);
}

TEST_CASE("[core/graph] GraphView(Constructor)", "[GraphView][constructor()]") {
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<GraphView> g0 = std::make_shared<GraphView>();
    std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("G1");
    REQUIRE(g0 != nullptr);
    REQUIRE(g1 != nullptr);
}

TEST_CASE("[core/graph] GraphView(add)", "[GraphView][add]") {
Cyril Moineau's avatar
Cyril Moineau committed
    SECTION("Node alone") {
        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
        std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1");
        g->add(GOp1);
        std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 1, "Gop2");
        g->add(GOp2);
        std::shared_ptr<Node> GOp3 = GenericOperator("Fictive", 1, 0, 0, "Gop3");
Cyril Moineau's avatar
Cyril Moineau committed
        g->add(GOp3);
        std::shared_ptr<Node> GOp4 = GenericOperator("Fictive", 0, 1, 0, "Gop4");
        g->add(GOp4);
        std::shared_ptr<Node> GOp5 = GenericOperator("Fictive", 1, 0, 1, "Gop5");
Cyril Moineau's avatar
Cyril Moineau committed
        g->add(GOp5);
        std::shared_ptr<Node> GOp6 = GenericOperator("Fictive", 1, 1, 1, "Gop6");
Cyril Moineau's avatar
Cyril Moineau committed
        g->add(GOp6);
        // g->save("node_alone");
        REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop3", 0}, {"Gop4", 0}, {"Gop5", 0}, {"Gop6", 0}, {"Gop6", 1}}));
        REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop2", 0}, {"Gop5", 0}, {"Gop6", 0}}));
Cyril Moineau's avatar
Cyril Moineau committed
    }

    SECTION("Several Nodes") {
        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
        // should automaticaly add parents for learnable parameters
        std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 1, 1, "Gop1");
        std::shared_ptr<Node> GOp1parent = GenericOperator("Fictive", 0, 0, 1, "Gop1parent");
        GOp1parent->addChild(GOp1, 0, 0);
        g->add(GOp1);
        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent}));
        REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({}));
        REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}}));
Cyril Moineau's avatar
Cyril Moineau committed

        // there should be no deplicates
        g->add(GOp1);
        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp1parent}));
        REQUIRE(nodePtrTo(g->getOrderedInputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({}));
        REQUIRE(nodePtrTo(g->getOrderedOutputs(), nodePtrToName) == std::vector<std::pair<std::string, IOIndex_t>>({{"Gop1", 0}}));
Cyril Moineau's avatar
Cyril Moineau committed
    }

    SECTION("Initializer list ofr Node") {
        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
        std::shared_ptr<Node> GOp1 = GenericOperator("Fictive", 0, 0, 0, "Gop1");
        std::shared_ptr<Node> GOp2 = GenericOperator("Fictive", 0, 0, 0, "Gop2");
        g->add({GOp1, GOp1, GOp2});
        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({GOp1, GOp2}));
    }

    SECTION("another GraphView") {
        std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph-1");
        std::shared_ptr<GraphView> g2 = std::make_shared<GraphView>("TestGraph-2");
        auto conv = GenericOperator("Conv", 1, 0, 1, "c");
        auto conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
        auto conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
        auto conv3 = GenericOperator("Conv", 1, 0, 1, "c3");
        auto conv4 = GenericOperator("Conv", 1, 0, 1, "c4");
Cyril Moineau's avatar
Cyril Moineau committed
        conv->addChild(conv1);
        conv1->addChild(conv2);
        conv2->addChild(conv3);
        conv3->addChild(conv4);
        g1->add({conv, conv1, conv2, conv3, conv4});
        g2->add(g1);
        REQUIRE(((g1->getNodes() == g2->getNodes()) && (g2->getNodes() == std::set<std::shared_ptr<Node>>({conv, conv1, conv2, conv3, conv4}))));
        REQUIRE(((g1->inputNodes() == g2->inputNodes()) &&
            (g2->inputNodes() == std::set<std::shared_ptr<Node>>({conv}))));
        REQUIRE(((g1->outputNodes() == g2->outputNodes()) &&
            (g2->outputNodes() == std::set<std::shared_ptr<Node>>({conv4}))));
    }
}

TEST_CASE("[core/graph] GraphView(addChild)") {
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
    std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
    std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
    std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
    std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3");
    std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 0, 1, "c3.5");
    std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4");
    std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5");
Cyril Moineau's avatar
Cyril Moineau committed

    g1->add(conv);
    SECTION("add(node)") {
        REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
        REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv});
    }
    g1->addChild(conv1, "c");
    SECTION("add(node, outputNodeName)") {
        REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
        REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv1});
        REQUIRE(conv->getChildren() == std::set<std::shared_ptr<Node>>({conv1}));
        REQUIRE(conv1->getParents() == std::vector<std::shared_ptr<Node>>({conv}));
    }
    g1->addChild(conv2, "c1", 0);
    SECTION("add(node, pair<outputNodeName, outID>)") {
        REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
        REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv2});
        REQUIRE(conv1->getChildren() == std::set<std::shared_ptr<Node>>({conv2}));
        REQUIRE(conv2->getParents() == std::vector<std::shared_ptr<Node>>({conv1}));
    }
    g1->addChild(conv3, "c2", 0, 0);
    SECTION("add(node, list(outputNodeName))") {
        REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
        REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3});
        REQUIRE(conv2->getChildren() == std::set<std::shared_ptr<Node>>({conv3}));
        REQUIRE(conv3->getParents() == std::vector<std::shared_ptr<Node>>({conv2}));
    }
    g1->addChild(conv3_5, conv3);
    SECTION("add(node, list(outputNodeName))") {
        REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
        REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv3_5});
        REQUIRE(conv3->getChildren() == std::set<std::shared_ptr<Node>>({conv3_5}));
        REQUIRE(conv3_5->getParents() ==
                std::vector<std::shared_ptr<Node>>({conv3}));
    }
    g1->addChild(conv4, conv3_5, 0);
    SECTION("add(node, vector<pair<outputNodeName, outID>>)") {
        REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
        REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv4});
        REQUIRE(conv3_5->getChildren() == std::set<std::shared_ptr<Node>>({conv4}));
        REQUIRE(conv4->getParents() ==
                std::vector<std::shared_ptr<Node>>({conv3_5}));
    }
    g1->addChild(conv5, conv4, 0, 0);
    SECTION("add(node, vector<pair<outputNodeName, outID>>)") {
        REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>{conv});
        REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>{conv5});
        REQUIRE(conv4->getChildren() == std::set<std::shared_ptr<Node>>({conv5}));
        REQUIRE(conv5->getParents() == std::vector<std::shared_ptr<Node>>({conv4}));
    }
    std::set<std::shared_ptr<Node>> requiredNodes = {conv,    conv1, conv2, conv3,
                                                    conv3_5, conv4, conv5};
    REQUIRE(g1->getNodes() == requiredNodes);
    REQUIRE(g1->getChildren(conv3) == std::set<std::shared_ptr<Node>>({conv3_5}));
}

TEST_CASE("[core/graph] GraphView(inputs)") {
Cyril Moineau's avatar
Cyril Moineau committed
    auto g1 = std::make_shared<GraphView>("TestGraph");
    std::shared_ptr<Node> conv = Conv(3, 32, {3, 3});
Cyril Moineau's avatar
Cyril Moineau committed

    REQUIRE(g1->inputs() == conv->inputs());
}

TEST_CASE("[core/graph] GraphView(outputs)") {
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
    std::shared_ptr<Node> conv = Conv(3, 32, {3, 3});
    g1->add(conv);

    REQUIRE(g1->outputs() == conv->outputs());
}

TEST_CASE("[core/graph] GraphView(save)") {
Cyril Moineau's avatar
Cyril Moineau committed
    std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
    std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
    std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
    std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
    std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3");
    std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 0, 1, "c4");
    std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 0, 1, "c5");
Cyril Moineau's avatar
Cyril Moineau committed

    g1->add(conv);
    g1->addChild(conv1, "c");
    g1->addChild(conv2, "c1", 0);
    g1->addChild(conv3, "c2");
    g1->addChild(conv4, "c3", 0);
    g1->addChild(conv5, "c4", 0, 0);

    g1->save("./graphExample");
    printf("File saved in ./graphExample.md\n");
}

TEST_CASE("[core/graph] GraphView(resetConnections)") {
Cyril Moineau's avatar
Cyril Moineau committed
    SECTION("disconnect data iput") {
        std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
        std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1");
        std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
Cyril Moineau's avatar
Cyril Moineau committed
        std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1");
        std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2");
        conv->addChild(conv1);
        prod1->addChild(conv1,0,1);
        prod2->addChild(conv1,0,2);
        conv1->addChild(conv2);

        conv1->resetConnections(false);

        REQUIRE(conv->output(0).size() == 0);
        for (std::size_t i = 0; i < conv1->nbData(); ++i) {
Cyril Moineau's avatar
Cyril Moineau committed
        REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)));
        }
        REQUIRE((conv1->input(1) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod1, 0)));
        REQUIRE((conv1->input(2) == std::pair<std::shared_ptr<Node>, IOIndex_t>(prod2, 0)));
        REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)));
        for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) {
        REQUIRE(conv->output(i).size() == 0U);
        }
    }

    SECTION("disconnect data input + learnable parameters") {
        std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
        std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 2, 1, "c1");
        std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
Cyril Moineau's avatar
Cyril Moineau committed
        std::shared_ptr<Node> prod1 = GenericOperator("Prod", 0, 0, 1, "p1");
        std::shared_ptr<Node> prod2 = GenericOperator("Prod", 0, 0, 1, "p2");
        conv->addChild(conv1);
        prod1->addChild(conv1,0,1);
        prod2->addChild(conv1,0,2);
        conv1->addChild(conv2);

        conv1->resetConnections(true);

        REQUIRE(conv->output(0).size() == 0);
        for (std::size_t i = 0; i < conv1->nbInputs(); ++i) {
        REQUIRE((conv1->input(i) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)));
        }
        REQUIRE((conv2->input(0) == std::pair<std::shared_ptr<Node>, IOIndex_t>(nullptr, gk_IODefaultIndex)));
        for (std::size_t i = 0; i < conv1->nbOutputs(); ++i) {
        REQUIRE(conv->output(i).size() == 0U);
        }
    }
}

TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") {
Cyril Moineau's avatar
Cyril Moineau committed
    auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
    auto conv1 = Conv(3, 32, {3, 3}, "conv1");
    auto conv2 = Conv(32, 64, {3, 3}, "conv2");
    auto conv3 = Conv(64, 10, {1, 1}, "conv3");
    auto g = std::make_shared<GraphView>("TestGraph");
    dataProvider->addChild(conv1, 0);
    g->add(conv1);
    g->addChild(conv2, conv1, 0);
    g->addChild(conv3, conv2, 0);
    g->save("graphForwardDims");
    g->forwardDims();

    SECTION("Check input-output connections") {
        REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0));
        REQUIRE(conv1->getOperator()->getRawInput(1) == g->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv1->getOperator()->getRawInput(2) == g->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
        REQUIRE(conv2->getOperator()->getRawInput(1) == g->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv2->getOperator()->getRawInput(2) == g->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
        REQUIRE(conv3->getOperator()->getRawInput(1) == g->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv3->getOperator()->getRawInput(2) == g->getNode("conv3_b")->getOperator()->getRawOutput(0));
Cyril Moineau's avatar
Cyril Moineau committed
    }

    SECTION("Check forwarded dims") {
        REQUIRE(std::static_pointer_cast<Tensor>(conv1->getOperator()->getRawOutput(0))
Cyril Moineau's avatar
Cyril Moineau committed
                    ->dims() == std::vector<DimSize_t>({16, 32, 222, 222}));
        REQUIRE(std::static_pointer_cast<Tensor>(conv2->getOperator()->getRawOutput(0))
Cyril Moineau's avatar
Cyril Moineau committed
                    ->dims() == std::vector<DimSize_t>({16, 64, 220, 220}));
    }
}

TEST_CASE("[core/graph] GraphView(replace)", "[GraphView][replace]") {
    SECTION("replace small pattern") {
        // create original graph
        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
        auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
        auto matmulWeight = GenericOperator("Producer", 0, 0, 1, "matmul_w");
        auto addBias = GenericOperator("Producer", 0, 0, 1, "add_b");
        auto other1 = GenericOperator("Other", 1, 0, 1, "other1");
        auto other2 = GenericOperator("Other", 1, 0, 1, "other2");
        auto matmul = GenericOperator("MatMul", 1, 1, 1, "matmul");
        auto add = GenericOperator("Add", 1, 1, 1, "add");
        otherInput->addChild(other1);
        other1->addChild(matmul);
        matmul->addChild(add);
        add->addChild(other2);
        matmulWeight->addChild(matmul, 0, 1);
        addBias->addChild(add, 0, 1);
        g->add({other1, matmul, add, other2});
        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({matmulWeight, addBias, other1, other2, matmul, add}));

        // create graph to replace
        std::set<std::shared_ptr<Node>> nodeToReplace = std::set<std::shared_ptr<Node>>({matmulWeight, addBias, matmul, add});

        // create replacing graph
        std::shared_ptr<Node> myFC = GenericOperator("FC", 1, 2, 1, "fc");
        auto newMatmulWeight = matmulWeight->cloneSharedOperators();
        newMatmulWeight->addChild(myFC, 0, 1);
        auto newAddBias = addBias->cloneSharedOperators();
        newAddBias->addChild(myFC, 0, 2);
        std::set<std::shared_ptr<Node>> newNodes = std::set<std::shared_ptr<Node>>({myFC, newMatmulWeight, newAddBias});

        // replace
        GraphView::replace(nodeToReplace, newNodes);

        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight, newAddBias, other1, other2, myFC}));
        REQUIRE(((myFC->getParent(0) == other1) && (myFC->getParent(1) == newMatmulWeight) && (myFC->getParent(2) == newAddBias)));
    }
    SECTION("replace with nothing") {
        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("TestGraph");
        auto r1 = GenericOperator("relu", 0, 0, 1);
        auto r2 = GenericOperator("relu", 1, 0, 1);
        auto r3 = GenericOperator("relu", 1, 0, 1);
        auto r4 = GenericOperator("relu", 1, 0, 0);
        r1->addChild(r2);
        r2->addChild(r3);
        r3->addChild(r4);
        g->add({r1, r2, r3, r4});
        auto nodesToReplace = std::set<std::shared_ptr<Node>>({r2, r3});
        auto newNodes = std::set<std::shared_ptr<Node>>({});
        GraphView::replace(nodesToReplace, newNodes);

        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({r1, r4}));
        REQUIRE((r1->output(0))[0].first == r4);
    }

    SECTION("replace for tiling") {
        std::shared_ptr<GraphView> g = std::make_shared<GraphView>("test_graph");
        auto otherInput = GenericOperator("Producer", 0, 0, 1, "other_input");
        auto other1 = GenericOperator("Other", 1, 0, 1, "other1");
        auto myConv = GenericOperator("Conv", 1, 0, 1, "myConv");
        auto other2 = GenericOperator("Other", 1, 0, 1, "other2");
        otherInput->addChild(other1);
        other1->addChild(myConv);
        myConv->addChild(other2);
        g->add({other1, myConv, other2});

        // create tiled Conv
        auto conv1 =  GenericOperator("Conv", 1, 0, 1, "myConv1");
        auto conv2 =  GenericOperator("Conv", 1, 0, 1, "myConv2");
        auto conv3 =  GenericOperator("Conv", 1, 0, 1, "myConv3");
        auto conv4 =  GenericOperator("Conv", 1, 0, 1, "myConv4");
        auto concat = GenericOperator("Concat", 4, 0, 1, "myConcat");
        conv1->addChild(concat);
        conv2->addChild(concat);
        conv3->addChild(concat);
        conv4->addChild(concat);

        GraphView::replace({myConv}, {conv1, conv2, conv3, conv4, concat});

        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, conv1, conv2, conv3, conv4, concat, other2}));

        GraphView::replace({conv1, conv2, conv3, conv4, concat}, {myConv});

        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({other1, myConv, other2}));
    }

    SECTION("Change every Nodes in a GraphView") {
        auto matmulWeight0 = GenericOperator("Producer", 0, 0, 1, "matmul_w0");
        auto addBias0 = GenericOperator("Producer", 0, 0, 1, "add_b0");
        auto matmul0 = GenericOperator("MatMul", 1, 1, 1, "matmul0");
        auto add0 = GenericOperator("Add", 1, 1, 1, "add0");
        auto matmulWeight1 = GenericOperator("Producer", 0, 0, 1, "matmul_w1");
        auto addBias1 = GenericOperator("Producer", 0, 0, 1, "add_b1");
        auto matmul1 = GenericOperator("MatMul", 1, 1, 1, "matmul1");
        auto add1 = GenericOperator("Add", 1, 1, 1, "add1");

        matmulWeight0 -> addChild(matmul0, 0, 1);
        addBias0 -> addChild(add0, 0, 1);
        matmulWeight1 -> addChild(matmul1, 0, 1);
        addBias1 -> addChild(add1, 0, 1);
        matmul0 -> addChild(add0, 0, 0);
        add0 -> addChild(matmul1, 0, 0);
        matmul1 -> addChild(add1, 0, 0);

        auto g = std::make_shared<GraphView>("TestGraph");
        g -> add({matmulWeight0, addBias0, matmulWeight1, addBias1, matmul0, add0, matmul1, add1});
        auto newMatmulWeight0 = matmulWeight0->cloneSharedOperators();
        auto newAddBias0 = addBias0->cloneSharedOperators();
        auto newMatmulWeight1 = matmulWeight1->cloneSharedOperators();
        auto newAddBias1 = addBias1->cloneSharedOperators();
        auto fc0 = GenericOperator("FC", 1, 2, 1, "fc0");
        auto fc1 = GenericOperator("FC", 1, 2, 1, "fc1");

        newMatmulWeight0 -> addChild(fc0, 0, 1);
        newAddBias0 -> addChild(fc0, 0, 2);
        newMatmulWeight1 -> addChild(fc1, 0, 1);
        newAddBias1 -> addChild(fc1, 0, 2);

        GraphView::replace({matmul0, add0, matmulWeight0, addBias0}, {newMatmulWeight0, newAddBias0, fc0});
        GraphView::replace({matmul1, add1, matmulWeight1, addBias1}, {newMatmulWeight1, newAddBias1, fc1});

        REQUIRE(g->getNodes() == std::set<std::shared_ptr<Node>>({newMatmulWeight0, newAddBias0, newAddBias1, newMatmulWeight1, fc1, fc0}));
    }
TEST_CASE("[GraphView] clone") {
    auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
    auto conv1 = Conv(3, 32, {3, 3}, "conv1");
    auto conv2 = Conv(32, 64, {3, 3}, "conv2");
    auto conv3 = Conv(64, 10, {1, 1}, "conv3");
    auto g1 = std::make_shared<GraphView>("TestGraph");
    dataProvider->addChild(conv1, 0);
    g1->add(conv1);
    g1->addChild(conv2, conv1, 0);
    g1->addChild(conv3, conv2, 0);
    g1->save("clone_g1");

    SECTION("Check input-output connections") {
        REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0));
        REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
        REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
        REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0));
    }

    auto g2 = g1->clone();

    auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
    dataProvider2->addChild(g2->getNode("conv1"), 0);

    g2->forwardDims();
    g2->save("clone_g2");

    SECTION("Check node cloning") {
        REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
        REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
        REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
        REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
        REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
        REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
        REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
        REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
        REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
    }

    SECTION("Check operator cloning") {
        REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
        REQUIRE(g1->getNode("conv1_w")->getOperator() != g2->getNode("conv1_w")->getOperator());
        REQUIRE(g1->getNode("conv1_b")->getOperator() != g2->getNode("conv1_b")->getOperator());
        REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
        REQUIRE(g1->getNode("conv2_w")->getOperator() != g2->getNode("conv2_w")->getOperator());
        REQUIRE(g1->getNode("conv2_b")->getOperator() != g2->getNode("conv2_b")->getOperator());
        REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
        REQUIRE(g1->getNode("conv3_w")->getOperator() != g2->getNode("conv3_w")->getOperator());
        REQUIRE(g1->getNode("conv3_b")->getOperator() != g2->getNode("conv3_b")->getOperator());
    }

        REQUIRE(dataProvider->getOperator()->getRawOutput(0) != g2->getNode("conv1")->getOperator()->getRawInput(0));
        REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(1) != g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(2) != g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv1")->getOperator()->getRawOutput(0) != g2->getNode("conv2")->getOperator()->getRawInput(0));
        REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(1) != g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(2) != g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv2")->getOperator()->getRawOutput(0) != g2->getNode("conv3")->getOperator()->getRawInput(0));
        REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(1) != g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(2) != g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
    SECTION("Check input-output connections") {
        REQUIRE(dataProvider2->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
    }
}

TEST_CASE("[GraphView] cloneSharedProducers") {
    auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
    auto conv1 = Conv(3, 32, {3, 3}, "conv1");
    auto conv2 = Conv(32, 64, {3, 3}, "conv2");
    auto conv3 = Conv(64, 10, {1, 1}, "conv3");
    auto g1 = std::make_shared<GraphView>("TestGraph");
    dataProvider->addChild(conv1, 0);
    g1->add(conv1);
    g1->addChild(conv2, conv1, 0);
    g1->addChild(conv3, conv2, 0);
    g1->save("cloneSharedProducers_g1");

    SECTION("Check input-output connections") {
        REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0));
        REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
        REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
        REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0));
    }

    auto g2 = g1->cloneSharedProducers();

    auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider");
    dataProvider2->addChild(g2->getNode("conv1"), 0);

    g2->forwardDims();
    g2->save("cloneSharedProducers_g2");

    SECTION("Check node cloning") {
        REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
        REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
        REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
        REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
        REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
        REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
        REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
        REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
        REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
    }

    SECTION("Check operator cloning") {
        REQUIRE(g1->getNode("conv1")->getOperator() != g2->getNode("conv1")->getOperator());
        REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
        REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
        REQUIRE(g1->getNode("conv2")->getOperator() != g2->getNode("conv2")->getOperator());
        REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
        REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
        REQUIRE(g1->getNode("conv3")->getOperator() != g2->getNode("conv3")->getOperator());
        REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
        REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
    }

        REQUIRE(dataProvider->getOperator()->getRawOutput(0) != g2->getNode("conv1")->getOperator()->getRawInput(0));
        REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv1")->getOperator()->getRawOutput(0) != g2->getNode("conv2")->getOperator()->getRawInput(0));
        REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv2")->getOperator()->getRawOutput(0) != g2->getNode("conv3")->getOperator()->getRawInput(0));
        REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(g1->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
    SECTION("Check input-output connections") {
        REQUIRE(dataProvider2->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
    }
}

TEST_CASE("[GraphView] cloneSharedOperators") {
    auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
    auto conv1 = Conv(3, 32, {3, 3}, "conv1");
    auto conv2 = Conv(32, 64, {3, 3}, "conv2");
    auto conv3 = Conv(64, 10, {1, 1}, "conv3");
    auto g1 = std::make_shared<GraphView>("TestGraph");
    dataProvider->addChild(conv1, 0);
    g1->add(conv1);
    g1->addChild(conv2, conv1, 0);
    g1->addChild(conv3, conv2, 0);
    g1->save("cloneSharedOperators_g1");

    SECTION("Check input-output connections") {
        REQUIRE(dataProvider->getOperator()->getRawOutput(0) == conv1->getOperator()->getRawInput(0));
        REQUIRE(conv1->getOperator()->getRawInput(1) == g1->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv1->getOperator()->getRawInput(2) == g1->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(conv1->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
        REQUIRE(conv2->getOperator()->getRawInput(1) == g1->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv2->getOperator()->getRawInput(2) == g1->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(conv2->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
        REQUIRE(conv3->getOperator()->getRawInput(1) == g1->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(conv3->getOperator()->getRawInput(2) == g1->getNode("conv3_b")->getOperator()->getRawOutput(0));
    }

    auto g2 = g1->cloneSharedOperators();
    g2->save("cloneSharedOperators_g2");

    SECTION("Check node cloning") {
        REQUIRE(g1->getNode("conv1") != g2->getNode("conv1"));
        REQUIRE(g1->getNode("conv1_w") != g2->getNode("conv1_w"));
        REQUIRE(g1->getNode("conv1_b") != g2->getNode("conv1_b"));
        REQUIRE(g1->getNode("conv2") != g2->getNode("conv2"));
        REQUIRE(g1->getNode("conv2_w") != g2->getNode("conv2_w"));
        REQUIRE(g1->getNode("conv2_b") != g2->getNode("conv2_b"));
        REQUIRE(g1->getNode("conv3") != g2->getNode("conv3"));
        REQUIRE(g1->getNode("conv3_w") != g2->getNode("conv3_w"));
        REQUIRE(g1->getNode("conv3_b") != g2->getNode("conv3_b"));
    }

    SECTION("Check operator cloning") {
        REQUIRE(g1->getNode("conv1")->getOperator() == g2->getNode("conv1")->getOperator());
        REQUIRE(g1->getNode("conv1_w")->getOperator() == g2->getNode("conv1_w")->getOperator());
        REQUIRE(g1->getNode("conv1_b")->getOperator() == g2->getNode("conv1_b")->getOperator());
        REQUIRE(g1->getNode("conv2")->getOperator() == g2->getNode("conv2")->getOperator());
        REQUIRE(g1->getNode("conv2_w")->getOperator() == g2->getNode("conv2_w")->getOperator());
        REQUIRE(g1->getNode("conv2_b")->getOperator() == g2->getNode("conv2_b")->getOperator());
        REQUIRE(g1->getNode("conv3")->getOperator() == g2->getNode("conv3")->getOperator());
        REQUIRE(g1->getNode("conv3_w")->getOperator() == g2->getNode("conv3_w")->getOperator());
        REQUIRE(g1->getNode("conv3_b")->getOperator() == g2->getNode("conv3_b")->getOperator());
    }

    SECTION("Check input-output connections") {
        REQUIRE(dataProvider->getOperator()->getRawOutput(0) == g2->getNode("conv1")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(1) == g2->getNode("conv1_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawInput(2) == g2->getNode("conv1_b")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv1")->getOperator()->getRawOutput(0) == g2->getNode("conv2")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(1) == g2->getNode("conv2_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawInput(2) == g2->getNode("conv2_b")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv2")->getOperator()->getRawOutput(0) == g2->getNode("conv3")->getOperator()->getRawInput(0));
        REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(1) == g2->getNode("conv3_w")->getOperator()->getRawOutput(0));
        REQUIRE(g2->getNode("conv3")->getOperator()->getRawInput(2) == g2->getNode("conv3_b")->getOperator()->getRawOutput(0));
TEST_CASE("[core/graph] GraphView(insertParent)") {
    auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider");
    auto conv1 = Conv(3, 32, {3, 3}, "conv1");
    auto conv2 = Conv(32, 64, {3, 3}, "conv2");
    auto conv3 = Conv(32, 64, {1, 1}, "conv3");
    auto g = std::make_shared<GraphView>("TestGraph");
    dataProvider->addChild(conv1, 0);
    g->add(conv1);
    g->addChild(conv2, conv1, 0);
    g->addChild(conv3, conv1, 0);
    g->save("graphForwardDims");
    g->forwardDims();

    auto newConv = Conv(32, 32, {1, 1}, "newConv");

    SECTION("Check insertParent conv2 then insertParent conv3") {
        g->insertParent(conv2, newConv, 0, 0, 0);

        std::set<NodePtr> expectedConv1Children = {conv3, newConv};
        std::set<NodePtr> expectedNewConvChildren = {conv2};
        REQUIRE(conv1->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
        REQUIRE(conv1->getOperator()->getRawOutput(0) == newConv->getOperator()->getRawInput(0));
        REQUIRE(conv1->getOperator()->getRawOutput(0) != conv2->getOperator()->getRawInput(0));
        REQUIRE(newConv->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
        REQUIRE((newConv->getChildren()) == expectedNewConvChildren);
        REQUIRE((conv1->getChildren()) == expectedConv1Children);

        g->insertParent(conv3, newConv, 0, 0, 0);

        std::set<NodePtr> expectedConv1Children2 = {newConv};
        std::set<NodePtr> expectedNewConvChildren2 = {conv2, conv3};

        REQUIRE(conv1->getOperator()->getRawOutput(0) != conv3->getOperator()->getRawInput(0));
        REQUIRE(conv1->getOperator()->getRawOutput(0) == newConv->getOperator()->getRawInput(0));
        REQUIRE(conv1->getOperator()->getRawOutput(0) != conv2->getOperator()->getRawInput(0));
        REQUIRE(newConv->getOperator()->getRawOutput(0) == conv2->getOperator()->getRawInput(0));
        REQUIRE(newConv->getOperator()->getRawOutput(0) == conv3->getOperator()->getRawInput(0));
        REQUIRE((newConv->getChildren()) == expectedNewConvChildren2);
        REQUIRE((conv1->getChildren()) == expectedConv1Children2);

    }