Skip to content
Snippets Groups Projects
Commit 4540e79c authored by vincent  lorrain's avatar vincent lorrain
Browse files

clean and remove old graph matching

parent 717a683e
No related branches found
No related tags found
1 merge request!48Refactor/recipies
Pipeline #34204 failed
This commit is part of merge request !48. Comments created here will be created in the context of that merge request.
Showing
with 6 additions and 1707 deletions
...@@ -20,7 +20,6 @@ ...@@ -20,7 +20,6 @@
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/Match.hpp" #include "aidge/graphmatching/Match.hpp"
#include "aidge/graphmatching/NodeRegex.hpp" #include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/SeqStm.hpp" #include "aidge/graphmatching/SeqStm.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_GREGEX_H_
#define AIDGE_GREGEX_H_
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <regex>
#include <memory> // for shared_ptr
#include <algorithm> // for next_permutation
#include "aidge/graphmatching/Utile.hpp"
#include "aidge/graphmatching/StmFactory.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/Match.hpp"
namespace Aidge{
class GRegex {
// __init__(self,nodes_regex:dict,seq_regexps:list)
StmFactory mStmFab;
std::vector<SeqStm*> mStmInit;
public:
GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps );
std::set<NodeTmp> matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch);
bool walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm);
bool walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm);
bool walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm);
std::set<NodeTmp> get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm);
std::vector<SeqStm*> getStmInit() const {
return mStmInit;
}
StmFactory getStmFab() const {
return mStmFab;
}
//std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> match(const std::shared_ptr<GraphView> graphToMatch);
Match match(const std::shared_ptr<GraphView> graphToMatch);
};
}
#endif //AIDGE_GREGEX_H_
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_MATCH_H_
#define AIDGE_MATCH_H_
#include <vector>
#include <set>
#include <iostream>
#include <cassert>
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge{
class Match {
public:
Match();
size_t getNbMatch();
void insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes);
std::vector<std::vector<NodeTmp>> getStartNodes();
std::vector<std::set<NodeTmp>> getMatchNodes();
protected:
std::vector<std::vector<NodeTmp>> mStartNodes;
std::vector<std::set<NodeTmp>> mMatchNodes;
};
}
#endif //AIDGE_MATCH_H_
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_NODEREGEX_H_
#define AIDGE_NODEREGEX_H_
#include <cstdlib>
#include <iostream>
#include <cstring>
#include "aidge/graph/Node.hpp"
namespace Aidge {
class NodeRegex
{
public:
std::string mCondition;
NodeRegex(const std::string c){
mCondition = c;
};
// Version 1 - Only test the type of the node (no need for a lexer)
// Input : Node_op
// Output : bool
// return mCondition == Node_op.type
bool _is(std::shared_ptr<Node> &Node_op);
bool isA(std::string NodeType);
};
}
#endif /* _AIDGE_NODEREGEX_H__ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_SEQSTM_H_
#define AIDGE_SEQSTM_H_
#include <iostream>
#include <map>
#include <regex>
#include <set>
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <string>
#include <utility>
#include <vector>
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge {
class SeqStm {
private:
const int mStmIdx;
const std::vector<std::vector<int>> mTransitionMatrix;
// str key of type like 'A' that ce use in the A->B .. extpr
const std::map<std::string, NodeRegex *> mNodesRegex;
// mTypeToIdxTransition.first = std::pair node_type , common_tag
// mTypeToIdxTransition.segond = idx in trans matrix
const std::map<NodeTypeKey, int> mTypeToIdxTransition;
int mActSt;
std::set<NodeTmp> mAllNodeValidated;
std::set<NodeTmp> mAllNodeTested;
std::set<std::pair<NodeTmp, std::string>> mAllCommonNode;
bool mStmIsValid;
std::pair<NodeRegex *, std::string> getNodeRegexAndCommonAt(int idxType);
/**
* @brief test the stm on a type
* @return the common tag
*/
std::string transitionOnNodeType(NodeType nodeType);
public:
SeqStm(const int mStmIdx,
const std::vector<std::vector<int>> &mTransitionMatrix,
const std::map<std::string, NodeRegex *> &mNodesRegex,
const std::map<NodeTypeKey, int> &mTypeToIdxTransition, int mActSt,
std::set<NodeTmp> mAllNodeValidated, std::set<NodeTmp> mAllNodeTested,
std::set<std::pair<NodeTmp, std::string>> mAllCommonNode,
bool mStmIsValid);
//////////////////////////////////////
// STM test
/////////////////////////////////////
/**
* @brief get if a st is a valide one
* @return bool
*/
bool isAValidSt(int st) {
std::size_t size = mTransitionMatrix.size();
return st == static_cast<int>(size - 1) ? true : false;
}
/**
* @brief true if the stm is blocked into st
* @return bool
*/
bool isStmBlocked() { return mActSt == -1 ? true : false; }
/**
* @brief true if the stm into valide st
* @return bool
*/
bool isValid() { return mStmIsValid; }
/////////////////////////////////////
// utile
/////////////////////////////////////
/**
* @brief extract from a node is type
* @return bool
*/
NodeType getTheNodeType(NodeTmp node);
void drawStm();
/////////////////////////////////////
// geter
/////////////////////////////////////
std::set<std::pair<NodeTmp, std::string>> getAllCommonNode() {
return mAllCommonNode;
}
std::set<NodeTmp> getAllNodeTested() { return mAllNodeTested; }
std::set<NodeTmp> getAllNodeValidated() { return mAllNodeValidated; }
SeqStm *duplicateStm();
int getStmIdx() { return mStmIdx; }
int getState() { return mActSt; }
//////////////////////////////////////////
// USE
//////////////////////////////////////////
/**
* @brief test the stm on a node
* @return pair new stm state, the common tag
*/
std::pair<int, std::string> testNode(const NodeTmp node);
};
} // namespace Aidge
#endif /* AIDGE_SEQSTM_H_ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_STMFACTORY_H_
#define AIDGE_STMFACTORY_H_
#include <map>
#include <utility>
#include <set>
#include <string>
#include <vector>
#include <iostream>
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <regex>
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge{
class StmFactory {
const std::map<std::string,NodeRegex*>& mNodesRegex;
std::size_t mCmptStm = 0;
public:
StmFactory(const std::map<std::string,NodeRegex*>& nodesRegex);
//StmFactory(){};
SeqStm* makeNewStm(const std::string& sequRegex);
SeqStm* duplicateStm(SeqStm* stm);
std::size_t getNumberOfStm(){
return mCmptStm;
}
private:
ParsingReturn initParsingSequRegex(const std::string& sequRegex);
std::vector<std::vector<int>> initTransitionMatrix(ParsingReturn& parsing);
};
}
#endif //AIDGE_STMFACTORY_H_
\ No newline at end of file
/**
* @file
* @brief
* @version file 1.0.0
* @author vl241552
* @copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory.
* All rights reserved.
*/
#ifndef _utile_H_
#define _utile_H_
#include <map>
#include "aidge/graph/Node.hpp"
#include <map>
namespace Aidge {
using NodeTmp = std::shared_ptr<Node>;
using NodeType = std::string;
using CommonTag = std::string;
using NodeTypeKey = std::pair<NodeType, CommonTag>;
// type def
// struct NodeTypeKey {
// NodeType nodeType;
// std::string commonTag;
// // for map find
// bool operator<(const NodeTypeKey& other) const {
// if (nodeType != other.nodeType or commonTag != other.commonTag) {
// return false;
// } else {
// return true;
// }
// }
// };
struct ParsingReturn {
std::map<NodeTypeKey, int> typeToIdxTransition;
std::vector<std::pair<NodeTypeKey, std::string>> transition;
};
} // namespace Aidge
#endif //_utile_H_
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graphmatching/GRegex.hpp"
namespace py = pybind11;
namespace Aidge {
void init_GRegex(py::module& m){
py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex", "GRegex class combines a Node Regex and a list of Graph Regex that together describes a graph pattern as a graph regular expression. GRegex find patterns in a given graph that matches the graph regular expression.")
.def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps"), R"mydelimiter(
Constructor of GRegex
:param nodesRegex: Describe the conditions an operator has to fulfill.
:type nodesRegex: Dict[str,:py:class:`aidge_core.NodeRegex`]
:param seqRegexps: Describe the graph topological pattern. List of Graph Regex as strings.
:type seqRegexps: List[str]
)mydelimiter")
.def("match", &GRegex::match, py::arg("graphToMatch"), R"mydelimiter(
Launch the graph matching algorithm on a given graph.
:param graphToMatch: The graph to perform the matching algorithm on.
:type graphToMatch: :py:class:`aidge_core.GraphView`
:returns: Matched graph patterns.
:rtype: :py:class:`aidge_core.Match`
)mydelimiter")
;
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/graphmatching/Match.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Match(py::module& m){
py::class_<Match, std::shared_ptr<Match>>(m, "Match", "Match class stores the matched patterns resulting from a graph matching query. A matched pattern is the combinaison of the graph pattern start nodes and the set of all the nodes in the matched pattern (including the start nodes)")
.def(py::init<>())
.def("get_nb_match", &Match::getNbMatch, R"mydelimiter(
:returns: The number of graph patterns matched
:rtype: int
)mydelimiter")
.def("get_start_nodes", &Match::getStartNodes, R"mydelimiter(
:returns: All matched graph patterns start nodes
:rtype: List[List[:py:class:`aidge_core.Nodes`]]
)mydelimiter")
.def("get_match_nodes", &Match::getMatchNodes, R"mydelimiter(
:returns: All matched graph patterns sets of matched nodes
:rtype: List[Set[:py:class:`aidge_core.Nodes`]]
)mydelimiter");
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/graphmatching/NodeRegex.hpp"
namespace py = pybind11;
namespace Aidge {
void init_NodeRegex(py::module& m){
py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex", "NodeRegex class describes a condition to test on any operator. Current version only supports testing the type of the operator.")
.def(py::init<const std::string>(), py::arg("condition"), R"mydelimiter(
Constructor of NodeRegex
:param condition: Condition to be fulfilled by an operator.
:type condition: str
)mydelimiter")
;
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graph/GraphView.hpp"
using namespace Aidge;
GRegex::GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ):mStmFab(nodesRegex){
//setup all the STM
for (const std::string& sequRegex : seqRegexps) {
mStmInit.push_back(mStmFab.makeNewStm(sequRegex));
}
}
bool GRegex::walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm){
//test if all stm type are in a valid state
std::vector<int> number_of_valid;
number_of_valid.resize(all_stm.size());
for (std::size_t i = 0; i < all_stm.size(); ++i) {
number_of_valid[i] = 0;
for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) {
SeqStm* stm = *it;
if (stm->isValid()){
number_of_valid[i] +=1;
}
}
}
for (std::size_t i = 0; i < number_of_valid.size(); ++i) {
if (number_of_valid[i] == 0) {
//std::cout << "NO MATCH at least one stm are not valid" << std::endl;
return false;
}
if (number_of_valid[i] > 1) {
//std::cout << "NO MATCH multiple brach match of stm (// quantification)" << std::endl;
return false;
}
}
return true;
}
bool GRegex::walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm){
std::set<NodeTmp> all_stm_node_tested;
std::set<NodeTmp> all_stm_node_validated;
for (std::size_t i = 0; i < all_stm.size(); ++i) {
//std::cout << "all stm index " << i << " on dimension 1 of size " << all_stm.size() <<std::endl;
for (std::size_t j = 0; j < all_stm[i].size(); ++j) {
//std::cout << "all stm index " << j << " on dimension 2 of size " << all_stm[i].size() <<std::endl;
std::set<NodeTmp> stm_node_tested = all_stm[i][j]->getAllNodeTested();
std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated();
all_stm_node_tested.insert(stm_node_tested.begin(), stm_node_tested.end());
all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end());
}
}
std::set<NodeTmp> test_but_not_valid;
for (const auto& x : all_stm_node_tested) {
if (all_stm_node_validated.find(x) == all_stm_node_validated.end()) {
test_but_not_valid.insert(x);
}
}
if (!test_but_not_valid.empty()) {
std::cout << "NO MATCH. The node(s) ";
for (const auto& x : test_but_not_valid) {
std::cout << x.get() << ", ";
}
std::cout << " have been tested but not validated." << std::endl;
return false;
}
return true;
}
bool GRegex::walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm){
std::map<NodeTmp, std::pair<std::string,int>> node_to_common_tag;
for (std::size_t i = 0; i < all_stm.size(); ++i) {
for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) {
SeqStm* stm = *it;
if (!stm->isValid()){
continue;
}
for (const auto& pair : stm->getAllCommonNode()) {
const NodeTmp node = pair.first;
const std::string common_tag = pair.second;
if (node_to_common_tag.find(node) != node_to_common_tag.end()) {
std::string tag = node_to_common_tag[node].first;
int& occurence = node_to_common_tag[node].second;
if (tag!=common_tag){
std::cout << "NO MATCH. The node " << node << " have two different tags "<< tag << " and " << common_tag << std::endl;
return false;
} else {
occurence += 1;
}
} else {
node_to_common_tag.insert(std::make_pair(node, std::make_pair(common_tag, 1)));
}
}
}
}
/*std::cout << "Node to common tag ";
for (const auto& x : node_to_common_tag) {
std::cout << "(" << x.first << ", " << "[" << x.second.first << ", " << x.second.second << "]" << ") ; ";
}
std::cout << std::endl;*/
for (const auto& pair : node_to_common_tag) {
const std::pair<std::string, int> tag_occurence_pair = pair.second;
if (tag_occurence_pair.second < 1){
//std::cout << "NO MATCH. The common tag " << tag_occurence_pair.first << " did not match " << std::endl;
return false;
}
}
return true;
}
std::set<NodeTmp> GRegex::get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm){
std::set<NodeTmp> all_stm_node_validated;
for (std::size_t i = 0; i < all_stm.size(); ++i) {
for (std::size_t j = 0; j < all_stm[i].size(); ++j) {
std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated();
all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end());
}
}
return all_stm_node_validated;
}
std::set<NodeTmp> GRegex::matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch){
std::set<NodeTmp> empty_set_return;
//ASSERT
if(startNodes.size() != mStmInit.size()){
throw std::runtime_error ("bad GRegex start nodes");
}
//init the walk
std::vector<std::vector<SeqStm*>> allStm;
std::vector<std::pair<NodeTmp,SeqStm*>> currentWalk;
for (SeqStm* seqStmPtr : mStmInit) {
SeqStm* newStm = mStmFab.duplicateStm(seqStmPtr);
std::size_t idxStart = newStm->getStmIdx();
currentWalk.push_back(std::make_pair(startNodes[idxStart],newStm));
allStm.push_back(std::vector<SeqStm*>());
}
//walk
while (currentWalk.size()!=0)
{
std::vector<std::pair<NodeTmp,SeqStm*>> newWalk;
for (const auto& pair : currentWalk) {
const NodeTmp node = pair.first;
SeqStm* stmPtr = pair.second;
std::pair<int,std::string> test = stmPtr->testNode(node);
int res = test.first;
std::string commonTag = test.second;
std::set<NodeTmp> next_nodes = graphToMatch->getChildren(node);
/*std::cout << "Next nodes : " ;
for (const auto& x : next_nodes) {
std::cout << x->name() << ", ";
}
std::cout << std::endl;*/
// Test Match
if (commonTag == "" && next_nodes.size() > 1) {
std::cout << "NO MATCH. The node " << node.get() << " is not common and has more than one child" << std::endl;
return empty_set_return;
}
// If there is no more nodes --> Archive the branch
if (res == -1 || next_nodes.empty()) {
int indexToInsert = stmPtr->getStmIdx();
allStm[indexToInsert].push_back(stmPtr);
//std::cout << "No more nodes --> STM archived : " << indexToInsert << std::endl;
continue; // TODEV : replace this with 'else' that encapsulate the rest of the function ?
}
bool first = true;
// Use an iterator to read through the next_nodes
std::set<NodeTmp>::iterator it;
for (it = next_nodes.begin(); it != next_nodes.end(); ++it) {
// Access the current element using the iterator
std::shared_ptr<Aidge::Node> next_node = *it;
if (first){
newWalk.push_back(std::make_pair(next_node, stmPtr));
first = false;
} else {
SeqStm* new_stmPtr = mStmFab.duplicateStm(stmPtr);
newWalk.push_back(std::make_pair(next_node, new_stmPtr));
}
}
}
currentWalk = newWalk;
}
//std::cout << "Walk finished" << std::endl;
if (!walk_validation_all_stm_are_valid(allStm)){
return empty_set_return;
}
//std::cout << "walk_validation_all_stm_are_valid finished" << std::endl;
if (!walk_validation_all_node_read_validate_by_one_stm(allStm)){
return empty_set_return;
}
//std::cout << "walk_validation_all_node_read_validate_by_one_stm finished" << std::endl;
if (!walk_validation_common_nodes_same_tag_for_all_stm(allStm)){
return empty_set_return;
}
//std::cout << "walk_validation_common_nodes_same_tag_for_all_stm finished" << std::endl;
//std::cout << "MATCH" << std::endl;
return get_all_validate_nodes(allStm);
}
Match GRegex::match(const std::shared_ptr<GraphView> graphToMatch){
//std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches;
//std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches;
Match matches;
std::size_t nbStartNodes = mStmInit.size();
std::set<NodeTmp> allNodes = graphToMatch->getNodes();
std::size_t nbAllNodes = allNodes.size();
std::vector<std::size_t> indices(nbStartNodes, 0);
while (true) {
// Generate all permutations of the current combination
do {
std::vector<NodeTmp> startNodes;
//std::cout <<"start nodes :";
for (std::size_t i = 0; i < nbStartNodes; ++i) {
auto it = std::begin(allNodes);
std::advance(it, indices[i]);
//std::cout << (*it).get() << " ";
startNodes.push_back(*it);
}
//std::cout <<"\n";
std::set<NodeTmp> match = matchFromStartNodes(startNodes, graphToMatch);
//std::cout << "match size : " << match.size() << " ";
if(match.size() != 0){
//matches.push_back(std::make_pair(startNodes,match));
//matches.insert(std::make_pair(startNodes,match));
matches.insert(startNodes,match);
}
} while (std::next_permutation(indices.begin(), indices.end()));
// Generate the next combination with replacement
std::size_t i = nbStartNodes - 1;
while (true) {
if (indices[i] < nbAllNodes - 1) {
++indices[i];
break;
}
if (i == 0) {
return matches;
}
--i;
}
std::fill(indices.begin() + i + 1, indices.end(), indices[i]);
}
return matches;
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/Match.hpp"
using namespace Aidge;
Match::Match(){
//ctr
}
size_t Match::getNbMatch(){
assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted");
return mStartNodes.size();
}
void Match::insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes){
assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted");
mStartNodes.push_back(startnodes);
mMatchNodes.push_back(matchnodes);
}
std::vector<std::vector<NodeTmp>> Match::getStartNodes(){
return mStartNodes;
}
std::vector<std::set<NodeTmp>> Match::getMatchNodes(){
return mMatchNodes;
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/NodeRegex.hpp"
// Verification done by the Attribute system
// Version 1 - Only test the type of the node (no need for a lexer)
// Input : Node_op
// Output : bool
// return mCondition == Node_op.type
bool Aidge::NodeRegex::_is(std::shared_ptr<Node> &Node_op){
std::string NodeType = Node_op->type();
return strcmp(NodeType.c_str(), mCondition.c_str()) == 0;
}
bool Aidge::NodeRegex::isA(std::string NodeType){
return strcmp(NodeType.c_str(), mCondition.c_str()) == 0;
}
// Version 2 - Test the node to an advanced condition
// Input : Node_op
// Output : bool
// return mCondition applied on Node
/**bool NodeRegex::_is(string &Node_op){
// Parsing the condition is done in the initialization of the NodeRegex
// assert attributes exist in the node with the attribute function hasAttr()
// get the attributes
}*/
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/SeqStm.hpp"
using namespace Aidge;
///////////////////////////////////////////////////////
SeqStm::SeqStm(
const int stmIdx,
const std::vector<std::vector<int>>& transitionMatrix,
const std::map<std::string,NodeRegex*>& nodesRegex,
const std::map<NodeTypeKey,int>& typeToIdxTransition,
int actSt,
std::set<NodeTmp> allNodeValidated,
std::set<NodeTmp> allNodeTested,
std::set<std::pair<NodeTmp,std::string>> allCommonNode,
bool stmIsValid):mStmIdx(stmIdx),
mTransitionMatrix(transitionMatrix),
mNodesRegex(nodesRegex),
mTypeToIdxTransition(typeToIdxTransition)
{
//assert
if (transitionMatrix.size() == 0){
throw std::runtime_error ("no transitionMatrix");
}
if(transitionMatrix[0].size() == 0 || transitionMatrix[0].size() != typeToIdxTransition.size()){
throw std::runtime_error ("bad transitionMatrix");
}
int size = static_cast<int>(transitionMatrix.size());
if (actSt >= size){
throw std::runtime_error ("bad actSt");
}
mActSt = actSt;
mAllNodeValidated = allNodeValidated;
mAllNodeTested = allNodeTested;
mAllCommonNode = allCommonNode;
mStmIsValid = stmIsValid;
}
SeqStm* SeqStm::duplicateStm(){
//deep copy of the set
// std::set<Node> cAllNodeValidated(mAllNodeValidated.begin(), mAllNodeValidated.end());
// std::set<Node> cAllNodeTested(mAllNodeTested.begin(), mAllNodeTested.end());
// std::set<std::pair<Node,std::string>> cAllCommonNode;
// for (const auto& p : mAllCommonNode) {
// cAllCommonNode.insert(p);
// }
auto newStm = new SeqStm(
mStmIdx,
mTransitionMatrix,
mNodesRegex,
mTypeToIdxTransition,
mActSt,
mAllNodeValidated,
mAllNodeTested,
mAllCommonNode,
mStmIsValid
);
return newStm;
}
std::pair<NodeRegex*,std::string> SeqStm::getNodeRegexAndCommonAt(int idxType)
{
//std::cout << "!" << idxType << "\n";
for (auto const& x : mTypeToIdxTransition)
{
//x.second is the value : idx in mTransitionMatrix for the type
//x.first pair of the node regex class and a string that is the common tag '',#,#n
if (x.second == idxType ){
if (mNodesRegex.find(x.first.first) != mNodesRegex.end()){
return std::make_pair(mNodesRegex.find(x.first.first)->second, x.first.second);
}else{
throw std::runtime_error ("a type is not define in NodesRegex");
}
}
}
throw std::runtime_error ("bad idx in mNodesRegex");
return std::make_pair(nullptr,nullptr);
}
NodeType SeqStm::getTheNodeType(NodeTmp node)
{
//the node is a str of '{type}{idx}' and we juste want type
// // std::regex re("([a-zA-Z]+)[0-9]+");
// // std::smatch match;
// // if (std::regex_search(node, match, re) == true) {
// // return match.str(1);
// // }
// // throw std::runtime_error ("Type node not found");
// // return "";
//return node->name();
return node->type();
}
std::string SeqStm::transitionOnNodeType(NodeType nodeType){
if (!isStmBlocked()){
int idxType = 0;
for (auto & nextSt : mTransitionMatrix[mActSt]) {
// There are a next step for this type
//std::cout << "transition matrix next state -> "<< nextSt<<"\n" ;
if (nextSt != -1){
//std::cout << "next -> "<< nextSt<< " "<< isAValidSt(nextSt) <<"\n" ;
auto nodeRegex = getNodeRegexAndCommonAt(idxType);
//std::cout << "-> "<< nodeRegex.second<<"\n" ;
if (nodeRegex.first->isA(nodeType)){
//std::cout << "nodetype tested !"<<"\n" ;
if(isAValidSt(nextSt)){
//std::cout << "Valid state !"<<"\n" ;
mStmIsValid = true;
}
mActSt = nextSt;
return nodeRegex.second;
}
}
idxType += 1;
}
mActSt =-1;
}
return "";
}
std::pair<int,std::string> SeqStm::testNode(const NodeTmp node){
std::string commonTag = "";
//std::cout << "0\n" ;
if (!isStmBlocked()){
bool isNextStEnd = std::all_of(mTransitionMatrix[mActSt].begin(), mTransitionMatrix[mActSt].end(), [&](int x){ return x == -1; });
//std::cout << "1:"<< isNextStEnd <<"\n" ;
//if the next state if full of -1 can we relay add the node test to all node tested
// oker y test it but it sure that not be valid
if(!isNextStEnd){
mAllNodeTested.insert(node);
}
//std::cout << "2\n" ;
//recurtion avoidance
if(mAllNodeValidated.find(node) == mAllNodeValidated.end()){
NodeType nodeType = getTheNodeType(node);
//std::cout << "3 " << nodeType << "\n" ;
commonTag = transitionOnNodeType(nodeType);
//after the transition test, if the node is != -1 the node is valid for the stm
//std::cout << " mActSt = " << mActSt << "\n" ;
if( mActSt != -1 ){
mAllNodeValidated.insert(node);
}
}else{
mActSt = -1;
}
}
if(commonTag != ""){
mAllCommonNode.insert(std::make_pair(node,commonTag));
}
return std::make_pair(mActSt,commonTag);
}
void SeqStm::drawStm(){
//mTransitionMatrix
// Find the maximum width of each column
std::vector<std::size_t> max_widths(mTransitionMatrix[0].size(), 0);
for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i)
{
for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j)
{
std::size_t width = std::to_string(mTransitionMatrix[i][j]).length();
if (width > max_widths[j])
{
max_widths[j] = width;
}
}
}
// Print the vector with aligned columns
for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i)
{
for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j)
{
int i_int = static_cast<int>(i);
if (mActSt == -1 ){
if(mStmIsValid){
std::cout << "\033[48;5;40m";
}else{
std::cout << "\033[48;5;9m";
}
}
else if (mActSt == i_int){
std::cout << "\033[48;5;30m";
}else{
std::cout << "\033[48;5;27m";
}
// Pad the value with spaces to align it with the maximum width
std::size_t width = std::to_string(mTransitionMatrix[i][j]).length();
std::string padding(max_widths[j] - width, ' ');
std::cout << padding << mTransitionMatrix[i][j] << " ";
std::cout << "\033[0m";
}
std::cout << "\n";
}
std::cout << "mAllNodeTested : ";
for (const auto& x : mAllNodeTested) {
std::cout << x << ", ";
}
std::cout << "\n";
std::cout << "mAllNodeValidated : ";
for (const auto& x : mAllNodeValidated) {
std::cout << x << ", ";
}
std::cout << "\n";
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/StmFactory.hpp"
using namespace Aidge;
StmFactory::StmFactory(const std::map<std::string, NodeRegex *> &nodesRegex)
: mNodesRegex(nodesRegex) {}
SeqStm *StmFactory::duplicateStm(SeqStm *stm) { return stm->duplicateStm(); }
SeqStm *StmFactory::makeNewStm(const std::string &sequRegex) {
ParsingReturn parsing = initParsingSequRegex(sequRegex);
std::vector<std::vector<int>> transitionMatrix =
initTransitionMatrix(parsing);
std::set<NodeTmp> allNodeValidated;
std::set<NodeTmp> allNodeTested;
std::set<std::pair<NodeTmp, std::string>> allCommonNode;
SeqStm *newStm = new SeqStm(static_cast<int>(mCmptStm), transitionMatrix, mNodesRegex,
parsing.typeToIdxTransition, 0, allNodeValidated,
allNodeTested, allCommonNode, false);
mCmptStm += 1;
return newStm;
}
ParsingReturn StmFactory::initParsingSequRegex(const std::string &sequRegex) {
std::string toMatch;
std::regex re("\\s*([A-Za-z]+)(#\\d*)?([+*])?\\s*(->|;)");
std::smatch matches;
int idxType = 0;
// return
ParsingReturn parsing;
// std::map<std::pair<NodeType,std::string>,int> typeToIdxTransition;
// std::vector<std::pair<std::pair<NodeType,std::string>,std::string>>
// transition;
// assert
std::map<NodeType, std::string> assertCommonNodeTypes;
for (std::size_t i = 0; i < sequRegex.length(); i++) {
toMatch += sequRegex[i];
if (std::regex_match(toMatch, matches, re)) {
std::string type = matches.str(1);
std::string commonTag = matches.str(2);
std::string quantification = matches.str(3);
if ((commonTag != "") && (quantification != "")) {
throw std::runtime_error("bad commonTag and quantification");
}
// make the typeToIdxTransition
NodeTypeKey typeTag = std::make_pair(type, commonTag);
/*std::cout << " typeTag: " << type << " " << commonTag
<< parsing.typeToIdxTransition.size() << std::endl;*/
if (parsing.typeToIdxTransition.find(typeTag) ==
parsing.typeToIdxTransition.end()) {
parsing.typeToIdxTransition[typeTag] = idxType;
idxType += 1;
}
////////////////////////////////////////////////////////////
// ASSERT
// SAME Common node in the sequ
if (commonTag != "") {
if (assertCommonNodeTypes.find(type) != assertCommonNodeTypes.end()) {
if (assertCommonNodeTypes[type] == commonTag) {
throw std::runtime_error("same common node in the sequ regex");
}
} else {
assertCommonNodeTypes[type] = commonTag;
}
}
// save all transition
parsing.transition.push_back(std::make_pair(typeTag, quantification));
/*std::cout << "Match found: " << matches.str() << std::endl;
std::cout << "Type: " << matches.str(1) << std::endl;
std::cout << "Common tag: " << matches.str(2) << std::endl;
std::cout << "Quantification: " << matches.str(3) << std::endl;*/
toMatch = "";
}
}
if (parsing.transition.size() == 0) {
throw std::runtime_error("Bad Parsing SequRegex ");
}
return parsing;
}
std::vector<std::vector<int>>
StmFactory::initTransitionMatrix(ParsingReturn &parsing) {
// std::pair<NodeTypeKey,std::string>
std::vector<std::vector<int>> transitionMatrix;
std::size_t numberOfType = parsing.typeToIdxTransition.size();
if (numberOfType == 0) {
throw std::runtime_error("Bad number Of Type ");
}
// init start st
transitionMatrix.push_back(std::vector<int>(numberOfType, -1));
std::size_t idxTransition = 0;
int idxState = 0;
for (const auto &pair : parsing.transition) {
const NodeTypeKey &nodeTypeKey = pair.first;
const std::string &quant = pair.second;
/*std::cout << "Key: {" << nodeTypeKey.first << ", " << nodeTypeKey.second
<< "}, Value: " << quant << std::endl;
std::cout << "idxState " << idxState << " TM: " << transitionMatrix.size()
<< std::endl;*/
std::size_t idxType = parsing.typeToIdxTransition[nodeTypeKey];
/*std::cout << "idxType " << idxType << " TM: " << transitionMatrix[0].size()
<< "type" << numberOfType << std::endl;*/
if (quant == "*") {
transitionMatrix[idxTransition][idxType] = idxState;
} else if (quant == "+") {
idxState += 1;
transitionMatrix[idxTransition][idxType] = idxState;
transitionMatrix.push_back(std::vector<int>(numberOfType, -1));
idxTransition += 1;
transitionMatrix[idxTransition][idxType] = idxState;
} else {
idxState += 1;
transitionMatrix[idxTransition][idxType] = idxState;
transitionMatrix.push_back(std::vector<int>(numberOfType, -1));
idxTransition += 1;
}
}
return transitionMatrix;
}
\ No newline at end of file
...@@ -21,9 +21,7 @@ ...@@ -21,9 +21,7 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
//Graph Regex //Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp" #include "aidge/graphRegex/GraphRegex.hpp"
...@@ -32,21 +30,7 @@ using namespace Aidge; ...@@ -32,21 +30,7 @@ using namespace Aidge;
void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm){ void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm){
// assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// // Assert the nodes types are correct to be fused
// std::shared_ptr<Node> conv;
// std::shared_ptr<Node> batchnorm;
// for (const auto& element : nodes) {
// assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace");
// if (element->type() == "Conv"){
// conv = element;
// }
// else if (element->type() == "BatchNorm") {
// batchnorm = element;
// }
// }
// TODO : check if batchnorm is the only child of the Conv or FC
std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second);
std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second);
...@@ -146,20 +130,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){ ...@@ -146,20 +130,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){
} }
void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){
// std::map<std::string,NodeRegex*> nodesRegex ;
// nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm");
// nodesRegex["Conv"] = new NodeRegex("Conv");
// nodesRegex["FC"] = new NodeRegex("FC");
// std::vector<std::string> seqRegex;
// seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC)
// GRegex GReg(nodesRegex, seqRegex);
// Match matches = GReg.match(graphView);
// std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
// for (size_t i = 0; i < matches.getNbMatch(); ++i) {
// fuseBatchNorm(matchNodes[i]);
// }
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'"); regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'");
......
...@@ -22,9 +22,6 @@ ...@@ -22,9 +22,6 @@
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
//Graph Regex //Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp" #include "aidge/graphRegex/GraphRegex.hpp"
...@@ -84,9 +81,6 @@ void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){ ...@@ -84,9 +81,6 @@ void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){
void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){
// std::map<std::string,NodeRegex*> nodesRegex ;
// nodesRegex["MatMul"] = new NodeRegex("MatMul");
// nodesRegex["Add"] = new NodeRegex("Add");
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Add","getType($) =='Add'"); regex->setNodeKey("Add","getType($) =='Add'");
...@@ -97,26 +91,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ ...@@ -97,26 +91,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){
fuseMulAdd(solution); fuseMulAdd(solution);
// // solution->at("MatMul");
// // solution->at("Add");
// assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n");
// assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n");
// for (const auto& matmul : solution->at("MatMul")) {
// for (const auto& add : solution->at("Add")) {
// fuseMulAdd(matmul,add);
// }
// }
} }
// std::vector<std::string> seqRegex;
// seqRegex.push_back("MatMul -> Add;");
// GRegex GReg(nodesRegex, seqRegex);
// Match matches = GReg.match(graphView);
// std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
// for (size_t i = 0; i < matches.getNbMatch(); ++i) {
// fuseMulAdd(matchNodes[i]);
// }
} }
...@@ -15,24 +15,14 @@ ...@@ -15,24 +15,14 @@
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/utils/Recipies.hpp" #include "aidge/utils/Recipies.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
//Graph Regex //Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp" #include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge { namespace Aidge {
void removeFlatten(std::shared_ptr<Node> flatten) { void removeFlatten(std::shared_ptr<Node> flatten) {
// assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// std::shared_ptr<Node> flatten;
// for (const auto& element : nodes) {
// assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
// if (element->type() == "Flatten"){
// flatten = element;
// }
// }
GraphView::replace({flatten}, {}); GraphView::replace({flatten}, {});
} }
...@@ -49,18 +39,7 @@ namespace Aidge { ...@@ -49,18 +39,7 @@ namespace Aidge {
void removeFlatten(std::shared_ptr<GraphView> graphView){ void removeFlatten(std::shared_ptr<GraphView> graphView){
// std::map<std::string,NodeRegex*> nodesRegex ;
// nodesRegex["Flatten"] = new NodeRegex("Flatten");
// nodesRegex["FC"] = new NodeRegex("FC");
// std::vector<std::string> seqRegex;
// seqRegex.push_back("Flatten->FC;");
// GRegex GReg(nodesRegex, seqRegex);
// Match matches = GReg.match(graphView);
// std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
// for (size_t i = 0; i < matches.getNbMatch(); ++i) {
// removeFlatten(matchNodes[i]);
// }
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Flatten","getType($) =='Flatten'"); regex->setNodeKey("Flatten","getType($) =='Flatten'");
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <iostream>
#include <map>
#include <memory>
#include <vector>
#include <utility>
#include <cassert>
#include <catch2/catch_test_macros.hpp>
//test
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/StmFactory.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/Match.hpp"
//use
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/graph/GraphView.hpp"
using namespace Aidge;
TEST_CASE("Create good init GRegex", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"A","B","C"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("A->B;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
// Perform tests
REQUIRE(GReg.getStmInit().size() == 1);
REQUIRE(GReg.getStmFab().getNumberOfStm() == 1);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"Conv","BN","ReLU"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("Conv->BN->ReLU;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1);
std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1);
std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1);
std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1);
std::shared_ptr<Node> Random2 = GenericOperator("Random2", 1, 1, 1);
g1->add(Conv1);
g1->addChild(BN1, Conv1);
g1->addChild(ReLU1, BN1);
g1->addChild(Random, ReLU1);
//g1->addChild(BN1, Random2);
std::vector<std::shared_ptr<Node>> startNodes1;
std::set<std::shared_ptr<Node>> result;
startNodes1.push_back(Conv1);
result = GReg.matchFromStartNodes(startNodes1, g1);
std::set<std::shared_ptr<Node>> true_result;
true_result.insert(Conv1);
true_result.insert(BN1);
true_result.insert(ReLU1);
// Perform tests
REQUIRE(result == true_result);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"Add","FC","Conv"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("Add#->Conv;");
seqRegex.push_back("Add#->FC;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
// Instanciate a graphView
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1);
std::shared_ptr<Node> Add1 = GenericOperator("Add", 1, 1, 1);
std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1);
std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1);
std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1);
std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1);
std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1);
g1->add(Random0);
g1->addChild(Add1, Random0);
g1->addChild(Conv1, Add1);
g1->addChild(BN1, Conv1);
g1->addChild(ReLU1, BN1);
g1->addChild(FC1, Add1);
g1->addChild(Random, FC1);
// Test 1 : Find the match
std::vector<std::shared_ptr<Node>> startNodes;
std::set<std::shared_ptr<Node>> result;
startNodes.push_back(Add1);
startNodes.push_back(Add1);
result = GReg.matchFromStartNodes(startNodes, g1);
std::set<std::shared_ptr<Node>> true_result;
true_result.insert(Add1);
true_result.insert(Conv1);
true_result.insert(FC1);
// Test 2 : Return an empty set when the start nodes are wrong
std::vector<std::shared_ptr<Node>> wrong_startNodes;
std::set<std::shared_ptr<Node>> wrong_start_result;
std::set<std::shared_ptr<Node>> empty_result;
wrong_startNodes.push_back(Random0);
wrong_startNodes.push_back(Random0);
wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1);
// Perform tests
REQUIRE(result == true_result);
REQUIRE(wrong_start_result == empty_result);
for (const std::string& key : nodeTypeKey) {
delete nodesRegex[key];
}
}
/*
TEST_CASE("Function matchFromStartNodes | Match a sequence with quantifier ", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"FC"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("FC+;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
// Instanciate a graphView
std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1);
std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1);
std::shared_ptr<Node> FC2 = GenericOperator("FC", 1, 1, 1);
std::shared_ptr<Node> FC3 = GenericOperator("FC", 1, 1, 1);
std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1);
g1->add(Random0);
g1->addChild(FC1, Random0);
g1->addChild(FC2, FC1);
g1->addChild(FC3, FC2);
g1->addChild(ReLU1, FC3);
// Test 1 : Find the match
std::vector<std::shared_ptr<Node>> startNodes;
std::set<std::shared_ptr<Node>> result;
startNodes.push_back(FC1);
result = GReg.matchFromStartNodes(startNodes, g1);
std::set<std::shared_ptr<Node>> true_result;
true_result.insert(FC1);
true_result.insert(FC2);
true_result.insert(FC3);
// Test 2 : Return an empty set when the start nodes are wrong
std::vector<std::shared_ptr<Node>> wrong_startNodes;
std::set<std::shared_ptr<Node>> wrong_start_result;
std::set<std::shared_ptr<Node>> empty_result;
wrong_startNodes.push_back(Random0);
wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1);
// Perform tests
REQUIRE(result == true_result);
REQUIRE(wrong_start_result == empty_result);
}
*/
TEST_CASE("Function match | ALL matches of Nodes sequence", "[GRegex]") {
// init all input for GRegex
// Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex
// Sequential Regex vector : std::vector<std::string>& seqRegexps
// init the Nodes Regex map
std::map<std::string,NodeRegex*> nodesRegex ;
std::vector<std::string> nodeTypeKey {"GEMM"};
for (const std::string& key : nodeTypeKey) {
nodesRegex[key] = new NodeRegex(key);
}
// init the Sequential Regex vector
std::vector<std::string> seqRegex;
seqRegex.push_back("GEMM;");
// Instanciate a GRegex
GRegex GReg(nodesRegex, seqRegex);
//init the input graph
std::shared_ptr<GraphView> graphToMatch = std::make_shared<GraphView>("TestGraph");
std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1);
std::shared_ptr<Node> GEMM1 = GenericOperator("GEMM", 1, 1, 1);
std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1);
std::shared_ptr<Node> GEMM2 = GenericOperator("GEMM", 1, 1, 1);
std::shared_ptr<Node> GEMM3 = GenericOperator("GEMM", 1, 1, 1);
std::shared_ptr<Node> ReLU2 = GenericOperator("ReLU", 1, 1, 1);
std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1);
graphToMatch->add(Random0);
graphToMatch->addChild(GEMM1, Random0);
graphToMatch->addChild(ReLU1, GEMM1);
graphToMatch->addChild(GEMM2, ReLU1);
graphToMatch->addChild(GEMM3, GEMM2);
graphToMatch->addChild(ReLU2, GEMM3);
graphToMatch->addChild(Random, ReLU2);
//std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch);
//std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch);
Match matches = GReg.match(graphToMatch);
size_t nb = matches.getNbMatch();
std::vector<std::vector<NodeTmp>> gm_startnodes = matches.getStartNodes();
std::vector<std::set<NodeTmp>> gm_matchnodes = matches.getMatchNodes();
std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs;
for (size_t i = 0; i < nb; ++i) {
matchs.insert(std::make_pair(gm_startnodes[i], gm_matchnodes[i]));
}
//std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ;
std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ;
// Carefull : as the assert is on a vector, the Order of match matters
std::vector<NodeTmp> startNode = {GEMM1};
std::set<NodeTmp> matchNode = {GEMM1};
//toMatchs.push_back(std::make_pair(startNode,matchNode));
toMatchs.insert(std::make_pair(startNode,matchNode));
std::vector<NodeTmp> startNode2 = {GEMM2};
std::set<NodeTmp> matchNode2 = {GEMM2};
//toMatchs.push_back(std::make_pair(startNode2,matchNode2));
toMatchs.insert(std::make_pair(startNode2,matchNode2));
std::vector<NodeTmp> startNode3 = {GEMM3};
std::set<NodeTmp> matchNode3 = {GEMM3};
//toMatchs.push_back(std::make_pair(startNode3,matchNode3));
toMatchs.insert(std::make_pair(startNode3,matchNode3));
REQUIRE(matchs == toMatchs);
REQUIRE(nb == 3);
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <iostream>
#include <map>
#include <memory>
#include <cassert>
#include <catch2/catch_test_macros.hpp>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/operator/GenericOperator.hpp"
using namespace Aidge;
TEST_CASE("Create Noderegex", "[Noderegex]") {
std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("conv");
}
TEST_CASE("Test _is function", "[Noderegex]") {
// Create Noderegex with only condition on the name of the Node
// Create several operators to pass into Noderegex _is function
// Assert Noderegex._is(operators) are correct
std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("Conv");
std::shared_ptr<Node> Conv = GenericOperator("Conv", 1, 1, 1);
std::shared_ptr<Node> FC = GenericOperator("FC", 1, 1, 1);
REQUIRE(nr->_is(Conv) == true);
REQUIRE(nr->_is(FC) == false);
REQUIRE(nr->isA("Conv") == true);
REQUIRE(nr->isA("FC") == false);
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment