From 4540e79c2833c7c5dd4ae49f6d62cc9b518cf4fb Mon Sep 17 00:00:00 2001 From: vl241552 <vincent.lorrain@cea.fr> Date: Mon, 13 Nov 2023 13:13:40 +0000 Subject: [PATCH] clean and remove old graph matching --- include/aidge/aidge.hpp | 1 - include/aidge/graphmatching/GRegex.hpp | 63 ---- include/aidge/graphmatching/Match.hpp | 44 --- include/aidge/graphmatching/NodeRegex.hpp | 41 --- include/aidge/graphmatching/SeqStm.hpp | 127 ------- include/aidge/graphmatching/StmFactory.hpp | 55 --- include/aidge/graphmatching/Utile.hpp | 50 --- .../graphmatching/pybind_GRegex.cpp | 42 --- python_binding/graphmatching/pybind_Match.cpp | 34 -- .../graphmatching/pybind_NodeRegex.cpp | 28 -- src/graphmatching/GRegex.cpp | 301 ----------------- src/graphmatching/Match.cpp | 37 -- src/graphmatching/NodeRegex.cpp | 46 --- src/graphmatching/SeqStm.cpp | 247 -------------- src/graphmatching/StmFactory.cpp | 150 --------- src/recipies/FuseBatchNorm.cpp | 35 +- src/recipies/FuseMulAdd.cpp | 23 -- src/recipies/RemoveFlatten.cpp | 27 +- unit_tests/graphMatching/Test_GRegex.cpp | 318 ------------------ unit_tests/graphMatching/Test_NodeRegex.cpp | 44 --- unit_tests/graphMatching/Test_SeqStm.cpp | 167 --------- unit_tests/graphMatching/Test_StmFactory.cpp | 204 ----------- 22 files changed, 6 insertions(+), 2078 deletions(-) delete mode 100644 include/aidge/graphmatching/GRegex.hpp delete mode 100644 include/aidge/graphmatching/Match.hpp delete mode 100644 include/aidge/graphmatching/NodeRegex.hpp delete mode 100755 include/aidge/graphmatching/SeqStm.hpp delete mode 100644 include/aidge/graphmatching/StmFactory.hpp delete mode 100644 include/aidge/graphmatching/Utile.hpp delete mode 100644 python_binding/graphmatching/pybind_GRegex.cpp delete mode 100644 python_binding/graphmatching/pybind_Match.cpp delete mode 100644 python_binding/graphmatching/pybind_NodeRegex.cpp delete mode 100644 src/graphmatching/GRegex.cpp delete mode 100644 src/graphmatching/Match.cpp delete mode 100644 src/graphmatching/NodeRegex.cpp delete mode 100755 src/graphmatching/SeqStm.cpp delete mode 100644 src/graphmatching/StmFactory.cpp delete mode 100644 unit_tests/graphMatching/Test_GRegex.cpp delete mode 100644 unit_tests/graphMatching/Test_NodeRegex.cpp delete mode 100644 unit_tests/graphMatching/Test_SeqStm.cpp delete mode 100644 unit_tests/graphMatching/Test_StmFactory.cpp diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index cc8763580..65dc7f70c 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -20,7 +20,6 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/OpArgs.hpp" -#include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/Match.hpp" #include "aidge/graphmatching/NodeRegex.hpp" #include "aidge/graphmatching/SeqStm.hpp" diff --git a/include/aidge/graphmatching/GRegex.hpp b/include/aidge/graphmatching/GRegex.hpp deleted file mode 100644 index fd2d0c52a..000000000 --- a/include/aidge/graphmatching/GRegex.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - - -#ifndef AIDGE_GREGEX_H_ -#define AIDGE_GREGEX_H_ - -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <regex> -#include <memory> // for shared_ptr -#include <algorithm> // for next_permutation - -#include "aidge/graphmatching/Utile.hpp" -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Match.hpp" - - -namespace Aidge{ - -class GRegex { -// __init__(self,nodes_regex:dict,seq_regexps:list) - - StmFactory mStmFab; - std::vector<SeqStm*> mStmInit; - -public: - GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ); - - std::set<NodeTmp> matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch); - - bool walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm); - - bool walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm); - - bool walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm); - - std::set<NodeTmp> get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm); - - std::vector<SeqStm*> getStmInit() const { - return mStmInit; - } - - StmFactory getStmFab() const { - return mStmFab; - } - - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> match(const std::shared_ptr<GraphView> graphToMatch); - Match match(const std::shared_ptr<GraphView> graphToMatch); - -}; - -} -#endif //AIDGE_GREGEX_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/Match.hpp b/include/aidge/graphmatching/Match.hpp deleted file mode 100644 index fc617a228..000000000 --- a/include/aidge/graphmatching/Match.hpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_MATCH_H_ -#define AIDGE_MATCH_H_ - -#include <vector> -#include <set> -#include <iostream> -#include <cassert> -#include "aidge/graphmatching/Utile.hpp" - - -namespace Aidge{ - -class Match { - -public: - Match(); - - size_t getNbMatch(); - - void insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes); - - std::vector<std::vector<NodeTmp>> getStartNodes(); - - std::vector<std::set<NodeTmp>> getMatchNodes(); - -protected: - std::vector<std::vector<NodeTmp>> mStartNodes; - std::vector<std::set<NodeTmp>> mMatchNodes; - -}; - -} -#endif //AIDGE_MATCH_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/NodeRegex.hpp b/include/aidge/graphmatching/NodeRegex.hpp deleted file mode 100644 index 10ba72258..000000000 --- a/include/aidge/graphmatching/NodeRegex.hpp +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_NODEREGEX_H_ -#define AIDGE_NODEREGEX_H_ -#include <cstdlib> -#include <iostream> -#include <cstring> -#include "aidge/graph/Node.hpp" - - -namespace Aidge { - -class NodeRegex -{ - public: - std::string mCondition; - - NodeRegex(const std::string c){ - mCondition = c; - }; - - // Version 1 - Only test the type of the node (no need for a lexer) - // Input : Node_op - // Output : bool - // return mCondition == Node_op.type - bool _is(std::shared_ptr<Node> &Node_op); - bool isA(std::string NodeType); -}; - -} - -#endif /* _AIDGE_NODEREGEX_H__ */ \ No newline at end of file diff --git a/include/aidge/graphmatching/SeqStm.hpp b/include/aidge/graphmatching/SeqStm.hpp deleted file mode 100755 index 0823b5fc0..000000000 --- a/include/aidge/graphmatching/SeqStm.hpp +++ /dev/null @@ -1,127 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_SEQSTM_H_ -#define AIDGE_SEQSTM_H_ - -#include <iostream> -#include <map> -#include <regex> -#include <set> -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <string> -#include <utility> -#include <vector> - - -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Utile.hpp" - - -namespace Aidge { - -class SeqStm { - -private: - const int mStmIdx; - const std::vector<std::vector<int>> mTransitionMatrix; - // str key of type like 'A' that ce use in the A->B .. extpr - const std::map<std::string, NodeRegex *> mNodesRegex; - // mTypeToIdxTransition.first = std::pair node_type , common_tag - // mTypeToIdxTransition.segond = idx in trans matrix - const std::map<NodeTypeKey, int> mTypeToIdxTransition; - - int mActSt; - std::set<NodeTmp> mAllNodeValidated; - std::set<NodeTmp> mAllNodeTested; - std::set<std::pair<NodeTmp, std::string>> mAllCommonNode; - bool mStmIsValid; - - std::pair<NodeRegex *, std::string> getNodeRegexAndCommonAt(int idxType); - - /** - * @brief test the stm on a type - * @return the common tag - */ - std::string transitionOnNodeType(NodeType nodeType); - -public: - SeqStm(const int mStmIdx, - const std::vector<std::vector<int>> &mTransitionMatrix, - const std::map<std::string, NodeRegex *> &mNodesRegex, - const std::map<NodeTypeKey, int> &mTypeToIdxTransition, int mActSt, - std::set<NodeTmp> mAllNodeValidated, std::set<NodeTmp> mAllNodeTested, - std::set<std::pair<NodeTmp, std::string>> mAllCommonNode, - bool mStmIsValid); - - ////////////////////////////////////// - // STM test - ///////////////////////////////////// - - /** - * @brief get if a st is a valide one - * @return bool - */ - bool isAValidSt(int st) { - std::size_t size = mTransitionMatrix.size(); - return st == static_cast<int>(size - 1) ? true : false; - } - - /** - * @brief true if the stm is blocked into st - * @return bool - */ - bool isStmBlocked() { return mActSt == -1 ? true : false; } - - /** - * @brief true if the stm into valide st - * @return bool - */ - bool isValid() { return mStmIsValid; } - - ///////////////////////////////////// - // utile - ///////////////////////////////////// - /** - * @brief extract from a node is type - * @return bool - */ - NodeType getTheNodeType(NodeTmp node); - - void drawStm(); - ///////////////////////////////////// - // geter - ///////////////////////////////////// - - std::set<std::pair<NodeTmp, std::string>> getAllCommonNode() { - return mAllCommonNode; - } - std::set<NodeTmp> getAllNodeTested() { return mAllNodeTested; } - - std::set<NodeTmp> getAllNodeValidated() { return mAllNodeValidated; } - - SeqStm *duplicateStm(); - - int getStmIdx() { return mStmIdx; } - - int getState() { return mActSt; } - ////////////////////////////////////////// - // USE - ////////////////////////////////////////// - /** - * @brief test the stm on a node - * @return pair new stm state, the common tag - */ - std::pair<int, std::string> testNode(const NodeTmp node); -}; -} // namespace Aidge - -#endif /* AIDGE_SEQSTM_H_ */ \ No newline at end of file diff --git a/include/aidge/graphmatching/StmFactory.hpp b/include/aidge/graphmatching/StmFactory.hpp deleted file mode 100644 index b5850e4a0..000000000 --- a/include/aidge/graphmatching/StmFactory.hpp +++ /dev/null @@ -1,55 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_STMFACTORY_H_ -#define AIDGE_STMFACTORY_H_ - -#include <map> -#include <utility> -#include <set> -#include <string> -#include <vector> -#include <iostream> -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <regex> - -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/Utile.hpp" - -namespace Aidge{ - - - -class StmFactory { - - const std::map<std::string,NodeRegex*>& mNodesRegex; - std::size_t mCmptStm = 0; -public: - StmFactory(const std::map<std::string,NodeRegex*>& nodesRegex); - //StmFactory(){}; - - SeqStm* makeNewStm(const std::string& sequRegex); - SeqStm* duplicateStm(SeqStm* stm); - - std::size_t getNumberOfStm(){ - return mCmptStm; - } -private: - - ParsingReturn initParsingSequRegex(const std::string& sequRegex); - - std::vector<std::vector<int>> initTransitionMatrix(ParsingReturn& parsing); - -}; -} - -#endif //AIDGE_STMFACTORY_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/Utile.hpp b/include/aidge/graphmatching/Utile.hpp deleted file mode 100644 index acda78cd1..000000000 --- a/include/aidge/graphmatching/Utile.hpp +++ /dev/null @@ -1,50 +0,0 @@ - -/** - * @file - * @brief - * @version file 1.0.0 - * @author vl241552 - * @copyright - * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. - * All rights reserved. - */ - -#ifndef _utile_H_ -#define _utile_H_ - -#include <map> - -#include "aidge/graph/Node.hpp" -#include <map> - -namespace Aidge { - -using NodeTmp = std::shared_ptr<Node>; -using NodeType = std::string; -using CommonTag = std::string; -using NodeTypeKey = std::pair<NodeType, CommonTag>; - -// type def -// struct NodeTypeKey { -// NodeType nodeType; -// std::string commonTag; - -// // for map find -// bool operator<(const NodeTypeKey& other) const { -// if (nodeType != other.nodeType or commonTag != other.commonTag) { -// return false; -// } else { -// return true; -// } -// } - -// }; - -struct ParsingReturn { - std::map<NodeTypeKey, int> typeToIdxTransition; - std::vector<std::pair<NodeTypeKey, std::string>> transition; -}; - -} // namespace Aidge - -#endif //_utile_H_ \ No newline at end of file diff --git a/python_binding/graphmatching/pybind_GRegex.cpp b/python_binding/graphmatching/pybind_GRegex.cpp deleted file mode 100644 index 48d0e19ff..000000000 --- a/python_binding/graphmatching/pybind_GRegex.cpp +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> -#include "aidge/graph/GraphView.hpp" -#include "aidge/graphmatching/GRegex.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_GRegex(py::module& m){ - py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex", "GRegex class combines a Node Regex and a list of Graph Regex that together describes a graph pattern as a graph regular expression. GRegex find patterns in a given graph that matches the graph regular expression.") - .def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps"), R"mydelimiter( - Constructor of GRegex - - :param nodesRegex: Describe the conditions an operator has to fulfill. - :type nodesRegex: Dict[str,:py:class:`aidge_core.NodeRegex`] - :param seqRegexps: Describe the graph topological pattern. List of Graph Regex as strings. - :type seqRegexps: List[str] - - )mydelimiter") - .def("match", &GRegex::match, py::arg("graphToMatch"), R"mydelimiter( - Launch the graph matching algorithm on a given graph. - - :param graphToMatch: The graph to perform the matching algorithm on. - :type graphToMatch: :py:class:`aidge_core.GraphView` - - :returns: Matched graph patterns. - :rtype: :py:class:`aidge_core.Match` - - )mydelimiter") - ; -} -} diff --git a/python_binding/graphmatching/pybind_Match.cpp b/python_binding/graphmatching/pybind_Match.cpp deleted file mode 100644 index a2d2654f4..000000000 --- a/python_binding/graphmatching/pybind_Match.cpp +++ /dev/null @@ -1,34 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> -#include "aidge/graphmatching/Match.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_Match(py::module& m){ - py::class_<Match, std::shared_ptr<Match>>(m, "Match", "Match class stores the matched patterns resulting from a graph matching query. A matched pattern is the combinaison of the graph pattern start nodes and the set of all the nodes in the matched pattern (including the start nodes)") - .def(py::init<>()) - .def("get_nb_match", &Match::getNbMatch, R"mydelimiter( - :returns: The number of graph patterns matched - :rtype: int - )mydelimiter") - .def("get_start_nodes", &Match::getStartNodes, R"mydelimiter( - :returns: All matched graph patterns start nodes - :rtype: List[List[:py:class:`aidge_core.Nodes`]] - )mydelimiter") - .def("get_match_nodes", &Match::getMatchNodes, R"mydelimiter( - :returns: All matched graph patterns sets of matched nodes - :rtype: List[Set[:py:class:`aidge_core.Nodes`]] - )mydelimiter"); -} -} diff --git a/python_binding/graphmatching/pybind_NodeRegex.cpp b/python_binding/graphmatching/pybind_NodeRegex.cpp deleted file mode 100644 index 034987f9c..000000000 --- a/python_binding/graphmatching/pybind_NodeRegex.cpp +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include "aidge/graphmatching/NodeRegex.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_NodeRegex(py::module& m){ - py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex", "NodeRegex class describes a condition to test on any operator. Current version only supports testing the type of the operator.") - .def(py::init<const std::string>(), py::arg("condition"), R"mydelimiter( - Constructor of NodeRegex - - :param condition: Condition to be fulfilled by an operator. - :type condition: str - - )mydelimiter") - ; -} -} diff --git a/src/graphmatching/GRegex.cpp b/src/graphmatching/GRegex.cpp deleted file mode 100644 index 6b54c5a47..000000000 --- a/src/graphmatching/GRegex.cpp +++ /dev/null @@ -1,301 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graph/GraphView.hpp" - -using namespace Aidge; - -GRegex::GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ):mStmFab(nodesRegex){ - - - //setup all the STM - for (const std::string& sequRegex : seqRegexps) { - mStmInit.push_back(mStmFab.makeNewStm(sequRegex)); - } - -} - -bool GRegex::walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm){ - //test if all stm type are in a valid state - std::vector<int> number_of_valid; - number_of_valid.resize(all_stm.size()); - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - number_of_valid[i] = 0; - for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) { - SeqStm* stm = *it; - if (stm->isValid()){ - number_of_valid[i] +=1; - } - } - } - - for (std::size_t i = 0; i < number_of_valid.size(); ++i) { - if (number_of_valid[i] == 0) { - //std::cout << "NO MATCH at least one stm are not valid" << std::endl; - return false; - } - if (number_of_valid[i] > 1) { - //std::cout << "NO MATCH multiple brach match of stm (// quantification)" << std::endl; - return false; - } - } - return true; -} - -bool GRegex::walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm){ - std::set<NodeTmp> all_stm_node_tested; - std::set<NodeTmp> all_stm_node_validated; - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - //std::cout << "all stm index " << i << " on dimension 1 of size " << all_stm.size() <<std::endl; - for (std::size_t j = 0; j < all_stm[i].size(); ++j) { - //std::cout << "all stm index " << j << " on dimension 2 of size " << all_stm[i].size() <<std::endl; - - std::set<NodeTmp> stm_node_tested = all_stm[i][j]->getAllNodeTested(); - std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated(); - - all_stm_node_tested.insert(stm_node_tested.begin(), stm_node_tested.end()); - all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end()); - } - } - - - std::set<NodeTmp> test_but_not_valid; - for (const auto& x : all_stm_node_tested) { - if (all_stm_node_validated.find(x) == all_stm_node_validated.end()) { - test_but_not_valid.insert(x); - } - } - - - if (!test_but_not_valid.empty()) { - std::cout << "NO MATCH. The node(s) "; - for (const auto& x : test_but_not_valid) { - std::cout << x.get() << ", "; - } - std::cout << " have been tested but not validated." << std::endl; - return false; - } - return true; - -} - -bool GRegex::walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm){ - std::map<NodeTmp, std::pair<std::string,int>> node_to_common_tag; - for (std::size_t i = 0; i < all_stm.size(); ++i) { - for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) { - SeqStm* stm = *it; - - if (!stm->isValid()){ - continue; - } - - for (const auto& pair : stm->getAllCommonNode()) { - const NodeTmp node = pair.first; - const std::string common_tag = pair.second; - - if (node_to_common_tag.find(node) != node_to_common_tag.end()) { - std::string tag = node_to_common_tag[node].first; - int& occurence = node_to_common_tag[node].second; - if (tag!=common_tag){ - std::cout << "NO MATCH. The node " << node << " have two different tags "<< tag << " and " << common_tag << std::endl; - return false; - } else { - occurence += 1; - } - } else { - node_to_common_tag.insert(std::make_pair(node, std::make_pair(common_tag, 1))); - } - } - } - } - /*std::cout << "Node to common tag "; - for (const auto& x : node_to_common_tag) { - std::cout << "(" << x.first << ", " << "[" << x.second.first << ", " << x.second.second << "]" << ") ; "; - } - std::cout << std::endl;*/ - - - for (const auto& pair : node_to_common_tag) { - const std::pair<std::string, int> tag_occurence_pair = pair.second; - if (tag_occurence_pair.second < 1){ - //std::cout << "NO MATCH. The common tag " << tag_occurence_pair.first << " did not match " << std::endl; - return false; - } - } - - return true; -} - -std::set<NodeTmp> GRegex::get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm){ - std::set<NodeTmp> all_stm_node_validated; - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - for (std::size_t j = 0; j < all_stm[i].size(); ++j) { - std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated(); - all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end()); - } - } - return all_stm_node_validated; -} - - -std::set<NodeTmp> GRegex::matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch){ - std::set<NodeTmp> empty_set_return; - //ASSERT - if(startNodes.size() != mStmInit.size()){ - throw std::runtime_error ("bad GRegex start nodes"); - } - - //init the walk - std::vector<std::vector<SeqStm*>> allStm; - std::vector<std::pair<NodeTmp,SeqStm*>> currentWalk; - - for (SeqStm* seqStmPtr : mStmInit) { - SeqStm* newStm = mStmFab.duplicateStm(seqStmPtr); - std::size_t idxStart = newStm->getStmIdx(); - currentWalk.push_back(std::make_pair(startNodes[idxStart],newStm)); - allStm.push_back(std::vector<SeqStm*>()); - } - - //walk - while (currentWalk.size()!=0) - { - std::vector<std::pair<NodeTmp,SeqStm*>> newWalk; - for (const auto& pair : currentWalk) { - const NodeTmp node = pair.first; - SeqStm* stmPtr = pair.second; - - std::pair<int,std::string> test = stmPtr->testNode(node); - int res = test.first; - std::string commonTag = test.second; - - std::set<NodeTmp> next_nodes = graphToMatch->getChildren(node); - - /*std::cout << "Next nodes : " ; - for (const auto& x : next_nodes) { - std::cout << x->name() << ", "; - } - std::cout << std::endl;*/ - - // Test Match - if (commonTag == "" && next_nodes.size() > 1) { - std::cout << "NO MATCH. The node " << node.get() << " is not common and has more than one child" << std::endl; - return empty_set_return; - } - - // If there is no more nodes --> Archive the branch - if (res == -1 || next_nodes.empty()) { - int indexToInsert = stmPtr->getStmIdx(); - allStm[indexToInsert].push_back(stmPtr); - //std::cout << "No more nodes --> STM archived : " << indexToInsert << std::endl; - continue; // TODEV : replace this with 'else' that encapsulate the rest of the function ? - } - - bool first = true; - - // Use an iterator to read through the next_nodes - std::set<NodeTmp>::iterator it; - for (it = next_nodes.begin(); it != next_nodes.end(); ++it) { - // Access the current element using the iterator - std::shared_ptr<Aidge::Node> next_node = *it; - if (first){ - newWalk.push_back(std::make_pair(next_node, stmPtr)); - first = false; - } else { - SeqStm* new_stmPtr = mStmFab.duplicateStm(stmPtr); - newWalk.push_back(std::make_pair(next_node, new_stmPtr)); - } - } - } - currentWalk = newWalk; - } - - //std::cout << "Walk finished" << std::endl; - - if (!walk_validation_all_stm_are_valid(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_all_stm_are_valid finished" << std::endl; - - - if (!walk_validation_all_node_read_validate_by_one_stm(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_all_node_read_validate_by_one_stm finished" << std::endl; - - - if (!walk_validation_common_nodes_same_tag_for_all_stm(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_common_nodes_same_tag_for_all_stm finished" << std::endl; - - //std::cout << "MATCH" << std::endl; - - return get_all_validate_nodes(allStm); - -} - - - -Match GRegex::match(const std::shared_ptr<GraphView> graphToMatch){ - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches; - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches; - Match matches; - std::size_t nbStartNodes = mStmInit.size(); - std::set<NodeTmp> allNodes = graphToMatch->getNodes(); - std::size_t nbAllNodes = allNodes.size(); - - std::vector<std::size_t> indices(nbStartNodes, 0); - - while (true) { - // Generate all permutations of the current combination - do { - std::vector<NodeTmp> startNodes; - //std::cout <<"start nodes :"; - for (std::size_t i = 0; i < nbStartNodes; ++i) { - auto it = std::begin(allNodes); - std::advance(it, indices[i]); - //std::cout << (*it).get() << " "; - startNodes.push_back(*it); - } - //std::cout <<"\n"; - - std::set<NodeTmp> match = matchFromStartNodes(startNodes, graphToMatch); - //std::cout << "match size : " << match.size() << " "; - if(match.size() != 0){ - //matches.push_back(std::make_pair(startNodes,match)); - //matches.insert(std::make_pair(startNodes,match)); - matches.insert(startNodes,match); - } - - } while (std::next_permutation(indices.begin(), indices.end())); - - // Generate the next combination with replacement - std::size_t i = nbStartNodes - 1; - while (true) { - if (indices[i] < nbAllNodes - 1) { - ++indices[i]; - break; - } - if (i == 0) { - return matches; - } - --i; - } - std::fill(indices.begin() + i + 1, indices.end(), indices[i]); - } - - return matches; -} \ No newline at end of file diff --git a/src/graphmatching/Match.cpp b/src/graphmatching/Match.cpp deleted file mode 100644 index 6c08b30b1..000000000 --- a/src/graphmatching/Match.cpp +++ /dev/null @@ -1,37 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/Match.hpp" - -using namespace Aidge; - -Match::Match(){ - //ctr -} - -size_t Match::getNbMatch(){ - assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted"); - return mStartNodes.size(); -} - -void Match::insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes){ - assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted"); - mStartNodes.push_back(startnodes); - mMatchNodes.push_back(matchnodes); -} - -std::vector<std::vector<NodeTmp>> Match::getStartNodes(){ - return mStartNodes; -} - -std::vector<std::set<NodeTmp>> Match::getMatchNodes(){ - return mMatchNodes; -} \ No newline at end of file diff --git a/src/graphmatching/NodeRegex.cpp b/src/graphmatching/NodeRegex.cpp deleted file mode 100644 index 9bf164f60..000000000 --- a/src/graphmatching/NodeRegex.cpp +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/NodeRegex.hpp" - - -// Verification done by the Attribute system - - -// Version 1 - Only test the type of the node (no need for a lexer) -// Input : Node_op -// Output : bool -// return mCondition == Node_op.type -bool Aidge::NodeRegex::_is(std::shared_ptr<Node> &Node_op){ - - std::string NodeType = Node_op->type(); - - return strcmp(NodeType.c_str(), mCondition.c_str()) == 0; -} - - -bool Aidge::NodeRegex::isA(std::string NodeType){ - - return strcmp(NodeType.c_str(), mCondition.c_str()) == 0; -} - -// Version 2 - Test the node to an advanced condition -// Input : Node_op -// Output : bool -// return mCondition applied on Node -/**bool NodeRegex::_is(string &Node_op){ - // Parsing the condition is done in the initialization of the NodeRegex - - // assert attributes exist in the node with the attribute function hasAttr() - - // get the attributes - -}*/ diff --git a/src/graphmatching/SeqStm.cpp b/src/graphmatching/SeqStm.cpp deleted file mode 100755 index 84553cb44..000000000 --- a/src/graphmatching/SeqStm.cpp +++ /dev/null @@ -1,247 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/SeqStm.hpp" - -using namespace Aidge; - - - - - /////////////////////////////////////////////////////// - - SeqStm::SeqStm( - const int stmIdx, - const std::vector<std::vector<int>>& transitionMatrix, - const std::map<std::string,NodeRegex*>& nodesRegex, - const std::map<NodeTypeKey,int>& typeToIdxTransition, - int actSt, - std::set<NodeTmp> allNodeValidated, - std::set<NodeTmp> allNodeTested, - std::set<std::pair<NodeTmp,std::string>> allCommonNode, - bool stmIsValid):mStmIdx(stmIdx), - mTransitionMatrix(transitionMatrix), - mNodesRegex(nodesRegex), - mTypeToIdxTransition(typeToIdxTransition) - { - - //assert - if (transitionMatrix.size() == 0){ - throw std::runtime_error ("no transitionMatrix"); - } - if(transitionMatrix[0].size() == 0 || transitionMatrix[0].size() != typeToIdxTransition.size()){ - throw std::runtime_error ("bad transitionMatrix"); - } - int size = static_cast<int>(transitionMatrix.size()); - if (actSt >= size){ - throw std::runtime_error ("bad actSt"); - } - - - mActSt = actSt; - mAllNodeValidated = allNodeValidated; - mAllNodeTested = allNodeTested; - mAllCommonNode = allCommonNode; - mStmIsValid = stmIsValid; - - } - - SeqStm* SeqStm::duplicateStm(){ - - //deep copy of the set - // std::set<Node> cAllNodeValidated(mAllNodeValidated.begin(), mAllNodeValidated.end()); - // std::set<Node> cAllNodeTested(mAllNodeTested.begin(), mAllNodeTested.end()); - - // std::set<std::pair<Node,std::string>> cAllCommonNode; - // for (const auto& p : mAllCommonNode) { - // cAllCommonNode.insert(p); - // } - - auto newStm = new SeqStm( - mStmIdx, - mTransitionMatrix, - mNodesRegex, - mTypeToIdxTransition, - mActSt, - mAllNodeValidated, - mAllNodeTested, - mAllCommonNode, - mStmIsValid - ); - - return newStm; - } - - - std::pair<NodeRegex*,std::string> SeqStm::getNodeRegexAndCommonAt(int idxType) - { - //std::cout << "!" << idxType << "\n"; - for (auto const& x : mTypeToIdxTransition) - { - //x.second is the value : idx in mTransitionMatrix for the type - //x.first pair of the node regex class and a string that is the common tag '',#,#n - if (x.second == idxType ){ - - if (mNodesRegex.find(x.first.first) != mNodesRegex.end()){ - return std::make_pair(mNodesRegex.find(x.first.first)->second, x.first.second); - }else{ - throw std::runtime_error ("a type is not define in NodesRegex"); - } - } - } - throw std::runtime_error ("bad idx in mNodesRegex"); - return std::make_pair(nullptr,nullptr); - } - - - NodeType SeqStm::getTheNodeType(NodeTmp node) - { - //the node is a str of '{type}{idx}' and we juste want type - // // std::regex re("([a-zA-Z]+)[0-9]+"); - // // std::smatch match; - // // if (std::regex_search(node, match, re) == true) { - // // return match.str(1); - // // } - // // throw std::runtime_error ("Type node not found"); - // // return ""; - - //return node->name(); - return node->type(); - } - - - std::string SeqStm::transitionOnNodeType(NodeType nodeType){ - - if (!isStmBlocked()){ - int idxType = 0; - for (auto & nextSt : mTransitionMatrix[mActSt]) { - // There are a next step for this type - //std::cout << "transition matrix next state -> "<< nextSt<<"\n" ; - if (nextSt != -1){ - //std::cout << "next -> "<< nextSt<< " "<< isAValidSt(nextSt) <<"\n" ; - auto nodeRegex = getNodeRegexAndCommonAt(idxType); - //std::cout << "-> "<< nodeRegex.second<<"\n" ; - if (nodeRegex.first->isA(nodeType)){ - //std::cout << "nodetype tested !"<<"\n" ; - if(isAValidSt(nextSt)){ - //std::cout << "Valid state !"<<"\n" ; - mStmIsValid = true; - } - mActSt = nextSt; - return nodeRegex.second; - } - - } - idxType += 1; - } - - mActSt =-1; - } - - return ""; - } - - - std::pair<int,std::string> SeqStm::testNode(const NodeTmp node){ - - std::string commonTag = ""; - //std::cout << "0\n" ; - if (!isStmBlocked()){ - bool isNextStEnd = std::all_of(mTransitionMatrix[mActSt].begin(), mTransitionMatrix[mActSt].end(), [&](int x){ return x == -1; }); - //std::cout << "1:"<< isNextStEnd <<"\n" ; - //if the next state if full of -1 can we relay add the node test to all node tested - // oker y test it but it sure that not be valid - if(!isNextStEnd){ - mAllNodeTested.insert(node); - } - //std::cout << "2\n" ; - //recurtion avoidance - if(mAllNodeValidated.find(node) == mAllNodeValidated.end()){ - - NodeType nodeType = getTheNodeType(node); - //std::cout << "3 " << nodeType << "\n" ; - commonTag = transitionOnNodeType(nodeType); - //after the transition test, if the node is != -1 the node is valid for the stm - //std::cout << " mActSt = " << mActSt << "\n" ; - if( mActSt != -1 ){ - mAllNodeValidated.insert(node); - } - }else{ - mActSt = -1; - } - } - - if(commonTag != ""){ - mAllCommonNode.insert(std::make_pair(node,commonTag)); - } - return std::make_pair(mActSt,commonTag); - } - - -void SeqStm::drawStm(){ - - //mTransitionMatrix - // Find the maximum width of each column - std::vector<std::size_t> max_widths(mTransitionMatrix[0].size(), 0); - for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i) - { - for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j) - { - std::size_t width = std::to_string(mTransitionMatrix[i][j]).length(); - if (width > max_widths[j]) - { - max_widths[j] = width; - } - } - } - - // Print the vector with aligned columns - for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i) - { - for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j) - { - int i_int = static_cast<int>(i); - if (mActSt == -1 ){ - if(mStmIsValid){ - std::cout << "\033[48;5;40m"; - }else{ - std::cout << "\033[48;5;9m"; - } - } - else if (mActSt == i_int){ - std::cout << "\033[48;5;30m"; - }else{ - std::cout << "\033[48;5;27m"; - } - - // Pad the value with spaces to align it with the maximum width - std::size_t width = std::to_string(mTransitionMatrix[i][j]).length(); - std::string padding(max_widths[j] - width, ' '); - std::cout << padding << mTransitionMatrix[i][j] << " "; - std::cout << "\033[0m"; - } - std::cout << "\n"; - } - - std::cout << "mAllNodeTested : "; - for (const auto& x : mAllNodeTested) { - std::cout << x << ", "; - } - std::cout << "\n"; - - - std::cout << "mAllNodeValidated : "; - for (const auto& x : mAllNodeValidated) { - std::cout << x << ", "; - } - std::cout << "\n"; -} - diff --git a/src/graphmatching/StmFactory.cpp b/src/graphmatching/StmFactory.cpp deleted file mode 100644 index 30b1fad81..000000000 --- a/src/graphmatching/StmFactory.cpp +++ /dev/null @@ -1,150 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/StmFactory.hpp" - -using namespace Aidge; - -StmFactory::StmFactory(const std::map<std::string, NodeRegex *> &nodesRegex) - : mNodesRegex(nodesRegex) {} - -SeqStm *StmFactory::duplicateStm(SeqStm *stm) { return stm->duplicateStm(); } - -SeqStm *StmFactory::makeNewStm(const std::string &sequRegex) { - - ParsingReturn parsing = initParsingSequRegex(sequRegex); - std::vector<std::vector<int>> transitionMatrix = - initTransitionMatrix(parsing); - - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp, std::string>> allCommonNode; - - SeqStm *newStm = new SeqStm(static_cast<int>(mCmptStm), transitionMatrix, mNodesRegex, - parsing.typeToIdxTransition, 0, allNodeValidated, - allNodeTested, allCommonNode, false); - mCmptStm += 1; - - return newStm; -} - -ParsingReturn StmFactory::initParsingSequRegex(const std::string &sequRegex) { - - std::string toMatch; - std::regex re("\\s*([A-Za-z]+)(#\\d*)?([+*])?\\s*(->|;)"); - std::smatch matches; - - int idxType = 0; - // return - ParsingReturn parsing; - // std::map<std::pair<NodeType,std::string>,int> typeToIdxTransition; - // std::vector<std::pair<std::pair<NodeType,std::string>,std::string>> - // transition; - // assert - std::map<NodeType, std::string> assertCommonNodeTypes; - - for (std::size_t i = 0; i < sequRegex.length(); i++) { - toMatch += sequRegex[i]; - if (std::regex_match(toMatch, matches, re)) { - - std::string type = matches.str(1); - std::string commonTag = matches.str(2); - std::string quantification = matches.str(3); - - if ((commonTag != "") && (quantification != "")) { - throw std::runtime_error("bad commonTag and quantification"); - } - - // make the typeToIdxTransition - NodeTypeKey typeTag = std::make_pair(type, commonTag); - /*std::cout << " typeTag: " << type << " " << commonTag - << parsing.typeToIdxTransition.size() << std::endl;*/ - if (parsing.typeToIdxTransition.find(typeTag) == - parsing.typeToIdxTransition.end()) { - parsing.typeToIdxTransition[typeTag] = idxType; - idxType += 1; - } - //////////////////////////////////////////////////////////// - // ASSERT - // SAME Common node in the sequ - if (commonTag != "") { - if (assertCommonNodeTypes.find(type) != assertCommonNodeTypes.end()) { - if (assertCommonNodeTypes[type] == commonTag) { - throw std::runtime_error("same common node in the sequ regex"); - } - } else { - assertCommonNodeTypes[type] = commonTag; - } - } - - // save all transition - parsing.transition.push_back(std::make_pair(typeTag, quantification)); - - /*std::cout << "Match found: " << matches.str() << std::endl; - std::cout << "Type: " << matches.str(1) << std::endl; - std::cout << "Common tag: " << matches.str(2) << std::endl; - std::cout << "Quantification: " << matches.str(3) << std::endl;*/ - - toMatch = ""; - } - } - if (parsing.transition.size() == 0) { - throw std::runtime_error("Bad Parsing SequRegex "); - } - - return parsing; -} - -std::vector<std::vector<int>> -StmFactory::initTransitionMatrix(ParsingReturn &parsing) { - - // std::pair<NodeTypeKey,std::string> - std::vector<std::vector<int>> transitionMatrix; - std::size_t numberOfType = parsing.typeToIdxTransition.size(); - - if (numberOfType == 0) { - throw std::runtime_error("Bad number Of Type "); - } - // init start st - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - - std::size_t idxTransition = 0; - int idxState = 0; - for (const auto &pair : parsing.transition) { - const NodeTypeKey &nodeTypeKey = pair.first; - const std::string &quant = pair.second; - - /*std::cout << "Key: {" << nodeTypeKey.first << ", " << nodeTypeKey.second - << "}, Value: " << quant << std::endl; - std::cout << "idxState " << idxState << " TM: " << transitionMatrix.size() - << std::endl;*/ - std::size_t idxType = parsing.typeToIdxTransition[nodeTypeKey]; - /*std::cout << "idxType " << idxType << " TM: " << transitionMatrix[0].size() - << "type" << numberOfType << std::endl;*/ - - if (quant == "*") { - transitionMatrix[idxTransition][idxType] = idxState; - } else if (quant == "+") { - idxState += 1; - transitionMatrix[idxTransition][idxType] = idxState; - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - idxTransition += 1; - transitionMatrix[idxTransition][idxType] = idxState; - } else { - - idxState += 1; - transitionMatrix[idxTransition][idxType] = idxState; - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - idxTransition += 1; - } - } - return transitionMatrix; -} \ No newline at end of file diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index 0d86d8789..6e345a647 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -21,9 +21,7 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" + //Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" @@ -32,21 +30,7 @@ using namespace Aidge; void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm){ - // assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - - // // Assert the nodes types are correct to be fused - // std::shared_ptr<Node> conv; - // std::shared_ptr<Node> batchnorm; - // for (const auto& element : nodes) { - // assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace"); - // if (element->type() == "Conv"){ - // conv = element; - // } - // else if (element->type() == "BatchNorm") { - // batchnorm = element; - // } - // } - // TODO : check if batchnorm is the only child of the Conv or FC + std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); @@ -146,20 +130,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){ } void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ - // std::map<std::string,NodeRegex*> nodesRegex ; - // nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm"); - // nodesRegex["Conv"] = new NodeRegex("Conv"); - // nodesRegex["FC"] = new NodeRegex("FC"); - - - // std::vector<std::string> seqRegex; - // seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC) - // GRegex GReg(nodesRegex, seqRegex); - // Match matches = GReg.match(graphView); - // std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - // for (size_t i = 0; i < matches.getNbMatch(); ++i) { - // fuseBatchNorm(matchNodes[i]); - // } + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'"); diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 806165967..df0fb5eff 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -22,9 +22,6 @@ #include "aidge/operator/GenericOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" //Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" @@ -84,9 +81,6 @@ void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){ void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ - // std::map<std::string,NodeRegex*> nodesRegex ; - // nodesRegex["MatMul"] = new NodeRegex("MatMul"); - // nodesRegex["Add"] = new NodeRegex("Add"); std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); regex->setNodeKey("Add","getType($) =='Add'"); @@ -97,26 +91,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ fuseMulAdd(solution); - // // solution->at("MatMul"); - // // solution->at("Add"); - // assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); - // assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); - // for (const auto& matmul : solution->at("MatMul")) { - // for (const auto& add : solution->at("Add")) { - // fuseMulAdd(matmul,add); - // } - // } } - // std::vector<std::string> seqRegex; - // seqRegex.push_back("MatMul -> Add;"); - // GRegex GReg(nodesRegex, seqRegex); - // Match matches = GReg.match(graphView); - // std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - // for (size_t i = 0; i < matches.getNbMatch(); ++i) { - // fuseMulAdd(matchNodes[i]); - // } } diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index 452c32b92..0dc8d856f 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -15,24 +15,14 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/utils/Recipies.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" + //Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" namespace Aidge { void removeFlatten(std::shared_ptr<Node> flatten) { - // assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - // std::shared_ptr<Node> flatten; - // for (const auto& element : nodes) { - // assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); - // if (element->type() == "Flatten"){ - // flatten = element; - // } - // } - + GraphView::replace({flatten}, {}); } @@ -49,18 +39,7 @@ namespace Aidge { void removeFlatten(std::shared_ptr<GraphView> graphView){ - // std::map<std::string,NodeRegex*> nodesRegex ; - // nodesRegex["Flatten"] = new NodeRegex("Flatten"); - // nodesRegex["FC"] = new NodeRegex("FC"); - // std::vector<std::string> seqRegex; - // seqRegex.push_back("Flatten->FC;"); - // GRegex GReg(nodesRegex, seqRegex); - // Match matches = GReg.match(graphView); - // std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - // for (size_t i = 0; i < matches.getNbMatch(); ++i) { - // removeFlatten(matchNodes[i]); - // } - + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); regex->setNodeKey("Flatten","getType($) =='Flatten'"); diff --git a/unit_tests/graphMatching/Test_GRegex.cpp b/unit_tests/graphMatching/Test_GRegex.cpp deleted file mode 100644 index 2c5907d82..000000000 --- a/unit_tests/graphMatching/Test_GRegex.cpp +++ /dev/null @@ -1,318 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Match.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/graph/GraphView.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init GRegex", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("A->B;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - // Perform tests - REQUIRE(GReg.getStmInit().size() == 1); - REQUIRE(GReg.getStmFab().getNumberOfStm() == 1); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - - -TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"Conv","BN","ReLU"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("Conv->BN->ReLU;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> Random2 = GenericOperator("Random2", 1, 1, 1); - - - g1->add(Conv1); - g1->addChild(BN1, Conv1); - g1->addChild(ReLU1, BN1); - g1->addChild(Random, ReLU1); - //g1->addChild(BN1, Random2); - - std::vector<std::shared_ptr<Node>> startNodes1; - std::set<std::shared_ptr<Node>> result; - - startNodes1.push_back(Conv1); - result = GReg.matchFromStartNodes(startNodes1, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(Conv1); - true_result.insert(BN1); - true_result.insert(ReLU1); - - // Perform tests - REQUIRE(result == true_result); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"Add","FC","Conv"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("Add#->Conv;"); - seqRegex.push_back("Add#->FC;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - // Instanciate a graphView - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> Add1 = GenericOperator("Add", 1, 1, 1); - std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - - g1->add(Random0); - g1->addChild(Add1, Random0); - g1->addChild(Conv1, Add1); - g1->addChild(BN1, Conv1); - g1->addChild(ReLU1, BN1); - g1->addChild(FC1, Add1); - g1->addChild(Random, FC1); - - // Test 1 : Find the match - std::vector<std::shared_ptr<Node>> startNodes; - std::set<std::shared_ptr<Node>> result; - - startNodes.push_back(Add1); - startNodes.push_back(Add1); - result = GReg.matchFromStartNodes(startNodes, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(Add1); - true_result.insert(Conv1); - true_result.insert(FC1); - - // Test 2 : Return an empty set when the start nodes are wrong - std::vector<std::shared_ptr<Node>> wrong_startNodes; - std::set<std::shared_ptr<Node>> wrong_start_result; - std::set<std::shared_ptr<Node>> empty_result; - - wrong_startNodes.push_back(Random0); - wrong_startNodes.push_back(Random0); - wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); - - // Perform tests - REQUIRE(result == true_result); - REQUIRE(wrong_start_result == empty_result); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -/* -TEST_CASE("Function matchFromStartNodes | Match a sequence with quantifier ", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"FC"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("FC+;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - - // Instanciate a graphView - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> FC2 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> FC3 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - - g1->add(Random0); - g1->addChild(FC1, Random0); - g1->addChild(FC2, FC1); - g1->addChild(FC3, FC2); - g1->addChild(ReLU1, FC3); - - // Test 1 : Find the match - std::vector<std::shared_ptr<Node>> startNodes; - std::set<std::shared_ptr<Node>> result; - - startNodes.push_back(FC1); - result = GReg.matchFromStartNodes(startNodes, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(FC1); - true_result.insert(FC2); - true_result.insert(FC3); - - // Test 2 : Return an empty set when the start nodes are wrong - std::vector<std::shared_ptr<Node>> wrong_startNodes; - std::set<std::shared_ptr<Node>> wrong_start_result; - std::set<std::shared_ptr<Node>> empty_result; - - wrong_startNodes.push_back(Random0); - wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); - - // Perform tests - REQUIRE(result == true_result); - REQUIRE(wrong_start_result == empty_result); -} -*/ - -TEST_CASE("Function match | ALL matches of Nodes sequence", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"GEMM"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("GEMM;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - //init the input graph - std::shared_ptr<GraphView> graphToMatch = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> GEMM1 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> GEMM2 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> GEMM3 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> ReLU2 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - - graphToMatch->add(Random0); - graphToMatch->addChild(GEMM1, Random0); - graphToMatch->addChild(ReLU1, GEMM1); - graphToMatch->addChild(GEMM2, ReLU1); - graphToMatch->addChild(GEMM3, GEMM2); - graphToMatch->addChild(ReLU2, GEMM3); - graphToMatch->addChild(Random, ReLU2); - - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); - Match matches = GReg.match(graphToMatch); - - size_t nb = matches.getNbMatch(); - std::vector<std::vector<NodeTmp>> gm_startnodes = matches.getStartNodes(); - std::vector<std::set<NodeTmp>> gm_matchnodes = matches.getMatchNodes(); - - std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs; - - for (size_t i = 0; i < nb; ++i) { - matchs.insert(std::make_pair(gm_startnodes[i], gm_matchnodes[i])); - } - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; - std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; - // Carefull : as the assert is on a vector, the Order of match matters - std::vector<NodeTmp> startNode = {GEMM1}; - std::set<NodeTmp> matchNode = {GEMM1}; - //toMatchs.push_back(std::make_pair(startNode,matchNode)); - toMatchs.insert(std::make_pair(startNode,matchNode)); - - std::vector<NodeTmp> startNode2 = {GEMM2}; - std::set<NodeTmp> matchNode2 = {GEMM2}; - //toMatchs.push_back(std::make_pair(startNode2,matchNode2)); - toMatchs.insert(std::make_pair(startNode2,matchNode2)); - - std::vector<NodeTmp> startNode3 = {GEMM3}; - std::set<NodeTmp> matchNode3 = {GEMM3}; - //toMatchs.push_back(std::make_pair(startNode3,matchNode3)); - toMatchs.insert(std::make_pair(startNode3,matchNode3)); - - REQUIRE(matchs == toMatchs); - REQUIRE(nb == 3); -} - - diff --git a/unit_tests/graphMatching/Test_NodeRegex.cpp b/unit_tests/graphMatching/Test_NodeRegex.cpp deleted file mode 100644 index 2866642bf..000000000 --- a/unit_tests/graphMatching/Test_NodeRegex.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> - -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/operator/GenericOperator.hpp" - - -using namespace Aidge; - -TEST_CASE("Create Noderegex", "[Noderegex]") { - std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("conv"); -} - -TEST_CASE("Test _is function", "[Noderegex]") { - // Create Noderegex with only condition on the name of the Node - // Create several operators to pass into Noderegex _is function - // Assert Noderegex._is(operators) are correct - std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("Conv"); - - std::shared_ptr<Node> Conv = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> FC = GenericOperator("FC", 1, 1, 1); - - REQUIRE(nr->_is(Conv) == true); - REQUIRE(nr->_is(FC) == false); - REQUIRE(nr->isA("Conv") == true); - REQUIRE(nr->isA("FC") == false); - -} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_SeqStm.cpp b/unit_tests/graphMatching/Test_SeqStm.cpp deleted file mode 100644 index db8662e33..000000000 --- a/unit_tests/graphMatching/Test_SeqStm.cpp +++ /dev/null @@ -1,167 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init SeqStm", "[SeqStm]") { - //init all iniput for SeqStm - - - int stmIdx = 0; - //matrix that in B->C - std::vector<std::vector<int>> transitionMatrix { - { -1, 1, -1 }, - { -1, -1, 2 }, - { -1, -1, -1 } }; - - //std::cout << transitionMatrix.size() << "\n"; - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // - - std::map<NodeTypeKey,int> typeToIdxTransition; - std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; - //init nodeTypeCommonTag - int idx = 0; - for (const NodeTypeKey& key : nodeTypeCommonTag) { - typeToIdxTransition[key] = idx; - idx += 1; - } - - int actSt = 0; - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp,std::string>> allCommonNode; - bool stmIsValid =false; - - - SeqStm stm( - stmIdx, - transitionMatrix, - nodesRegex, - typeToIdxTransition, - actSt, - allNodeValidated, - allNodeTested, - allCommonNode, - stmIsValid); - - REQUIRE(stm.getStmIdx() == 0); - REQUIRE(stm.isValid() == false); - REQUIRE(stm.getAllCommonNode().size() == 0); - REQUIRE(stm.getAllNodeTested().size() == 0); - REQUIRE(stm.getAllNodeValidated().size() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test testNode function", "[SeqStm]") { - - int stmIdx = 0; - std::map<NodeTypeKey,int> typeToIdxTransition; - std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; - //init nodeTypeCommonTag - int idx = 0; - for (const NodeTypeKey& key : nodeTypeCommonTag) { - typeToIdxTransition[key] = idx; - idx += 1; - } - //matrix that in B->C - std::vector<std::vector<int>> transitionMatrix { - { -1, 1, -1 }, - { -1, -1, 2 }, - { -1, -1, -1 } }; - - //std::cout << transitionMatrix.size() << "\n"; - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // - int actSt = 0; - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp,std::string>> allCommonNode; - bool stmIsValid =false; - - SeqStm stm( - stmIdx, - transitionMatrix, - nodesRegex, - typeToIdxTransition, - actSt, - allNodeValidated, - allNodeTested, - allCommonNode, - stmIsValid); - REQUIRE(stm.getStmIdx() == 0); - //test a node - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - stm.testNode(nodeB); - REQUIRE(stm.isValid() == false); - REQUIRE(stm.getState() == 1); - REQUIRE(stm.isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - - stm.testNode(nodeC); - REQUIRE(stm.isValid() == true); - REQUIRE(stm.getState() == 2); - REQUIRE(stm.isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - stm.testNode(nodeC); - REQUIRE(stm.isValid() == true); - REQUIRE(stm.getState() == -1); - REQUIRE(stm.isStmBlocked() == true); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_StmFactory.cpp b/unit_tests/graphMatching/Test_StmFactory.cpp deleted file mode 100644 index 3c66d0fa8..000000000 --- a/unit_tests/graphMatching/Test_StmFactory.cpp +++ /dev/null @@ -1,204 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init StmFactory", "[StmFactory]") { - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - StmFactory stmF(nodesRegex); - REQUIRE(stmF.getNumberOfStm() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - StmFactory stmF(nodesRegex); - - std::string seq1 = "A->B+->A#;"; - SeqStm* stm = stmF.makeNewStm(seq1); - REQUIRE(stm->getStmIdx() == 0); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getAllCommonNode().size() == 0); - REQUIRE(stm->getAllNodeTested().size() == 0); - REQUIRE(stm->getAllNodeValidated().size() == 0); - - std::string seq2 = "A->B;"; - SeqStm* stm2 = stmF.makeNewStm(seq2); - REQUIRE(stm2->getStmIdx() == 1); - REQUIRE(stm2->isValid() == false); - REQUIRE(stm2->getAllCommonNode().size() == 0); - REQUIRE(stm2->getAllNodeTested().size() == 0); - REQUIRE(stm2->getAllNodeValidated().size() == 0); - - //test the number of stm - REQUIRE(stmF.getNumberOfStm() == 2); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - - StmFactory stmF(nodesRegex); - std::string seq1 = "B->C;"; - SeqStm* stm = stmF.makeNewStm(seq1); - //test the number of stm - REQUIRE(stmF.getNumberOfStm() == 1); - - //std::shared_ptr<Node> nodeB = GenericOperator("B",1,1,1); - //std::shared_ptr<Node> nodeC = GenericiOperator("C",1,1,1); - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 0); - REQUIRE(stm->isStmBlocked() == false); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeB); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 1); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == 2); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == -1); - REQUIRE(stm->isStmBlocked() == true); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } - -} - -TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - - StmFactory stmF(nodesRegex); - std::string seq1 = "B->C;"; - SeqStm* stm = stmF.makeNewStm(seq1); - SeqStm* stmD = stmF.duplicateStm(stm); - - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - //run the stm - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 0); - REQUIRE(stm->isStmBlocked() == false); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeB); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 1); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == 2); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == -1); - REQUIRE(stm->isStmBlocked() == true); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - //check if stmD not move - REQUIRE(stmD->isValid() == false); - REQUIRE(stmD->getState() == 0); - REQUIRE(stmD->isStmBlocked() == false); - REQUIRE(stmD->getAllNodeTested().size() == 0); - REQUIRE(stmD->getAllNodeValidated().size() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -- GitLab