Skip to content
Snippets Groups Projects
Test_GraphRegex.cpp 5.62 KiB
Newer Older
vincent  lorrain's avatar
vincent lorrain committed

#include <catch2/catch_test_macros.hpp>
#include "aidge/graphRegex/GraphRegex.hpp"

vincent  lorrain's avatar
vincent lorrain committed

#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
vincent  lorrain's avatar
vincent lorrain committed

#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
vincent  lorrain's avatar
vincent lorrain committed

using namespace Aidge;

TEST_CASE("GraphRegexUser") {

    SECTION("INIT") {

        const std::string query = "Conv->FC";

        std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();

        std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
        std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
        std::shared_ptr<Node> fc = GenericOperator("FC", 1, 0, 1, "c1");
        std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
        std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 0, 1, "c3");

        g1->add(conv);
        g1->addChild(fc, "c");
        g1->addChild(conv2, "c1");
        g1->addChild(fc2, "c2");
        sut->setKeyFromGraph(g1);
        sut->addQuery(query);
        for (const auto& solution : sut->match(g1)) {

            REQUIRE(solution->getQuery() == query);
            if(solution->getStartNode() == std::vector<NodePtr>{conv}){
                REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv} );
                REQUIRE(solution->at("FC") == std::set<NodePtr>{fc} );
            }else if (solution->getStartNode() == std::vector<NodePtr>{conv2})
            {
                REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv2} );
                REQUIRE(solution->at("FC") == std::set<NodePtr>{fc2} );
            }
        }
        //REQUIRE( sut->match(g1)[1]->getAll() == std::set<NodePtr>{conv,fc});
vincent  lorrain's avatar
vincent lorrain committed
    }
        std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();

        std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
        std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
        std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
        std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
        std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3");

        g1->add(conv);
        g1->addChild(conv1, "c");
        g1->addChild(conv2, "c1");
        g1->addChild(conv3, "c2");


        sut->setKeyFromGraph(g1);

        const std::string query = "Conv->Conv";
        const std::string query2 = "Conv->FC";

vincent  lorrain's avatar
vincent lorrain committed
        sut->setNodeKey("FC","getType($) =='FC'");

        sut->addQuery(query);
        sut->addQuery(query2);


        for (const auto& solution : sut->match(g1)) {
            REQUIRE(solution->getQuery() == query);
        }
vincent  lorrain's avatar
vincent lorrain committed


   SECTION("Not define node Test") {

        //test if the FC is not define only match query not query2
        std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();

        std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
        std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
        std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
        std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
        std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 0, 1, "c3");

        g1->add(conv);
        g1->addChild(conv1, "c");
        g1->addChild(conv2, "c1");
        g1->addChild(conv3, "c2");


        //sut->setKeyFromGraph(g1);

        const std::string query = "Conv->Conv";
        const std::string query2 = "Conv->FC";

        sut->setNodeKey("Conv","getType($) =='Conv'");

        sut->addQuery(query);
        sut->addQuery(query2);


        for (const auto& solution : sut->match(g1)) {
            REQUIRE(solution->getQuery() == query);
        }
vincent  lorrain's avatar
vincent lorrain committed
    SECTION("Applied Recipes"){

      // generate the original GraphView
        auto matmul0 = MatMul("matmul0");
        auto matmul1 = MatMul("matmul1");
vincent  lorrain's avatar
vincent lorrain committed

        auto b0 = Producer({5}, "B0");
        auto w0 = Producer({5, 5}, "W0");
        auto b1 = Producer({5}, "B1");
        auto w1 = Producer({5,5},"W1");
        auto input = Producer({2,5}, "input");

        input->addChild(matmul0, 0, 0);
        w0->addChild(matmul0, 0, 1);

        matmul0->addChild(add0, 0, 0);
        b0->addChild(add0, 0, 1);

        add0->addChild(matmul1, 0, 0);
        w1->addChild(matmul1, 0, 1);

        matmul1->addChild(add1, 0, 0);
        b1->addChild(add1, 0, 1);

        auto fc = GenericOperator("FC", 1, 0, 1, "c");
        auto fl = GenericOperator("Flatten", 1, 0, 1, "c");
vincent  lorrain's avatar
vincent lorrain committed


        auto g = std::make_shared<GraphView>();
        g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc});
vincent  lorrain's avatar
vincent lorrain committed

        std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>();

        kitchenBook->setNodeKey("Add","getType($) =='Add'");
        kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'");
        kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'");
        kitchenBook->setNodeKey("FC","getType($) =='FC'");

        kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
        kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten));

        kitchenBook->appliedRecipes(g);

        std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
        REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc}));
	    //REQUIRE(newNodes.size() == 6);


    }

vincent  lorrain's avatar
vincent lorrain committed
}