Skip to content
Snippets Groups Projects
Forked from Eclipse Projects / aidge / aidge_core
697 commits behind the upstream repository.
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Test_Matching.cpp 17.67 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <catch2/catch_test_macros.hpp>

#include <fmt/chrono.h>

#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Testing.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/ReLU.hpp"
#include "aidge/operator/MetaOperatorDefs.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/recipes/Recipes.hpp"

using namespace Aidge;

void checkMatches(const std::set<SinglePassGraphMatching::MatchingResult>& results, const std::map<std::string, std::set<std::string>>& expected) {
    CHECK(results.size() == expected.size());

    for (const auto& result : results) {
        const auto found = nodePtrTo(result.graph->getNodes(), nodePtrToName);
        fmt::print("Found: {}\n", found);

        const auto rootNode = result.graph->rootNode()->name();
        const auto expectedSet = expected.at(rootNode);
        REQUIRE(found == expectedSet);
    }
}

TEST_CASE("[core/graph] Matching") {
    auto g1 = Sequential({
        Producer({16, 3, 512, 512}, "dataProvider"),
        Conv(3, 4, {5, 5}, "conv1"),
        ReLU("relu1"),
        PaddedConv(4, 8, {5, 5}, "conv2", {1, 1}, {2, 2, 2, 2}),
        ReLU("relu2"),
        PaddedConv(8, 16, {3, 3}, "conv3", {1, 1}, {2, 2, 2, 2}),
        ReLU("relu3"),
        PaddedConv(16, 16, {5, 5}, "conv4", {1, 1}, {2, 2, 2, 2}),
        Add("add"),
        PaddedConv(16, 16, {5, 5}, "conv5", {1, 1}, {2, 2, 2, 2}),
        ReLU("relu5"),
        Add("add2")
    });

    g1->getNode("relu3")->addChild(g1->getNode("add"), 0, 1);
    g1->getNode("conv5")->addChild(g1->getNode("add2"), 0, 1);
    g1->updateInputsOutputs();

    g1->save("Test_examples_before_expand", true);
    expandMetaOps(g1);
    g1->save("Test_examples", true);

    SECTION("Conv2D->(ReLU->Pad2D->Conv2D)*") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D->(ReLU->Pad2D->Conv2D)*");
        checkMatches(results, {
            {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "relu1", "relu2"}},
            {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "relu2"}},
            {"conv3_conv", {"conv3_conv"}},
            {"conv4_conv", {"conv4_conv"}},
            {"conv5_conv", {"conv5_conv"}}
        });
    }

    SECTION("Conv2D->ReLU;ReLU->Pad2D") {
        REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv2D->ReLU;ReLU->Pad2D"));
    }

    SECTION("Conv2D->ReLU#1;ReLU#2->Pad2D") {
        REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv2D->ReLU#1;ReLU#2->Pad2D"));
    }

    SECTION("Conv2D?->ReLU") {
        REQUIRE_THROWS(SinglePassGraphMatching(g1).match("Conv2D?->ReLU"));
    }

    SECTION("(Add#<*~.)*") {
        REQUIRE_THROWS(SinglePassGraphMatching(g1).match("(Add#<*~.)*"));
    }

    SECTION("Conv2D->(ReLU~>Pad2D->Conv2D)*") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D->(ReLU~>Pad2D->Conv2D)*");

        checkMatches(results, {
            {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
            {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
            {"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}},
            {"conv4_conv", {"conv4_conv"}},
            {"conv5_conv", {"conv5_conv"}}
        });
    }

    SECTION("Conv2D->(ReLU~>Pad2D->Conv2D)* [disjoint]") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D->(ReLU~>Pad2D->Conv2D)*", true);

        checkMatches(results, {
            {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
            {"conv5_conv", {"conv5_conv"}}
        });
    }

    SECTION("Conv~>(ReLU~>Pad2D->Conv2D)*") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D~>(ReLU~>Pad2D->Conv2D)*");

        checkMatches(results, {
            {"conv1", {"conv1", "conv2_conv", "conv2_pad", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu1", "relu2", "relu3"}},
            {"conv2_conv", {"conv2_conv", "conv3_conv", "conv3_pad", "conv4_conv", "conv4_pad", "relu2", "relu3"}},
            {"conv3_conv", {"conv3_conv", "conv4_conv", "conv4_pad", "relu3"}},
            {"conv4_conv", {"conv4_conv"}},
            {"conv5_conv", {"conv5_conv"}}
        });
    }

    SECTION("Pad2D->Conv2D#->ReLU;Conv2D#<1-Producer;Conv2D#<2-Producer") {
        const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#->ReLU;Conv2D#<1-Producer;Conv2D#<2-Producer");

        checkMatches(results, {
            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
        });
    }

    SECTION("Pad2D->Conv2D#~>ReLU;Conv2D#<1-Producer;Conv2D#<2-Producer") {
        const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#~>ReLU;Conv2D#<1-Producer;Conv2D#<2-Producer");
        checkMatches(results, {
            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
            {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
        });
    }

    SECTION("Pad2D->Conv2D#~>ReLU;(Conv2D#<*-Producer){2}") {
        const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#~>ReLU;(Conv2D#<*-Producer){2}");

        checkMatches(results, {
            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
            {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
        });
    }

    SECTION("Pad2D->Conv2D#->ReLU;(Conv2D#<*-Producer){2}") {
        const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#->ReLU;(Conv2D#<*-Producer){2}");

        checkMatches(results, {
            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
        });
    }

    SECTION("Pad2D->Conv2D#~>ReLU;(Conv2D#<*-.){2}") {
        const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#~>ReLU;(Conv2D#<*-.){2}");

        checkMatches(results, {
            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}},
            {"conv5_pad", {"conv5_b", "conv5_conv", "conv5_pad", "conv5_w", "relu5"}}
        });
    }

    SECTION("Pad2D->Conv2D#->ReLU;(Conv2D#<*-.){2}") {
        const auto results = SinglePassGraphMatching(g1).match("Pad2D->Conv2D#->ReLU;(Conv2D#<*-.){2}");

        checkMatches(results, {
            {"conv2_pad", {"conv2_b", "conv2_conv", "conv2_pad", "conv2_w", "relu2"}},
            {"conv3_pad", {"conv3_b", "conv3_conv", "conv3_pad", "conv3_w", "relu3"}}
        });
    }

    SECTION("Conv#~>ReLU*;Conv#<-Pad*") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D#~>ReLU*;Conv2D#<-Pad2D*");

        checkMatches(results, {
            {"conv1", {"conv1", "relu1"}},
            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
            {"conv4_conv", {"conv4_conv", "conv4_pad"}},
            {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
        });
    }

    SECTION("Conv2D#->ReLU*;Conv2D#<-Pad2D*") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D#->ReLU*;Conv2D#<-Pad2D*");

        checkMatches(results, {
            {"conv1", {"conv1", "relu1"}},
            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
            {"conv4_conv", {"conv4_conv", "conv4_pad"}},
            {"conv5_conv", {"conv5_conv", "conv5_pad"}}
        });
    }

    SECTION("Conv2D#->ReLU?-*>Add#1?->ReLU?;Conv2D#<-Pad2D?;(Add#1<*-.)?") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D#->ReLU?-*>Add#1?->ReLU?;Conv2D#<-Pad2D?;(Add#1<*-.)?");

        checkMatches(results, {
            {"conv1", {"conv1", "relu1"}},
            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
            {"conv4_conv", {"add", "conv4_conv", "conv4_pad"}},
            {"conv5_conv", {"conv5_conv", "conv5_pad"}}
        });
    }

    SECTION("Conv2D#~>ReLU?-*>Add#1?~>ReLU?;Conv2D#<-Pad?;(Add#1<*-.)?") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D#~>ReLU?-*>Add#1?~>ReLU?;Conv2D#<-Pad2D?;(Add#1<*-.)?");

        checkMatches(results, {
            {"conv1", {"conv1", "relu1"}},
            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
            {"conv4_conv", {"add", "conv4_conv", "conv4_pad"}},
            {"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}}
        });
    }

    SECTION("Conv2D#~>ReLU?~*>Add#1?~>ReLU?;Conv2D#<-Pad2D?;(Add#1<*~.)?") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D#~>ReLU?~*>Add#1?~>ReLU?;Conv2D#<-Pad2D?;(Add#1<*~.)?");

        checkMatches(results, {
            {"conv1", {"conv1", "relu1"}},
            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
            {"conv3_conv", {"add", "conv3_conv", "conv3_pad", "conv4_conv", "relu3"}},
            {"conv4_conv", {"add", "conv4_conv", "conv4_pad", "relu3"}},
            {"conv5_conv", {"add2", "conv5_conv", "conv5_pad", "relu5"}}
        });
    }

    SECTION("Conv2D#->ReLU?;Conv2D#<-Pad2D?") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D#->ReLU?;Conv2D#<-Pad2D?");

        checkMatches(results, {
            {"conv1", {"conv1", "relu1"}},
            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
            {"conv4_conv", {"conv4_conv", "conv4_pad"}},
            {"conv5_conv", {"conv5_conv", "conv5_pad"}}
        });
    }

    SECTION("Conv2D#~>ReLU?;Conv2D#<-Pad2D?") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D#~>ReLU?;Conv2D#<-Pad2D?");

        checkMatches(results, {
            {"conv1", {"conv1", "relu1"}},
            {"conv2_conv", {"conv2_conv", "conv2_pad", "relu2"}},
            {"conv3_conv", {"conv3_conv", "conv3_pad", "relu3"}},
            {"conv4_conv", {"conv4_conv", "conv4_pad"}},
            {"conv5_conv", {"conv5_conv", "conv5_pad", "relu5"}}
        });
    }

    SECTION("(Conv2D|ReLU)->Add") {
        const auto results = SinglePassGraphMatching(g1).match("(Conv2D|ReLU)->Add");

        checkMatches(results, {
            {"conv4_conv", {"add", "conv4_conv"}},
            {"relu5", {"add2", "relu5"}}
        });
    }

    SECTION("Add<*-.") {
        const auto results = SinglePassGraphMatching(g1).match("Add<*-.");

        checkMatches(results, {
            {"add", {"add", "conv4_conv"}},
            {"add2", {"add2", "relu5"}}
        });
    }

    SECTION("(Add#<*~.)+") {
        const auto results = SinglePassGraphMatching(g1).match("(Add#<*~.)+");

        checkMatches(results, {
            {"add", {"add", "conv4_conv", "relu3"}},
            {"add2", {"add2", "conv5_conv", "relu5"}}
        });
    }

    SECTION("Conv2D~*>(ReLU&Add)") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D~*>(ReLU&Add)");

        checkMatches(results, {
            {"conv5_conv", {"add2", "conv5_conv", "relu5"}}
        });
    }

    SECTION("Conv2D~>(ReLU&Add)") {
        const auto results = SinglePassGraphMatching(g1).match("Conv2D~>(ReLU&Add)");
        REQUIRE(results.size() == 0);
    }

    SECTION("ReLU~*>((Pad2D->Conv2D-*>Add#)&Add#)") {
        const auto results = SinglePassGraphMatching(g1).match("ReLU~*>((Pad2D->Conv2D-*>Add#)&Add#)");

        checkMatches(results, {
            {"relu3", {"add", "conv4_conv", "conv4_pad", "relu3"}}
        });
    }

    SECTION("ReLU-*>((Pad2D->Conv2D-*>Add)&Add)") {
        const auto results = SinglePassGraphMatching(g1).match("ReLU-*>((Pad2D->Conv2D-*>Add)&Add)");
        REQUIRE(results.size() == 0);
    }

    SECTION("Pad2D->Conv2D[3x3]->ReLU") {
        auto gm = SinglePassGraphMatching(g1);
        gm.addNodeLambda("3x3", [](const NodePtr& node) {
            const std::shared_ptr<Conv_Op<2>> op =
                std::static_pointer_cast<Conv_Op<2>>(node->getOperator());
            return (op->kernelDims() == std::array<DimSize_t, 2>({3, 3}));
        });

        const auto results = gm.match("Pad2D->Conv2D[3x3]->ReLU");

        checkMatches(results, {
            {"conv3_pad", {"conv3_conv", "conv3_pad", "relu3"}}
        });
    }

    SECTION(".[test]->Pad2D") {
        auto gm = SinglePassGraphMatching(g1);
        gm.addNodeLambda("test", [](const NodePtr& node) {
            return (node->type() == "Add" || (node->type() == "ReLU" && node->name() == "relu1"));
        });

        const auto results = gm.match(".[test]->Pad2D");

        checkMatches(results, {
            {"add", {"add", "conv5_pad"}},
            {"relu1", {"relu1", "conv2_pad"}}
        });
    }

    auto g2 = Sequential({
        Producer({16, 3, 512, 512}, "dataProvider"),
        Conv(3, 4, {5, 5}, "conv1"),
        BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn1"),
        Conv(4, 4, {5, 5}, "conv2"),
        ReLU("relu2"),
        Conv(4, 4, {5, 5}, "conv3"),
        BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"),
        FC(4, 4, false, "fc1"),
        FC(4, 4, false, "fc2"),
        FC(4, 4, false, "fc3"),
        ReLU("relu3"),
        Conv(1, 4, {5, 5}, "conv4")
    });

    SECTION("((Conv2D#->(.[exBN]|$))|(FC#->(.[exFC])*->$))") {
        auto gm = SinglePassGraphMatching(g2);
        gm.addNodeLambda("exBN", [](const NodePtr& node) {
            return (node->type() != "BatchNorm2D");
        });
        gm.addNodeLambda("exFC", [](const NodePtr& node) {
            return (node->type() != "FC");
        });

        const auto results = gm.match("((Conv2D#->(.[exBN]|$))|(FC#->(.[exFC])*->$))");

        checkMatches(results, {
            {"conv2", {"conv2", "relu2"}},
            {"conv4", {"conv4"}},
            {"fc3", {"fc3", "relu3", "conv4"}}
        });
    }

    // Find last node of a type
    SECTION("FC#->(.[exFC])*->$") {
        auto gm = SinglePassGraphMatching(g2);
        gm.addNodeLambda("exFC", [](const NodePtr& node) {
            return (node->type() != "FC");
        });

        const auto results = gm.match("FC#->(.[exFC])*->$");

        checkMatches(results, {
            {"fc3", {"fc3", "relu3", "conv4"}}
        });
    }

    SECTION("Conv2D#->(.[exConv])*->$") {
        auto gm = SinglePassGraphMatching(g2);
        gm.addNodeLambda("exConv", [](const NodePtr& node) {
            return (node->type() != "Conv2D");
        });

        const auto results = gm.match("Conv2D#->(.[exConv])*->$");

        checkMatches(results, {
            {"conv4", {"conv4"}}
        });
    }

    // Find first node of a type
    SECTION("FC#<-(.[exFC])*<-$") {
        auto gm = SinglePassGraphMatching(g2);
        gm.addNodeLambda("exFC", [](const NodePtr& node) {
            return (node->type() != "FC");
        });

        const auto results = gm.match("FC#<-(.[exFC])*<-$");
        checkMatches(results, {
            {"fc1", {"fc1", "bn3", "conv3", "relu2", "conv2", "bn1", "conv1", "dataProvider"}}
        });
    }

    SECTION("(((FC#|Conv2D#)<-(.[exParam])*<-$)|((FC#|Conv2D#)->(.[exParam])*->$));(FC#|Conv2D#)<1-Producer#") {
        auto gm = SinglePassGraphMatching(g2);
        gm.addNodeLambda("exParam", [](const NodePtr& node) {
            return (node->type() != "FC" && node->type() != "Conv2D");
        });

        const auto results = gm.match("(((FC#|Conv2D#)<-(.[exParam])*<-$)|((FC#|Conv2D#)->(.[exParam])*->$));(FC#|Conv2D#)<1-Producer#");

        checkMatches(results, {
            {"conv1", {"conv1", "conv1_w", "dataProvider"}},
            {"conv4", {"conv4", "conv4_w"}}
        });
    }

    SECTION("Conv2D->ReLU [perf]") {
        const size_t nbTests = 3;
        std::mt19937::result_type seed(1);

        for (int test = 0; test < nbTests; ++test) {
            RandomGraph randGraph;
            randGraph.types = {"Conv2D", "ReLU", "Dummy"};
            randGraph.typesWeights = {0.4, 0.4, 0.2};
            randGraph.avgIn = 1;
            randGraph.maxIn = 1;
            randGraph.maxOut = 1;
            randGraph.avgOut = 1;
            randGraph.density = 0.9;
            randGraph.acyclic = true;
            const auto g1 = std::make_shared<GraphView>("g1");

            Log::setConsoleLevel(Log::Warn);
            g1->add(randGraph.gen(seed, 100));
            g1->save("graph_single_pass");

            auto gm = SinglePassGraphMatching(g1);

            const auto start = std::chrono::system_clock::now();
            const auto results = gm.match("Conv2D->ReLU#;ReLU#->Dummy");
            const auto end = std::chrono::system_clock::now();
            const auto duration = std::chrono::duration_cast<std::chrono::milliseconds>(end - start);

            REQUIRE(results.size() > 0);
            ++seed;

            fmt::print("Found: {} - duration: {}\n", results.size(), duration);
        }
    }
}