Skip to content
Snippets Groups Projects
Commit 0686c4ed authored by vincent  lorrain's avatar vincent lorrain
Browse files

Merge branch 'refactor/recipies' into 'main'

Refactor/recipies

See merge request !48
parents 5686be10 f8884a1e
No related branches found
No related tags found
1 merge request!48Refactor/recipies
Pipeline #34241 passed
...@@ -21,28 +21,17 @@ ...@@ -21,28 +21,17 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp" //Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
using namespace Aidge; using namespace Aidge;
void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm){
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// Assert the nodes types are correct to be fused
std::shared_ptr<Node> conv;
std::shared_ptr<Node> batchnorm;
for (const auto& element : nodes) {
assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace");
if (element->type() == "Conv"){
conv = element;
}
else if (element->type() == "BatchNorm") {
batchnorm = element;
}
}
// TODO : check if batchnorm is the only child of the Conv or FC
std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second);
std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second);
std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second); std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second);
...@@ -127,19 +116,32 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ ...@@ -127,19 +116,32 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){
} }
void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op,batchNorm);
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm");
nodesRegex["Conv"] = new NodeRegex("Conv"); std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
nodesRegex["FC"] = new NodeRegex("FC"); regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'");
regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' ");
std::vector<std::string> seqRegex; regex->addQuery("OP -> BatchNorm");
seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC)
GRegex GReg(nodesRegex, seqRegex); for (const auto& solution : regex->match(graphView)) {
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); fuseBatchNorm(solution);
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
fuseBatchNorm(matchNodes[i]);
} }
} }
...@@ -22,30 +22,17 @@ ...@@ -22,30 +22,17 @@
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
// Graph Regex //Graph Regex
#include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
using namespace Aidge; using namespace Aidge;
void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ void Aidge::fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add){//std::set<std::shared_ptr<Node>> nodes){
// Fuse Mulmat & Add into FC // Fuse Mulmat & Add into FC
// Inputs : old nodes (pointers on mul & add) // Inputs : old nodes (pointers on mul & add)
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); assert((matmul->type() == "MatMul" && add->type() == "Add") && "Wrong type for the nodes to replace");
// Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ?
// Step 0 : Assert the nodes types are correct to be fused
std::shared_ptr<Node> add;
std::shared_ptr<Node> matmul;
for (const auto& element : nodes) {
assert((element->type() == "MatMul" || element->type() == "Add") && "Wrong type for the nodes to replace");
if (element->type() == "MatMul"){
matmul = element;
}
else if (element->type() == "Add") {
add = element;
}
}
// Step 1 : Create FC // Step 1 : Create FC
// Fetch the output dimension throught the bias size // Fetch the output dimension throught the bias size
...@@ -78,17 +65,35 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -78,17 +65,35 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
} }
void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){
assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n");
assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n");
for (const auto& matmul : solution->at("MatMul")) {
for (const auto& add : solution->at("Add")) {
fuseMulAdd(matmul,add);
}
}
}
void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["MatMul"] = new NodeRegex("MatMul"); std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
nodesRegex["Add"] = new NodeRegex("Add"); regex->setNodeKey("Add","getType($) =='Add'");
std::vector<std::string> seqRegex; regex->setNodeKey("MatMul","getType($) =='MatMul'");
seqRegex.push_back("MatMul -> Add;"); regex->addQuery("MatMul -> Add ;");
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView); for (const auto& solution : regex->match(graphView)) {
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) { fuseMulAdd(solution);
fuseMulAdd(matchNodes[i]);
} }
} }
...@@ -15,36 +15,41 @@ ...@@ -15,36 +15,41 @@
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/utils/Recipies.hpp" #include "aidge/utils/Recipies.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge {
void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
std::shared_ptr<Node> flatten;
for (const auto& element : nodes) {
assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
if (element->type() == "Flatten"){
flatten = element;
}
}
namespace Aidge {
void removeFlatten(std::shared_ptr<Node> flatten) {
GraphView::replace({flatten}, {}); GraphView::replace({flatten}, {});
} }
void removeFlatten(std::shared_ptr<MatchSolution> solution){
assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n");
assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n");
for (const auto& flatten : solution->at("Flatten")) {
removeFlatten(flatten);
}
}
void removeFlatten(std::shared_ptr<GraphView> graphView){ void removeFlatten(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["Flatten"] = new NodeRegex("Flatten");
nodesRegex["FC"] = new NodeRegex("FC"); std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
std::vector<std::string> seqRegex; regex->setNodeKey("Flatten","getType($) =='Flatten'");
seqRegex.push_back("Flatten->FC;"); regex->setNodeKey("FC","getType($) =='FC'");
GRegex GReg(nodesRegex, seqRegex); regex->addQuery("Flatten->FC");
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); for (const auto& solution : regex->match(graphView)) {
for (size_t i = 0; i < matches.getNbMatch(); ++i) { removeFlatten(solution);
removeFlatten(matchNodes[i]);
} }
} }
} }
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <iostream>
#include <map>
#include <memory>
#include <vector>
#include <utility>
#include <cassert>
#include <catch2/catch_test_macros.hpp>
//test
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/StmFactory.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/Match.hpp"
//use
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/GraphView.hpp"
using namespace Aidge;
TEST_CASE("Create good init GRegex", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"A","B","C"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("A->B;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
// Perform tests
REQUIRE(GReg.getStmInit().size() == 1);
REQUIRE(GReg.getStmFab().getNumberOfStm() == 1);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"Conv","BN","ReLU"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("Conv->BN->ReLU;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1);
std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1);
std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1);
std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1);
std::shared_ptr<Node> Random2 = GenericOperator("Random2", 1, 1, 1);
g1->add(Conv1);
g1->addChild(BN1, Conv1);
g1->addChild(ReLU1, BN1);
g1->addChild(Random, ReLU1);
//g1->addChild(BN1, Random2);
std::vector<std::shared_ptr<Node>> startNodes1;
std::set<std::shared_ptr<Node>> result;
startNodes1.push_back(Conv1);
result = GReg.matchFromStartNodes(startNodes1, g1);
std::set<std::shared_ptr<Node>> true_result;
true_result.insert(Conv1);
true_result.insert(BN1);
true_result.insert(ReLU1);
// Perform tests
REQUIRE(result == true_result);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"Add","FC","Conv"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("Add#->Conv;");
seqRegex.push_back("Add#->FC;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
// Instanciate a graphView
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1);
std::shared_ptr<Node> Add1 = GenericOperator("Add", 1, 1, 1);
std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1);
std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1);
std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1);
std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1);
std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1);
g1->add(Random0);
g1->addChild(Add1, Random0);
g1->addChild(Conv1, Add1);
g1->addChild(BN1, Conv1);
g1->addChild(ReLU1, BN1);
g1->addChild(FC1, Add1);
g1->addChild(Random, FC1);
// Test 1 : Find the match
std::vector<std::shared_ptr<Node>> startNodes;
std::set<std::shared_ptr<Node>> result;
startNodes.push_back(Add1);
startNodes.push_back(Add1);
result = GReg.matchFromStartNodes(startNodes, g1);
std::set<std::shared_ptr<Node>> true_result;
true_result.insert(Add1);
true_result.insert(Conv1);
true_result.insert(FC1);
// Test 2 : Return an empty set when the start nodes are wrong
std::vector<std::shared_ptr<Node>> wrong_startNodes;
std::set<std::shared_ptr<Node>> wrong_start_result;
std::set<std::shared_ptr<Node>> empty_result;
wrong_startNodes.push_back(Random0);
wrong_startNodes.push_back(Random0);
wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1);
// Perform tests
REQUIRE(result == true_result);
REQUIRE(wrong_start_result == empty_result);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
/*
TEST_CASE("Function matchFromStartNodes | Match a sequence with quantifier ", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"FC"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("FC+;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
// Instanciate a graphView
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1);
std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1);
std::shared_ptr<Node> FC2 = GenericOperator("FC", 1, 1, 1);
std::shared_ptr<Node> FC3 = GenericOperator("FC", 1, 1, 1);
std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1);
g1->add(Random0);
g1->addChild(FC1, Random0);
g1->addChild(FC2, FC1);
g1->addChild(FC3, FC2);
g1->addChild(ReLU1, FC3);
// Test 1 : Find the match
std::vector<std::shared_ptr<Node>> startNodes;
std::set<std::shared_ptr<Node>> result;
startNodes.push_back(FC1);
result = GReg.matchFromStartNodes(startNodes, g1);
std::set<std::shared_ptr<Node>> true_result;
true_result.insert(FC1);
true_result.insert(FC2);
true_result.insert(FC3);
// Test 2 : Return an empty set when the start nodes are wrong
std::vector<std::shared_ptr<Node>> wrong_startNodes;
std::set<std::shared_ptr<Node>> wrong_start_result;
std::set<std::shared_ptr<Node>> empty_result;
wrong_startNodes.push_back(Random0);
wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1);
// Perform tests
REQUIRE(result == true_result);
REQUIRE(wrong_start_result == empty_result);
}
*/
TEST_CASE("Function match | ALL matches of Nodes sequence", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"GEMM"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("GEMM;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
//init the input graph
std::shared_ptr<GraphView> graphToMatch = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1);
std::shared_ptr<Node> GEMM1 = GenericOperator("GEMM", 1, 1, 1);
std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1);
std::shared_ptr<Node> GEMM2 = GenericOperator("GEMM", 1, 1, 1);
std::shared_ptr<Node> GEMM3 = GenericOperator("GEMM", 1, 1, 1);
std::shared_ptr<Node> ReLU2 = GenericOperator("ReLU", 1, 1, 1);
std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1);
graphToMatch->add(Random0);
graphToMatch->addChild(GEMM1, Random0);
graphToMatch->addChild(ReLU1, GEMM1);
graphToMatch->addChild(GEMM2, ReLU1);
graphToMatch->addChild(GEMM3, GEMM2);
graphToMatch->addChild(ReLU2, GEMM3);
graphToMatch->addChild(Random, ReLU2);
//std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch);
//std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch);
Match matches = GReg.match(graphToMatch);
size_t nb = matches.getNbMatch();
std::vector<std::vector<NodeTmp>> gm_startnodes = matches.getStartNodes();
std::vector<std::set<NodeTmp>> gm_matchnodes = matches.getMatchNodes();
std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs;
for (size_t i = 0; i < nb; ++i) {
matchs.insert(std::make_pair(gm_startnodes[i], gm_matchnodes[i]));
}
//std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ;
std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ;
// Carefull : as the assert is on a vector, the Order of match matters
std::vector<NodeTmp> startNode = {GEMM1};
std::set<NodeTmp> matchNode = {GEMM1};
//toMatchs.push_back(std::make_pair(startNode,matchNode));
toMatchs.insert(std::make_pair(startNode,matchNode));
std::vector<NodeTmp> startNode2 = {GEMM2};
std::set<NodeTmp> matchNode2 = {GEMM2};
//toMatchs.push_back(std::make_pair(startNode2,matchNode2));
toMatchs.insert(std::make_pair(startNode2,matchNode2));
std::vector<NodeTmp> startNode3 = {GEMM3};
std::set<NodeTmp> matchNode3 = {GEMM3};
//toMatchs.push_back(std::make_pair(startNode3,matchNode3));
toMatchs.insert(std::make_pair(startNode3,matchNode3));
REQUIRE(matchs == toMatchs);
REQUIRE(nb == 3);
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <iostream>
#include <map>
#include <memory>
#include <cassert>
#include <catch2/catch_test_macros.hpp>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/operator/GenericOperator.hpp"
using namespace Aidge;
TEST_CASE("Create Noderegex", "[Noderegex]") {
std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("conv");
}
TEST_CASE("Test _is function", "[Noderegex]") {
// Create Noderegex with only condition on the name of the Node
// Create several operators to pass into Noderegex _is function
// Assert Noderegex._is(operators) are correct
std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("Conv");
std::shared_ptr<Node> Conv = GenericOperator("Conv", 1, 1, 1);
std::shared_ptr<Node> FC = GenericOperator("FC", 1, 1, 1);
REQUIRE(nr->_is(Conv) == true);
REQUIRE(nr->_is(FC) == false);
REQUIRE(nr->isA("Conv") == true);
REQUIRE(nr->isA("FC") == false);
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <iostream>
#include <map>
#include <memory>
#include <vector>
#include <utility>
#include <cassert>
#include <catch2/catch_test_macros.hpp>
//test
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
//use
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
using namespace Aidge;
TEST_CASE("Create good init SeqStm", "[SeqStm]") {
//init all iniput for SeqStm
int stmIdx = 0;
//matrix that in B->C
std::vector<std::vector<int>> transitionMatrix {
{ -1, 1, -1 },
{ -1, -1, 2 },
{ -1, -1, -1 } };
//std::cout << transitionMatrix.size() << "\n";
// init the nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"A","B","C"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
//
std::map<NodeTypeKey,int> typeToIdxTransition;
std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}};
//init nodeTypeCommonTag
int idx = 0;
for (const NodeTypeKey& key : nodeTypeCommonTag) {
typeToIdxTransition[key] = idx;
idx += 1;
}
int actSt = 0;
std::set<NodeTmp> allNodeValidated;
std::set<NodeTmp> allNodeTested;
std::set<std::pair<NodeTmp,std::string>> allCommonNode;
bool stmIsValid =false;
SeqStm stm(
stmIdx,
transitionMatrix,
nodesRegex,
typeToIdxTransition,
actSt,
allNodeValidated,
allNodeTested,
allCommonNode,
stmIsValid);
REQUIRE(stm.getStmIdx() == 0);
REQUIRE(stm.isValid() == false);
REQUIRE(stm.getAllCommonNode().size() == 0);
REQUIRE(stm.getAllNodeTested().size() == 0);
REQUIRE(stm.getAllNodeValidated().size() == 0);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
TEST_CASE("Test testNode function", "[SeqStm]") {
int stmIdx = 0;
std::map<NodeTypeKey,int> typeToIdxTransition;
std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}};
//init nodeTypeCommonTag
int idx = 0;
for (const NodeTypeKey& key : nodeTypeCommonTag) {
typeToIdxTransition[key] = idx;
idx += 1;
}
//matrix that in B->C
std::vector<std::vector<int>> transitionMatrix {
{ -1, 1, -1 },
{ -1, -1, 2 },
{ -1, -1, -1 } };
//std::cout << transitionMatrix.size() << "\n";
// init the nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"A","B","C"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
//
int actSt = 0;
std::set<NodeTmp> allNodeValidated;
std::set<NodeTmp> allNodeTested;
std::set<std::pair<NodeTmp,std::string>> allCommonNode;
bool stmIsValid =false;
SeqStm stm(
stmIdx,
transitionMatrix,
nodesRegex,
typeToIdxTransition,
actSt,
allNodeValidated,
allNodeTested,
allCommonNode,
stmIsValid);
REQUIRE(stm.getStmIdx() == 0);
//test a node
std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1);
std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1);
//set use to test the state of the smt
std::set<NodeTmp> testAllNodeTested;
std::set<NodeTmp> testAllNodeValidated;
stm.testNode(nodeB);
REQUIRE(stm.isValid() == false);
REQUIRE(stm.getState() == 1);
REQUIRE(stm.isStmBlocked() == false);
testAllNodeTested.insert(nodeB);
testAllNodeValidated.insert(nodeB);
REQUIRE(stm.getAllNodeTested() == testAllNodeTested);
REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated);
stm.testNode(nodeC);
REQUIRE(stm.isValid() == true);
REQUIRE(stm.getState() == 2);
REQUIRE(stm.isStmBlocked() == false);
testAllNodeTested.insert(nodeC);
testAllNodeValidated.insert(nodeC);
REQUIRE(stm.getAllNodeTested() == testAllNodeTested);
REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated);
stm.testNode(nodeC);
REQUIRE(stm.isValid() == true);
REQUIRE(stm.getState() == -1);
REQUIRE(stm.isStmBlocked() == true);
REQUIRE(stm.getAllNodeTested() == testAllNodeTested);
REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <iostream>
#include <map>
#include <memory>
#include <vector>
#include <utility>
#include <cassert>
#include <catch2/catch_test_macros.hpp>
//test
#include "aidge/graphmatching/StmFactory.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
//use
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
using namespace Aidge;
TEST_CASE("Create good init StmFactory", "[StmFactory]") {
// init the nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"A","B","C"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
StmFactory stmF(nodesRegex);
REQUIRE(stmF.getNumberOfStm() == 0);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") {
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"A","B","C"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
StmFactory stmF(nodesRegex);
std::string seq1 = "A->B+->A#;";
SeqStm* stm = stmF.makeNewStm(seq1);
REQUIRE(stm->getStmIdx() == 0);
REQUIRE(stm->isValid() == false);
REQUIRE(stm->getAllCommonNode().size() == 0);
REQUIRE(stm->getAllNodeTested().size() == 0);
REQUIRE(stm->getAllNodeValidated().size() == 0);
std::string seq2 = "A->B;";
SeqStm* stm2 = stmF.makeNewStm(seq2);
REQUIRE(stm2->getStmIdx() == 1);
REQUIRE(stm2->isValid() == false);
REQUIRE(stm2->getAllCommonNode().size() == 0);
REQUIRE(stm2->getAllNodeTested().size() == 0);
REQUIRE(stm2->getAllNodeValidated().size() == 0);
//test the number of stm
REQUIRE(stmF.getNumberOfStm() == 2);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") {
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"A","B","C"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
StmFactory stmF(nodesRegex);
std::string seq1 = "B->C;";
SeqStm* stm = stmF.makeNewStm(seq1);
//test the number of stm
REQUIRE(stmF.getNumberOfStm() == 1);
//std::shared_ptr<Node> nodeB = GenericOperator("B",1,1,1);
//std::shared_ptr<Node> nodeC = GenericiOperator("C",1,1,1);
std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1);
std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1);
//set use to test the state of the smt
std::set<NodeTmp> testAllNodeTested;
std::set<NodeTmp> testAllNodeValidated;
REQUIRE(stm->isValid() == false);
REQUIRE(stm->getState() == 0);
REQUIRE(stm->isStmBlocked() == false);
REQUIRE(stm->getAllNodeTested() == testAllNodeTested);
REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated);
stm->testNode(nodeB);
REQUIRE(stm->isValid() == false);
REQUIRE(stm->getState() == 1);
REQUIRE(stm->isStmBlocked() == false);
testAllNodeTested.insert(nodeB);
testAllNodeValidated.insert(nodeB);
REQUIRE(stm->getAllNodeTested() == testAllNodeTested);
REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated);
stm->testNode(nodeC);
REQUIRE(stm->isValid() == true);
REQUIRE(stm->getState() == 2);
REQUIRE(stm->isStmBlocked() == false);
testAllNodeTested.insert(nodeC);
testAllNodeValidated.insert(nodeC);
REQUIRE(stm->getAllNodeTested() == testAllNodeTested);
REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated);
stm->testNode(nodeC);
REQUIRE(stm->isValid() == true);
REQUIRE(stm->getState() == -1);
REQUIRE(stm->isStmBlocked() == true);
REQUIRE(stm->getAllNodeTested() == testAllNodeTested);
REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") {
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"A","B","C"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
StmFactory stmF(nodesRegex);
std::string seq1 = "B->C;";
SeqStm* stm = stmF.makeNewStm(seq1);
SeqStm* stmD = stmF.duplicateStm(stm);
std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1);
std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1);
//set use to test the state of the smt
std::set<NodeTmp> testAllNodeTested;
std::set<NodeTmp> testAllNodeValidated;
//run the stm
REQUIRE(stm->isValid() == false);
REQUIRE(stm->getState() == 0);
REQUIRE(stm->isStmBlocked() == false);
REQUIRE(stm->getAllNodeTested() == testAllNodeTested);
REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated);
stm->testNode(nodeB);
REQUIRE(stm->isValid() == false);
REQUIRE(stm->getState() == 1);
REQUIRE(stm->isStmBlocked() == false);
testAllNodeTested.insert(nodeB);
testAllNodeValidated.insert(nodeB);
REQUIRE(stm->getAllNodeTested() == testAllNodeTested);
REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated);
stm->testNode(nodeC);
REQUIRE(stm->isValid() == true);
REQUIRE(stm->getState() == 2);
REQUIRE(stm->isStmBlocked() == false);
testAllNodeTested.insert(nodeC);
testAllNodeValidated.insert(nodeC);
REQUIRE(stm->getAllNodeTested() == testAllNodeTested);
REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated);
stm->testNode(nodeC);
REQUIRE(stm->isValid() == true);
REQUIRE(stm->getState() == -1);
REQUIRE(stm->isStmBlocked() == true);
REQUIRE(stm->getAllNodeTested() == testAllNodeTested);
REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated);
//check if stmD not move
REQUIRE(stmD->isValid() == false);
REQUIRE(stmD->getState() == 0);
REQUIRE(stmD->isStmBlocked() == false);
REQUIRE(stmD->getAllNodeTested().size() == 0);
REQUIRE(stmD->getAllNodeValidated().size() == 0);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
...@@ -12,13 +12,38 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { ...@@ -12,13 +12,38 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") {
SECTION("custom Lambda") { SECTION("custom Lambda") {
const std::string test = " !toto($) == true " ;
ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); ConditionalInterpreter conditionalParserB = ConditionalInterpreter("A"," bad($) == false ");
conditionalParser.insertLambda("toto",+[](NodePtr NodeOp){return false;}); ConditionalInterpreter conditionalParserG = ConditionalInterpreter("A"," good($) == true ");
conditionalParserB.insertLambda("bad",+[](NodePtr NodeOp){return NodeOp->name() == "ZZ";});
conditionalParserG.insertLambda("good",+[](NodePtr NodeOp){return NodeOp->name() == "Gop1";});
std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1");
bool result = conditionalParser.test(nodeOp); REQUIRE(conditionalParserB.test(nodeOp) == true);
REQUIRE(result == true); REQUIRE(conditionalParserG.test(nodeOp) == true);
}
ConditionalInterpreter conditionalParserT = ConditionalInterpreter("A","isConv($)==true");
conditionalParserT.insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";});
std::shared_ptr<Node> zz = GenericOperator("conv", 0, 0, 0, "Gop1");
conditionalParserT.test(zz);
SECTION("Lambdas") {
ConditionalInterpreter conditionalParser = ConditionalInterpreter("OP_test","getType($) =='Conv' || getType($) =='FC' ");
std::shared_ptr<Node> A = GenericOperator("Conv", 0, 0, 0, "A");
REQUIRE(conditionalParser.test(A) == true);
std::shared_ptr<Node> B = GenericOperator("FC", 0, 0, 0, "B");
REQUIRE(conditionalParser.test(B) == true);
std::shared_ptr<Node> C = GenericOperator("A", 0, 0, 0, "C");
conditionalParser.test(C);
REQUIRE(conditionalParser.test(C) == false);
} }
SECTION("syntax error") { SECTION("syntax error") {
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
/*
#include <catch2/catch_test_macros.hpp>
#include <set>
//#include "aidge/backend/cpu/operator/BatchNormImpl.hpp"
//#include "aidge/backend/cpu/operator/ConvImpl.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/utils/Recipies.hpp"
//#include "aidge/backend/TensorImpl.hpp"
//#include "aidge/backend/cpu.hpp"
//#include "aidge/"
#include <cstddef>
namespace Aidge {
TEST_CASE("[FuseBatchNorm] conv") {
auto g1 = Sequential({
Producer({16, 3, 224, 224}, "dataProvider"),
Conv(3, 32, {3, 3}, "conv1"),
BatchNorm<2>()
});
g1->setDatatype(DataType::Float32);
g1->setBackend("cpu");
g1->forwardDims();
// std::set<std::string> availableBackends = Tensor::getAvailableBackends();
// if (availableBackends.find("cpu") != availableBackends.end()){
// g1->setBackend("cpu");
// newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr));
// }else{
// printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n");
// }
fuseBatchNorm(g1);
SECTION("Check resulting nodes") {
// REQUIRE(g1->getNodes().size() == 2);
// REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling");
// REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0));
// REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling");
// REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0));
// REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling");
}
}
}
*/
\ No newline at end of file
...@@ -26,6 +26,7 @@ ...@@ -26,6 +26,7 @@
namespace Aidge { namespace Aidge {
TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
// generate the original GraphView // generate the original GraphView
auto matmul0 = MatMul(5, "matmul0"); auto matmul0 = MatMul(5, "matmul0");
...@@ -74,4 +75,5 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { ...@@ -74,4 +75,5 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") {
REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); REQUIRE(((node->type() == "Producer") || (node->type() == "FC")));
} }
} }
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment