Newer
Older
#include <catch2/catch_test_macros.hpp>
#include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"

Maxence Naud
committed
#include "aidge/recipies/Recipies.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
TEST_CASE("GraphRegexUser") {
SECTION("INIT") {
const std::string query = "Conv->FC";
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> fc = GenericOperator("FC", 1, 0, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 0, 1, "c3");
g1->add(conv);
g1->addChild(fc, "c");
g1->addChild(conv2, "c1");
g1->addChild(fc2, "c2");
sut->setKeyFromGraph(g1);
sut->addQuery(query);
for (const auto& solution : sut->match(g1)) {
REQUIRE(solution->getQuery() == query);
if(solution->getStartNode() == std::vector<NodePtr>{conv}){
REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv} );
REQUIRE(solution->at("FC") == std::set<NodePtr>{fc} );
}else if (solution->getStartNode() == std::vector<NodePtr>{conv2})
{
REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv2} );
REQUIRE(solution->at("FC") == std::set<NodePtr>{fc2} );
}
}
//REQUIRE( sut->match(g1)[1]->getAll() == std::set<NodePtr>{conv,fc});
vincent lorrain
committed
SECTION("2 query") {
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 0, 1, "c3");
g1->add(conv);
g1->addChild(conv1, "c");
g1->addChild(conv2, "c1");
g1->addChild(conv3, "c2");
sut->setKeyFromGraph(g1);
const std::string query = "Conv->Conv";
const std::string query2 = "Conv->FC";
sut->addQuery(query);
sut->addQuery(query2);
for (const auto& solution : sut->match(g1)) {
REQUIRE(solution->getQuery() == query);
}
vincent lorrain
committed
SECTION("Not define node Test") {
//test if the FC is not define only match query not query2
std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>();
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");

Maxence Naud
committed
std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 0, 1, "c");
std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 0, 1, "c1");
std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 0, 1, "c2");
std::shared_ptr<Node> conv3 = GenericOperator("FC", 1, 0, 1, "c3");
vincent lorrain
committed
g1->add(conv);
g1->addChild(conv1, "c");
g1->addChild(conv2, "c1");
g1->addChild(conv3, "c2");
//sut->setKeyFromGraph(g1);
const std::string query = "Conv->Conv";
const std::string query2 = "Conv->FC";
sut->setNodeKey("Conv","getType($) =='Conv'");
sut->addQuery(query);
sut->addQuery(query2);
for (const auto& solution : sut->match(g1)) {
REQUIRE(solution->getQuery() == query);
}
vincent lorrain
committed
}
SECTION("Applied Recipes"){
// generate the original GraphView

Maxence Naud
committed
auto add0 = Add(2, "add0");

Maxence Naud
committed
auto add1 = Add(2, "add1");
auto b0 = Producer({5}, "B0");
auto w0 = Producer({5, 5}, "W0");
auto b1 = Producer({5}, "B1");
auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0);
w0->addChild(matmul0, 0, 1);
matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 0);
w1->addChild(matmul1, 0, 1);
matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1);

Maxence Naud
committed
auto fc = GenericOperator("FC", 1, 0, 1, "c");
auto fl = GenericOperator("Flatten", 1, 0, 1, "c");
g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fl,fc});
std::shared_ptr<GraphRegex> kitchenBook = std::make_shared<GraphRegex>();
kitchenBook->setNodeKey("Add","getType($) =='Add'");
kitchenBook->setNodeKey("MatMul","getType($) =='MatMul'");
kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'");
kitchenBook->setNodeKey("FC","getType($) =='FC'");
kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten));
kitchenBook->appliedRecipes(g);
std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1,fc}));
//REQUIRE(newNodes.size() == 6);
}