diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 754907443530f7e73d1e10ed9549d0c8eb78a011..883cbffc0f22c6a3d009f643dadf0aec9eb3f8fc 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -34,9 +34,9 @@ class test_recipies(unittest.TestCase): def test_fuse_matmul_add(self): matmul0 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul0") - add0 = aidge_core.Add(name="Add0") + add0 = aidge_core.Add(2, name="Add0") matmul1 = aidge_core.GenericOperator("MatMul", 1, 2, 1, name="MatMul1") - add1 = aidge_core.Add(name="Add1") + add1 = aidge_core.Add(2, name="Add1") graph_view = aidge_core.sequential([matmul0, add0, matmul1, add1]) diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index cc8763580076957d550c7c0702468a593e218569..6782392a77159814c9c363e236e21b87ca5480d9 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -14,21 +14,24 @@ #include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/TensorImpl.hpp" + #include "aidge/data/Data.hpp" #include "aidge/data/Tensor.hpp" + #include "aidge/graph/Connector.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/OpArgs.hpp" -#include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/Match.hpp" #include "aidge/graphmatching/NodeRegex.hpp" #include "aidge/graphmatching/SeqStm.hpp" #include "aidge/graphmatching/StmFactory.hpp" #include "aidge/graphmatching/Utile.hpp" + #include "aidge/operator/Add.hpp" #include "aidge/operator/AvgPooling.hpp" #include "aidge/operator/BatchNorm.hpp" +#include "aidge/operator/Concat.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/Div.hpp" @@ -45,14 +48,18 @@ #include "aidge/operator/Pow.hpp" #include "aidge/operator/ReLU.hpp" #include "aidge/operator/Scaling.hpp" +#include "aidge/operator/Slice.hpp" #include "aidge/operator/Softmax.hpp" #include "aidge/operator/Sqrt.hpp" #include "aidge/operator/Sub.hpp" + #include "aidge/scheduler/Scheduler.hpp" + +#include "aidge/recipies/Recipies.hpp" + #include "aidge/utils/Attributes.hpp" #include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/DynamicAttributes.hpp" -#include "aidge/utils/Recipies.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" //#include "aidge/utilsParsing/AstNode.hpp" diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 481099726843146173a37fcddc3bf69723b1a70e..4fb14f66f68954095e6caf79ffe5e3ea3f982169 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -162,6 +162,21 @@ public: std::vector<std::vector<std::pair<NodePtr, IOIndex_t>>> outputs( std::string nodeName) const; + /** + * @brief Assert Datatype, Backend, data format and dimensions along the GraphView are coherent. + * If not, apply the required transformations. + * @details Sets the GraphView ready for computation in four steps: + * 1 - Assert input Tensors' datatype is compatible with each Operator's datatype. + * If not, a conversion Operator is inserted. + * 2 - Assert input Tensors' backend is compatible with each Operator's backend. + * If not, add a Transmitter Operator. + * 3 - Assert data format (NCHW, NHWC, ...) of each Operator's input Tensor is + * compatible with the selected kernel. + * If not, add a Transpose Operator. + * 4 - Propagate Tensor dimensions through the consecutive Operators. + */ + void compile(const std::string& backend, const Aidge::DataType datatype); + /** * @brief Compute dimensions of input/output Tensors for each Operator of the * GraphView object's Nodes. @@ -428,4 +443,4 @@ private: }; } // namespace Aidge -#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */ diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 384aa64cd30b05dc22907f578627d902131f5053..b81f5288e387c7016f65928b434fbb8cb41cf6e9 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -187,7 +187,7 @@ public: IOIndex_t getNbFreeDataInputs() const; /** - * @brief List input ids of children liked to outputs of the node + * @brief List input ids of children linked to outputs of the node * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * IOIndex_t>>> */ diff --git a/include/aidge/graphRegex/GraphFsmInterpreter.hpp b/include/aidge/graphRegex/GraphFsmInterpreter.hpp index 9e92b6fe8fc9d5e44cb8051e687e33d7192e0eb7..e2fd43b9e641e8cb4a695e3a3eecf5975610d564 100644 --- a/include/aidge/graphRegex/GraphFsmInterpreter.hpp +++ b/include/aidge/graphRegex/GraphFsmInterpreter.hpp @@ -19,13 +19,16 @@ namespace Aidge { std::size_t mActGroupe; std::map<std::string,std::shared_ptr<ConditionalInterpreter>> mNodesCondition; + const std::string mGraphMatchExpr; public: - GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition); + GraphFsmInterpreter(const std::string graphMatchExpr,std::vector<std::shared_ptr<ConditionalInterpreter>> & nodesCondition); virtual ~GraphFsmInterpreter() =default; std::shared_ptr<FsmGraph> interpret(void); + + private: diff --git a/include/aidge/graphRegex/GraphLexer.hpp b/include/aidge/graphRegex/GraphLexer.hpp index e4137ab093c466b7349007da91e032dae48eda51..bd65dfc15d18533676b19e148a98185d3844acbd 100644 --- a/include/aidge/graphRegex/GraphLexer.hpp +++ b/include/aidge/graphRegex/GraphLexer.hpp @@ -36,6 +36,9 @@ namespace Aidge { bool isEnd(void); + const std::string getQuery(); + + /** * @brief Get the representation of the class * @return string @@ -46,7 +49,7 @@ namespace Aidge { /** * @brief Constructs an error message to display the character not understood by the lexer - * @return error mesage + * @return error message */ std::runtime_error badTokenError(const std::string& currentChars,std::size_t position); diff --git a/include/aidge/graphRegex/GraphParser.hpp b/include/aidge/graphRegex/GraphParser.hpp index 73406203a8be87e1df75cc694ab1ff281c27fbfa..29ee8c7b294eae2b8d8196de1702cb7e194cfa84 100644 --- a/include/aidge/graphRegex/GraphParser.hpp +++ b/include/aidge/graphRegex/GraphParser.hpp @@ -30,6 +30,13 @@ class GraphParser{ std::shared_ptr<AstNode<gRegexTokenTypes>> parse(void); + /** + * @brief get the query that be use in the parsing + * @return query + */ + const std::string getQuery(); + + private: /** * @brief restart at the start of the ConditionalExpressions for LEXER and restart mCurrentToken diff --git a/include/aidge/graphRegex/GraphRegex.hpp b/include/aidge/graphRegex/GraphRegex.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12a5139a36135979639d2447b869568b943ee840 --- /dev/null +++ b/include/aidge/graphRegex/GraphRegex.hpp @@ -0,0 +1,87 @@ +#ifndef AIDGE_CORE_GRAPH_REGEX_H_ +#define AIDGE_CORE_GRAPH_REGEX_H_ + +#include <string> + +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge{ + +/** + * @brief class which is the hight level interface for graph matching, used to simplify match definition + * + */ +class GraphRegex{ + + private: + + std::vector<std::string> mQuery; + std::vector<std::shared_ptr<ConditionalInterpreter>> mAllTest; + std::map<std::string, std::function<bool(NodePtr)>> mAllLambda; + + public: + GraphRegex(){}; + virtual ~GraphRegex() = default; + + /** + * @brief add a topology query to the match + * @param query the topology query to find + **/ + void addQuery(const std::string query); + + /** + * @brief get all the types of a graph and set it as type key in the query + * @param Reference graph use to get all the node types + **/ + void setKeyFromGraph(std::shared_ptr<GraphView> ref); + + /** + * @brief set a node test manually + * @param key the ref of this test used in the query + * @param ConditionalExpressions expression to test the node + **/ + void setNodeKey(const std::string key, const std::string conditionalExpressions ); + + /** + * @brief set a specific lambda that can be used in setQueryKey + * @param key ref to the lambda to use in the + * @param f expression to test the node ConditionalExpressions + **/ + void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); + + /*** + * @brief brief match the queries in the graph + * @param Reference the graph were the querys in search + * @return the result + */ + std::set<std::shared_ptr<MatchSolution>> match(std::shared_ptr<GraphView> ref); + + private: + + void _generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index, + std::vector<NodePtr>& current, std::set<std::vector<NodePtr>>& combinations); + + + + void _findLargestCompatibleSet( + const std::vector<std::shared_ptr<MatchSolution>>& solutions, + std::set<std::shared_ptr<MatchSolution>>& currentSet, + std::set<std::shared_ptr<MatchSolution>>& largestSet, + size_t currentIndex + ); + + std::set<std::shared_ptr<MatchSolution>> _findLargestCompatibleSet( + const std::vector<std::shared_ptr<MatchSolution>>& solutions + ); + + void _majConditionalInterpreterLambda(); + +}; +} + + +#endif //AIDGE_CORE_GRAPH_REGEX_H_ \ No newline at end of file diff --git a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp index c3eae528808dbdb8023718c961b7c45cbf4afac9..3e63f92337f6394382f6d92ef9f6dd7b5098a454 100644 --- a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp @@ -87,7 +87,7 @@ namespace Aidge{ * @brief set a new source to the edge * @return FsmNode */ - void reSetSouceNode(const std::shared_ptr<FsmNode>& newSource); + void reSetSourceNode(const std::shared_ptr<FsmNode>& newSource); /** * @brief get dest FsmNode * @return FsmNode diff --git a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp index 0a74551367dd492cb0abb820e4c5ce5a601d071e..d718009e87e5360981ff93ff808124581917c089 100644 --- a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp @@ -18,78 +18,89 @@ class FsmGraph { private: /** - * @brief all node origine + * @brief all node Origin */ - std::set<std::size_t> mAllOrigine; + std::set<std::size_t> mAllOrigin; std::set<std::shared_ptr<FsmEdge>> mEdges; + + + const std::string mQuery; + public: - FsmGraph(/* args */); + + FsmGraph(const std::string query); virtual ~FsmGraph() = default; -std::shared_ptr<MatchResult> test(std::vector<NodePtr>& StartNodes); - - - -const std::set<std::shared_ptr<FsmEdge>>& getEdge(void); -/** - * @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph -*/ -void addEdge(std::shared_ptr<FsmEdge>& edge); - -/** - * @brief get the liste of the starting states - * @details we need to use a vector because the order of the nodes is important for start node initialization \ref test() -*/ -const std::vector<std::shared_ptr<FsmNode>> getStartNodes(void); - -/** - * @brief get the set of the valide states - * @return set of valide state -*/ -const std::set<std::shared_ptr<FsmNode>> getValidNodes(void); - -/** - * @brief get the set of all the node in the graph - * @return set of all nodes -*/ -const std::set<std::shared_ptr<FsmNode>> getNodes(void); - -/** - * @brief set a groupe idx for all the nodes in the graph -*/ -void setGroupe(std::size_t groupeIdx); - -/** - * @brief make the union beteen this graph and an input graph - * @param fsmGraph graph to union -*/ -void unionG(const std::shared_ptr<FsmGraph> fsmGraph); - - -/** - * @brief make the union beteen this graph and an input graph and merge the valide state to the start state - * @param fsmGraph graph to merge -*/ -void mergeOneStartOneValid(const std::shared_ptr< FsmGraph> fsmGraph); -/** - * @brief get the number of sub FSM - * @return number of sub Fsm -*/ -std::size_t getNbSubFsm(void); - -/** - * @brief increment the origine of all node in the graph - * @param incr the incrémentation value -*/ -void incOrigineAllNodeBy(std::size_t incr); + std::vector<std::shared_ptr<MatchSolution>> test(const std::vector<NodePtr>& StartNodes); -private: -/** - * @brief merge tow node of the graph - * @param node -*/ -void _mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + + const std::set<std::shared_ptr<FsmEdge>>& getEdge(void); + /** + * @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph + */ + void addEdge(std::shared_ptr<FsmEdge>& edge); + + /** + * @brief get the list of the starting states + * @details we need to use a vector because the order of the nodes is important for start node initialization \ref test() + */ + const std::vector<std::shared_ptr<FsmNode>> getStartNodes(void); + + /** + * @brief get the set of the valid states + * @return set of valide state + */ + const std::set<std::shared_ptr<FsmNode>> getValidNodes(void); + + /** + * @brief get the set of all the node in the graph + * @return set of all nodes + */ + const std::set<std::shared_ptr<FsmNode>> getNodes(void); + + /** + * @brief set a groupe idx for all the nodes in the graph + */ + void setGroupe(std::size_t groupeIdx); + + /** + * @brief make the union between this graph and an input graph + * @param fsmGraph graph to union + */ + void unionG(const std::shared_ptr<FsmGraph> fsmGraph); + + + /** + * @brief make the union between this graph and an input graph and merge the valid state to the start state + * @param fsmGraph graph to merge + */ + void mergeOneStartOneValid(const std::shared_ptr< FsmGraph> fsmGraph); + /** + * @brief get the number of sub FSM + * @return number of sub Fsm + */ + std::size_t getNbSubFsm(void); + + /** + * @brief get the number of start state + * @return number of start state + */ + std::size_t getNbStart(void); + + /** + * @brief increment the origin of all nodes in the graph + * @param incr value + */ + void incOriginAllNodeBy(std::size_t incr); + + private: + + /** + * @brief merge tow node of the graph + * @param node + */ + void _mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); }; diff --git a/include/aidge/graphRegex/matchFsm/FsmNode.hpp b/include/aidge/graphRegex/matchFsm/FsmNode.hpp index 2776ff8eb297fd5ad9a4c425fb386adde0a25269..7987c5ce33522ca7d43de1918d53e68738af6d18 100644 --- a/include/aidge/graphRegex/matchFsm/FsmNode.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmNode.hpp @@ -33,7 +33,7 @@ namespace Aidge{ * @details a state can be and/or : * - a valide state, the match is valide if it stop on this edge * - a start state , the match start on this state - * The state is also define by this origine (is the unique id of it's expretion ) + * The state is also define by this Origin (is the unique id of it's expretion ) * and it's groupe (for inner expression TODO) */ class FsmNode : public std::enable_shared_from_this<FsmNode> @@ -49,8 +49,8 @@ namespace Aidge{ */ std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>> mParents; - std::size_t mOrigineStm = 0; - std::size_t mGroupeStm = 0; + std::size_t mOriginFsm = 0; + std::size_t mGroupeFsm = 0; bool mIsAValid; bool mIsAStart; @@ -59,7 +59,7 @@ namespace Aidge{ FsmNode(bool isAValid,bool isAStart ); virtual ~FsmNode() = default; /** - * @brief use to MAG the actual context , and return all the posible new context + * @brief use to MAG the actual context , and return all the possible new context * @details one input context can generate a multitude of contexts because a graph node * can have more than one child, and each traversal possibility is a new context. * @param actContext the actual context @@ -68,8 +68,8 @@ namespace Aidge{ const std::vector<std::shared_ptr<FsmRunTimeContext>> test( std::shared_ptr<FsmRunTimeContext>); - std::size_t getOrigine(void); - void incOrigine(std::size_t inc); + std::size_t getOrigin(void); + void incOrigin(std::size_t inc); void rmEdge(std::shared_ptr<FsmEdge>); diff --git a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp index 6f1b9fc2bfe68195f67cfc0bf17d57aed5345219..2f6066ba4cd97284c43b509c9d5eb988b65b53a5 100644 --- a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp +++ b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp @@ -152,7 +152,7 @@ namespace Aidge{ std::set<NodePtr> getValidNodes(void); std::set<NodePtr> getValidNodesNoCommon(void); - std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> getValid(void); + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& getValid(void); NodePtr getActNode(void); diff --git a/include/aidge/graphRegex/matchFsm/MatchResult.hpp b/include/aidge/graphRegex/matchFsm/MatchResult.hpp index ac2f2a627a9d88b3cabeac4b181af2f3b7566d72..29b9abb616a80899b9c2ad8d5e01e5f00e674757 100644 --- a/include/aidge/graphRegex/matchFsm/MatchResult.hpp +++ b/include/aidge/graphRegex/matchFsm/MatchResult.hpp @@ -11,9 +11,31 @@ namespace Aidge{ +/** + * @brief contained the result of one match and the associate key , the query and the start node +*/ + +class MatchSolution{ +private: + std::map<std::string,std::set<NodePtr>> mSolution; + const std::string mQueryFrom; + const std::vector<NodePtr> mStartNode; + +public: + MatchSolution(std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string query,const std::vector<NodePtr> startNode); + const std::set<NodePtr> & at(const std::string key); + const std::set<NodePtr> getAll(); + bool areCompatible(std::shared_ptr<MatchSolution> solution); + + const std::string& getQuery(){ return mQueryFrom ;} + const std::vector<NodePtr>& getStartNode(){ return mStartNode ;} + +}; + + /** * @brief class that old the result of a matching - * give acess to all node ant there tag in the expression + * give access to all node and there tag in the expression */ class MatchResult { @@ -22,17 +44,20 @@ private: std::vector<std::shared_ptr<FsmRunTimeContext>> mAllValid; /* - the Run time of eatch sub FSM , to have a valide match we need a set of one run time per FSM compatible - the id must be contigue + the Run time of each sub FSM , to have a valid match we need a set of one run time per FSM compatible + the id must be continue */ std::vector<std::vector<std::shared_ptr<FsmRunTimeContext>>> mIdToRunTime; - std::vector<std::set<NodePtr>> mSolve; + std::vector<std::shared_ptr<MatchSolution>> mSolve; std::size_t mNbSubStm; + + public: - MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm); + MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm, + const std::string& query,const std::vector<NodePtr>& startNodes); virtual ~MatchResult() = default; @@ -40,16 +65,18 @@ public: * @brief get the set of the node match for une expression * @return the set of node of the graph that corresponding to an expression */ - std::set<NodePtr> getBiggerSolution(void); + std::shared_ptr<MatchSolution> getBiggerSolution(void); + + std::vector<std::shared_ptr<MatchSolution>> getSolutions(void); private: /** - * @brief recurent function use to inite mSolve in the constructor + * @brief recurrent function use to init mSolve in the constructor * **/ -void _generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence); +void _generateCombination( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string& query,const std::vector<NodePtr>& startNodes); }; diff --git a/include/aidge/graphmatching/GRegex.hpp b/include/aidge/graphmatching/GRegex.hpp deleted file mode 100644 index fd2d0c52ab47e0f03b3307bdbcfcb5a7b81d78d9..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/GRegex.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - - -#ifndef AIDGE_GREGEX_H_ -#define AIDGE_GREGEX_H_ - -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <regex> -#include <memory> // for shared_ptr -#include <algorithm> // for next_permutation - -#include "aidge/graphmatching/Utile.hpp" -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Match.hpp" - - -namespace Aidge{ - -class GRegex { -// __init__(self,nodes_regex:dict,seq_regexps:list) - - StmFactory mStmFab; - std::vector<SeqStm*> mStmInit; - -public: - GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ); - - std::set<NodeTmp> matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch); - - bool walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm); - - bool walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm); - - bool walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm); - - std::set<NodeTmp> get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm); - - std::vector<SeqStm*> getStmInit() const { - return mStmInit; - } - - StmFactory getStmFab() const { - return mStmFab; - } - - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> match(const std::shared_ptr<GraphView> graphToMatch); - Match match(const std::shared_ptr<GraphView> graphToMatch); - -}; - -} -#endif //AIDGE_GREGEX_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/Match.hpp b/include/aidge/graphmatching/Match.hpp deleted file mode 100644 index fc617a22869fde6531fba67c8641581572cbffc4..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/Match.hpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_MATCH_H_ -#define AIDGE_MATCH_H_ - -#include <vector> -#include <set> -#include <iostream> -#include <cassert> -#include "aidge/graphmatching/Utile.hpp" - - -namespace Aidge{ - -class Match { - -public: - Match(); - - size_t getNbMatch(); - - void insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes); - - std::vector<std::vector<NodeTmp>> getStartNodes(); - - std::vector<std::set<NodeTmp>> getMatchNodes(); - -protected: - std::vector<std::vector<NodeTmp>> mStartNodes; - std::vector<std::set<NodeTmp>> mMatchNodes; - -}; - -} -#endif //AIDGE_MATCH_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/NodeRegex.hpp b/include/aidge/graphmatching/NodeRegex.hpp deleted file mode 100644 index 10ba7225834e4abfb7f0f5cd45ffa91b22f2f87d..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/NodeRegex.hpp +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_NODEREGEX_H_ -#define AIDGE_NODEREGEX_H_ -#include <cstdlib> -#include <iostream> -#include <cstring> -#include "aidge/graph/Node.hpp" - - -namespace Aidge { - -class NodeRegex -{ - public: - std::string mCondition; - - NodeRegex(const std::string c){ - mCondition = c; - }; - - // Version 1 - Only test the type of the node (no need for a lexer) - // Input : Node_op - // Output : bool - // return mCondition == Node_op.type - bool _is(std::shared_ptr<Node> &Node_op); - bool isA(std::string NodeType); -}; - -} - -#endif /* _AIDGE_NODEREGEX_H__ */ \ No newline at end of file diff --git a/include/aidge/graphmatching/SeqStm.hpp b/include/aidge/graphmatching/SeqStm.hpp deleted file mode 100755 index 0823b5fc0f292d8cf28f7ead53d01bd8dd8adbfe..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/SeqStm.hpp +++ /dev/null @@ -1,127 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_SEQSTM_H_ -#define AIDGE_SEQSTM_H_ - -#include <iostream> -#include <map> -#include <regex> -#include <set> -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <string> -#include <utility> -#include <vector> - - -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Utile.hpp" - - -namespace Aidge { - -class SeqStm { - -private: - const int mStmIdx; - const std::vector<std::vector<int>> mTransitionMatrix; - // str key of type like 'A' that ce use in the A->B .. extpr - const std::map<std::string, NodeRegex *> mNodesRegex; - // mTypeToIdxTransition.first = std::pair node_type , common_tag - // mTypeToIdxTransition.segond = idx in trans matrix - const std::map<NodeTypeKey, int> mTypeToIdxTransition; - - int mActSt; - std::set<NodeTmp> mAllNodeValidated; - std::set<NodeTmp> mAllNodeTested; - std::set<std::pair<NodeTmp, std::string>> mAllCommonNode; - bool mStmIsValid; - - std::pair<NodeRegex *, std::string> getNodeRegexAndCommonAt(int idxType); - - /** - * @brief test the stm on a type - * @return the common tag - */ - std::string transitionOnNodeType(NodeType nodeType); - -public: - SeqStm(const int mStmIdx, - const std::vector<std::vector<int>> &mTransitionMatrix, - const std::map<std::string, NodeRegex *> &mNodesRegex, - const std::map<NodeTypeKey, int> &mTypeToIdxTransition, int mActSt, - std::set<NodeTmp> mAllNodeValidated, std::set<NodeTmp> mAllNodeTested, - std::set<std::pair<NodeTmp, std::string>> mAllCommonNode, - bool mStmIsValid); - - ////////////////////////////////////// - // STM test - ///////////////////////////////////// - - /** - * @brief get if a st is a valide one - * @return bool - */ - bool isAValidSt(int st) { - std::size_t size = mTransitionMatrix.size(); - return st == static_cast<int>(size - 1) ? true : false; - } - - /** - * @brief true if the stm is blocked into st - * @return bool - */ - bool isStmBlocked() { return mActSt == -1 ? true : false; } - - /** - * @brief true if the stm into valide st - * @return bool - */ - bool isValid() { return mStmIsValid; } - - ///////////////////////////////////// - // utile - ///////////////////////////////////// - /** - * @brief extract from a node is type - * @return bool - */ - NodeType getTheNodeType(NodeTmp node); - - void drawStm(); - ///////////////////////////////////// - // geter - ///////////////////////////////////// - - std::set<std::pair<NodeTmp, std::string>> getAllCommonNode() { - return mAllCommonNode; - } - std::set<NodeTmp> getAllNodeTested() { return mAllNodeTested; } - - std::set<NodeTmp> getAllNodeValidated() { return mAllNodeValidated; } - - SeqStm *duplicateStm(); - - int getStmIdx() { return mStmIdx; } - - int getState() { return mActSt; } - ////////////////////////////////////////// - // USE - ////////////////////////////////////////// - /** - * @brief test the stm on a node - * @return pair new stm state, the common tag - */ - std::pair<int, std::string> testNode(const NodeTmp node); -}; -} // namespace Aidge - -#endif /* AIDGE_SEQSTM_H_ */ \ No newline at end of file diff --git a/include/aidge/graphmatching/StmFactory.hpp b/include/aidge/graphmatching/StmFactory.hpp deleted file mode 100644 index b5850e4a00691ef6c808554a86a6ceec8c38ad19..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/StmFactory.hpp +++ /dev/null @@ -1,55 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_STMFACTORY_H_ -#define AIDGE_STMFACTORY_H_ - -#include <map> -#include <utility> -#include <set> -#include <string> -#include <vector> -#include <iostream> -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <regex> - -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/Utile.hpp" - -namespace Aidge{ - - - -class StmFactory { - - const std::map<std::string,NodeRegex*>& mNodesRegex; - std::size_t mCmptStm = 0; -public: - StmFactory(const std::map<std::string,NodeRegex*>& nodesRegex); - //StmFactory(){}; - - SeqStm* makeNewStm(const std::string& sequRegex); - SeqStm* duplicateStm(SeqStm* stm); - - std::size_t getNumberOfStm(){ - return mCmptStm; - } -private: - - ParsingReturn initParsingSequRegex(const std::string& sequRegex); - - std::vector<std::vector<int>> initTransitionMatrix(ParsingReturn& parsing); - -}; -} - -#endif //AIDGE_STMFACTORY_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/Utile.hpp b/include/aidge/graphmatching/Utile.hpp deleted file mode 100644 index acda78cd181519c86ab0b14d5b01bf91223cec9d..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/Utile.hpp +++ /dev/null @@ -1,50 +0,0 @@ - -/** - * @file - * @brief - * @version file 1.0.0 - * @author vl241552 - * @copyright - * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. - * All rights reserved. - */ - -#ifndef _utile_H_ -#define _utile_H_ - -#include <map> - -#include "aidge/graph/Node.hpp" -#include <map> - -namespace Aidge { - -using NodeTmp = std::shared_ptr<Node>; -using NodeType = std::string; -using CommonTag = std::string; -using NodeTypeKey = std::pair<NodeType, CommonTag>; - -// type def -// struct NodeTypeKey { -// NodeType nodeType; -// std::string commonTag; - -// // for map find -// bool operator<(const NodeTypeKey& other) const { -// if (nodeType != other.nodeType or commonTag != other.commonTag) { -// return false; -// } else { -// return true; -// } -// } - -// }; - -struct ParsingReturn { - std::map<NodeTypeKey, int> typeToIdxTransition; - std::vector<std::pair<NodeTypeKey, std::string>> transition; -}; - -} // namespace Aidge - -#endif //_utile_H_ \ No newline at end of file diff --git a/include/aidge/hook/ExecTime.hpp b/include/aidge/hook/ExecTime.hpp index 212fef58696be702e89c8ad973dcc0dd0fc389ae..0964d9575b7ad345d5e07c9f19c7e56a3b69c813 100644 --- a/include/aidge/hook/ExecTime.hpp +++ b/include/aidge/hook/ExecTime.hpp @@ -18,7 +18,7 @@ #define execTime_H_ #include "aidge/operator/Operator.hpp" -#include "aidge/hook/hook.hpp" +#include "aidge/hook/Hook.hpp" #include <memory> #include <chrono> #include <vector> diff --git a/include/aidge/hook/OutputRange.hpp b/include/aidge/hook/OutputRange.hpp index a2da2a997d594c0ef78fb7c31f33b32c3495c4eb..355f4aaa15a6bcd77d99ec2dad344a45f8f9edc0 100644 --- a/include/aidge/hook/OutputRange.hpp +++ b/include/aidge/hook/OutputRange.hpp @@ -18,7 +18,7 @@ #define AIDGE_CORE_HOOK_OUTPUTRANGE_H_ #include "aidge/operator/Operator.hpp" -#include "aidge/hook/hook.hpp" +#include "aidge/hook/Hook.hpp" #include <memory> #include <chrono> #include <vector> diff --git a/include/aidge/nodeTester/ConditionalInterpreter.hpp b/include/aidge/nodeTester/ConditionalInterpreter.hpp index 165fac1c2ae98bf76b73c039de9fc975e9845cc9..af6a3b920bb9ca389724860d55250d7ef4540677 100644 --- a/include/aidge/nodeTester/ConditionalInterpreter.hpp +++ b/include/aidge/nodeTester/ConditionalInterpreter.hpp @@ -22,7 +22,7 @@ namespace Aidge{ ///////////////////////////// /** * @brief class used to register any lambda function without context, - * it encapsulates the source lambda in a lambda which takes as argument ConditionalData* which are any type. + * it encapsulates the source lambda in a lambda which takes as argument std::shared_ptr<ConditionalData> which are any type. * @see ConditionalData */ class ConditionalRegisterFunction { @@ -31,12 +31,12 @@ class ConditionalRegisterFunction { ////////////////////////// /** - * @brief recast the ConditionalData* to the argument type of the lambda + * @brief recast the std::shared_ptr<ConditionalData> to the argument type of the lambda * @tparam T type of the lambda argument * @see ConditionalData */ template <typename T> - T safeCastInput(ConditionalData* data) { + T safeCastInput( std::shared_ptr<ConditionalData> data) { //cnvertion and type cheking if (data->isTypeEqualTo<T>()){ return data->getValue<T>(); @@ -48,14 +48,14 @@ class ConditionalRegisterFunction { /** - * @brief recaste the output of the lambda to a ConditionalData* + * @brief recaste the output of the lambda to a std::shared_ptr<ConditionalData> * @tparam T type of the lambda return * @see ConditionalData */ template <typename T> - ConditionalData* safeCastOutput(T data) { + std::shared_ptr<ConditionalData> safeCastOutput(T data) { - ConditionalData* out = new ConditionalData; + std::shared_ptr<ConditionalData> out = std::make_shared<ConditionalData>(); out->setValue<T>(data); return out; @@ -111,11 +111,11 @@ class ConditionalRegisterFunction { }; ///////////////////// - //change the function to ConditionalData*(std::vector<ConditionalData*>) + //change the function to std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>) ///////////////////// /** - * @brief Converts a function to a ConditionalData*(std::vector<ConditionalData*>). + * @brief Converts a function to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>). * @tparam F The type of the function to convert. * @tparam ParamsIdx The indices of the function parameters. * @param f The function to convert. @@ -124,25 +124,31 @@ class ConditionalRegisterFunction { template <class F, std::size_t... ParamsIdx> auto funcPointer(F f, std::index_sequence<ParamsIdx...>) { //wrapp the lambda in a new one that as ConditionalData as inputs and output - return [this,f](std::vector<ConditionalData*> &args) { - if (args.size() != sizeof...(ParamsIdx)){ + return [this,f](std::vector< std::shared_ptr<ConditionalData>> &args) { + if (args.size() < sizeof...(ParamsIdx)){ std::ostringstream errorMessage; errorMessage << "bad Number of argument: get " << args.size() << " need " << sizeof...(ParamsIdx) << "\n"; throw std::runtime_error(errorMessage.str()); } - //assert(args.size() == sizeof...(ParamsIdx));//the size of the vector valide + //we used std::vector< std::shared_ptr<ConditionalData>> as a fifo + std::size_t offset = args.size()-sizeof...(ParamsIdx); using FuncTraits = function_traits<decltype(f)>; using outType = typename FuncTraits::return_type; - outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[ParamsIdx])...); + outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[offset+ParamsIdx])...); + + //suppress what we used + for (size_t i = 0; i < sizeof...(ParamsIdx); ++i) { + args.pop_back(); + } //typename return safeCastOutput<outType>(result); }; } /** - * @brief Converts a function pointer to a ConditionalData*(std::vector<ConditionalData*>). + * @brief Converts a function pointer to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>). * @tparam R The return type of the function. * @tparam Params The parameter types of the function. * @param f The function pointer to convert. @@ -154,7 +160,7 @@ class ConditionalRegisterFunction { } /** - * @brief Converts a std::function to a ConditionalData*(std::vector<ConditionalData*>). + * @brief Converts a std::function to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>). * @tparam R The return type of the function. * @tparam Params The parameter types of the function. * @param f The function pointer to convert. @@ -196,11 +202,18 @@ class ConditionalRegisterFunction { * @param datas The vector of input data. * @return A pointer to the output ConditionalData object. */ - ConditionalData* run(const std::string key,std::vector<ConditionalData*> & datas); + std::shared_ptr<ConditionalData> run(const std::string key,std::vector< std::shared_ptr<ConditionalData>> & datas); + + bool isLambdaRegister(const std::string &key) { + if(mWlambda.find(key) != mWlambda.end()){ + return true; + } + return false; + } private: /// @brief map of name and the converted function. - std::map<const std::string, std::function<ConditionalData*(std::vector<ConditionalData*> &)>> mWlambda; + std::map<const std::string, std::function< std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>> &)>> mWlambda; }; /////////////////// @@ -227,28 +240,38 @@ class ConditionalInterpreter * @brief the registery for the lambda fuction * @see ConditionalRegisterFunction */ - ConditionalRegisterFunction mLambdaRegiter; + ConditionalRegisterFunction mLambdaRegister; - std::vector<ConditionalData*> mResolution ; + std::vector< std::shared_ptr<ConditionalData>> mResolution ; - void clearRes(){ + // void clearRes(){ - for (std::size_t i = 0; i < mResolution.size(); ++i) { - delete mResolution[i]; - } - mResolution.clear(); - } + // for (std::size_t i = 0; i < mResolution.size(); ++i) { + // delete mResolution[i]; + // } + // mResolution.clear(); + // } public: + + const std::string mKey; + /** * @brief Constructor * @param ConditionalExpressions The expression of the test to be performed on the nodes */ - ConditionalInterpreter(const std::string ConditionalExpressions); + ConditionalInterpreter(const std::string key,const std::string ConditionalExpressions); - ~ConditionalInterpreter(){clearRes();} + ~ConditionalInterpreter(){} + + /** + * @brief get the condition key + * @return the key + */ + + const std::string& getKey(); /** * @brief Test a node depending of the ConditionalExpressions @@ -266,7 +289,7 @@ class ConditionalInterpreter */ void insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f); - + bool isLambdaRegister(const std::string &key); ///// private: @@ -276,12 +299,12 @@ class ConditionalInterpreter * @param NodeOp The node currently being tested * @param nodes The AST given by the parsing process */ - std::vector<ConditionalData*> visit(const ASTNodeCh& nodes, const NodePtr NodeOp ); + std::vector< std::shared_ptr<ConditionalData>> visit(const ASTNodeCh& nodes, const NodePtr NodeOp ); /** * @defgroup ASTnodeInterpreterF Functions for interpreting AST nodes * @brief For each node type in the AST, function defines the processing to be performed - * they return a std::vector<ConditionalData*> which corresponds to the value(s) obtained + * they return a std::vector< std::shared_ptr<ConditionalData>> which corresponds to the value(s) obtained */ /** @@ -291,38 +314,38 @@ class ConditionalInterpreter void fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); /** * @ingroup ASTnodeInterpreterF - * @brief Converted the lexeme to a int and to ConditionalData* + * @brief Converted the lexeme to a int and to std::shared_ptr<ConditionalData> */ void fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); /** * @ingroup ASTnodeInterpreterF - * @brief Converted the lexeme to a float and to ConditionalData* + * @brief Converted the lexeme to a float and to std::shared_ptr<ConditionalData> */ void fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); /** * @ingroup ASTnodeInterpreterF - * @brief Converted the lexeme to a str and to ConditionalData* + * @brief Converted the lexeme to a str and to std::shared_ptr<ConditionalData> */ void fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); /** * @ingroup ASTnodeInterpreterF - * @brief makes the == operation between two previously converted ConditionalData* + * @brief makes the == operation between two previously converted std::shared_ptr<ConditionalData> */ void fEq(void); /** * @ingroup ASTnodeInterpreterF - * @brief makes the != operation between two previously converted ConditionalData* + * @brief makes the != operation between two previously converted std::shared_ptr<ConditionalData> */ void fNeq(void); /** * @ingroup ASTnodeInterpreterF - * @brief makes the && operation between two previously converted ConditionalData* in bool + * @brief makes the && operation between two previously converted std::shared_ptr<ConditionalData> in bool */ void fAnd(void); /** * @ingroup ASTnodeInterpreterF - * @brief makes the || operation between two previously converted ConditionalData* in bool + * @brief makes the || operation between two previously converted std::shared_ptr<ConditionalData> in bool */ void fOr(void); diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 65c7e8ce0e47bd470e2a1499a682ed2f2c8c2dbc..b5e37f9bc52d4a74dabf68022235feadc384748f 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -16,7 +16,7 @@ #include <vector> #include <cmath> #include <memory> -#include <array> +#include <vector> #include "aidge/utils/Registrar.hpp" #include "aidge/operator/Operator.hpp" @@ -26,24 +26,21 @@ namespace Aidge { -template <std::size_t NUM> class Add_Op : public Operator, - public Registrable<Add_Op<NUM>, std::string, std::unique_ptr<OperatorImpl>(const Add_Op<NUM>&)> { -public: + public Registrable<Add_Op, std::string, std::unique_ptr<OperatorImpl>(const Add_Op&)> { +private: // FIXME: change accessibility - std::array<std::shared_ptr<Tensor>, NUM> mInputs; + std::vector<std::shared_ptr<Tensor>> mInputs; const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); public: static constexpr const char* Type = "Add"; - constexpr Add_Op() - : Operator(Type) + Add_Op(const IOIndex_t nbIn) + : Operator(Type), + mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())) { - assert(NUM > 0 && "Add should have at least one input"); - for (std::size_t i = 0; i<NUM; ++i) { - mInputs[i] = std::make_shared<Tensor>(); - } + assert(nbIn > 0 && "Add should have at least one input"); setDatatype(DataType::Float32); } @@ -51,17 +48,15 @@ public: * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - Add_Op(const Add_Op<NUM>& op) + Add_Op(const Add_Op& op) : Operator(Type), + mInputs(std::vector<std::shared_ptr<Tensor>>(op.nbInputs())), mOutput(std::make_shared<Tensor>(*op.mOutput)) { // cpy-ctor - assert(NUM > 0 && "Add should have at least one input"); - for (std::size_t i = 0; i<NUM; ++i) { - mInputs[i] = std::make_shared<Tensor>(); - } + assert(op.nbInputs() > 0 && "Add should have at least one input"); setDatatype(op.mOutput->dataType()); - mImpl = op.mImpl ? Registrar<Add_Op<NUM>>::create(mOutput->getImpl()->backend())(*this) : nullptr; + mImpl = op.mImpl ? Registrar<Add_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; } /** @@ -82,7 +77,7 @@ public: // } void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { - assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator."); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); @@ -92,10 +87,10 @@ public: if (!mInputs[0]->empty()) { const auto expectedDims = mInputs[0]->dims(); std::size_t nonEmptyInputTensor = 1; - for (; nonEmptyInputTensor<NUM && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) { + for (; nonEmptyInputTensor < nbInputs() && (!mInputs[nonEmptyInputTensor]->empty()); ++nonEmptyInputTensor) { assert(expectedDims == mInputs[nonEmptyInputTensor]->dims()); } - if (nonEmptyInputTensor == NUM) { + if (nonEmptyInputTensor == nbInputs()) { mOutput->resize(expectedDims); } } @@ -103,8 +98,8 @@ public: bool outputDimsForwarded() const override final { std::size_t forwarded = 0; - for (; forwarded < NUM && (!mInputs[forwarded]->empty()); ++forwarded) {} - return ((forwarded==NUM) && !(mOutput->empty())); + for (; forwarded < nbInputs() && (!mInputs[forwarded]->empty()); ++forwarded) {} + return ((forwarded==nbInputs()) && !(mOutput->empty())); } // void checkDims() const override final { @@ -114,13 +109,13 @@ public: // } // } inline Tensor& input(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator."); return *(mInputs[inputIdx].get()); } inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator."); return mInputs[inputIdx]; } inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { @@ -130,7 +125,7 @@ public: } std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { - assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Add operator."); return std::static_pointer_cast<Data>(mInputs[inputIdx]); } std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { @@ -141,11 +136,11 @@ public: void setBackend(const std::string& name) override { - mImpl = Registrar<Add_Op<NUM>>::create(name)(*this); + mImpl = Registrar<Add_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround - for (std::size_t i = 0; i < NUM; ++i) { + for (std::size_t i = 0; i < nbInputs(); ++i) { mInputs[i]->setBackend(name); } } @@ -154,15 +149,16 @@ public: mOutput->setDatatype(datatype); // FIXME: temporary workaround - for (std::size_t i = 0; i < NUM; ++i) { + for (std::size_t i = 0; i < nbInputs(); ++i) { mInputs[i]->setDatatype(datatype); } } - inline IOIndex_t nbInputs() const noexcept override final { return NUM; } - inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; } + inline IOIndex_t nbInputs() const noexcept override final { return mInputs.size(); } + inline IOIndex_t nbDataInputs() const noexcept override final { return mInputs.size(); } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } - static const std::vector<std::string> getInputsName(){ + + static const std::vector<std::string> getInputsName(){ return {"data_input_0", "data_input_n"}; } static const std::vector<std::string> getOutputsName(){ @@ -170,9 +166,8 @@ public: } }; -template <std::size_t NUM> -inline std::shared_ptr<Node> Add(const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<Add_Op<NUM>>(), name); +inline std::shared_ptr<Node> Add(const IOIndex_t nbIn, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Add_Op>(nbIn), name); } } diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index 490782331335a5510257e1f55ec48bfe57334ede..994239bc14c477691a39724c2778578023967da0 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -94,6 +94,42 @@ public: } + // std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { + // if (outputIdx != 0) { + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor."); + // } + // if ((outputDims.size() == (DIM+2)) && outputDimsForwarded()) { + // // Offset + // const auto outputIdxDims = mOutput->getCoord(firstIdx); + // std::vector<DimSize_t> inputIdxDims = outputIdxDims; + + // for (DimIdx_t i = 0; i < (DIM+2); ++i) { + // if (((outputDims[i] + outputIdxDims[i]) > mOutput->dims<DIM+2>()[i]) || (outputDims[i] == 0)) { + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]); + // } + // } + + // // padding is not a parameter of Conv_Op. It is handled in Pad_Op Operator + // // Width + // std::vector<DimSize_t> inputDims; + // inputDims.push_back(outputDims[0]); // same batch value + // inputDims.push_back(outputDims[1]); // same channel value + + // for (DimIdx_t i = 0; i < DIM; ++i) { + // inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1) + // * this->template getAttr<AvgPoolingAttr::StrideDims>()[static_cast<std::size_t>(i)] + // + 1 + // + (this->template getAttr<AvgPoolingAttr::KernelDims>()[static_cast<std::size_t>(i)] - 1)); + // inputIdxDims[2+i] *= this->template getAttr<AvgPoolingAttr::StrideDims>()[static_cast<std::size_t>(i)]; + // } + // std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res = std::vector<std::pair<std::size_t, std::vector<DimSize_t>>>(); + // res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(mInput->getIdx(inputIdxDims), inputDims)); + // return res; + // } + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); + // } + + void setBackend(const std::string &name) override { mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp new file mode 100644 index 0000000000000000000000000000000000000000..31b99370d12a16020ea7c6f9d35c08d9f616f10f --- /dev/null +++ b/include/aidge/operator/Concat.hpp @@ -0,0 +1,192 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_CONCAT_H_ +#define AIDGE_CORE_OPERATOR_CONCAT_H_ + +#include <numeric> +#include <vector> +#include <cmath> +#include <memory> +#include <vector> + +#include "aidge/utils/Registrar.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class ConcatAttr { Axis }; + +class Concat_Op : public Operator, + public Registrable<Concat_Op, std::string, std::unique_ptr<OperatorImpl>(const Concat_Op&)>, + public StaticAttributes<ConcatAttr, DimSize_t> { +private: + // FIXME: change accessibility + std::vector<std::shared_ptr<Tensor>> mInputs; + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char* Type = "Concat"; + + using Attributes_ = StaticAttributes<ConcatAttr, DimSize_t>; + template <ConcatAttr e> + using attr = typename Attributes_::template attr<e>; + + Concat_Op(const IOIndex_t nbIn, const DimSize_t axis) + : Operator(Type), + mInputs(std::vector<std::shared_ptr<Tensor>>(nbIn, std::make_shared<Tensor>())), + Attributes_(attr<ConcatAttr::Axis>(axis)) + { + assert(nbIn > 0 && "Concat should have at least one input"); + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Concat_Op(const Concat_Op& op) + : Operator(Type), + Attributes_(op), + mInputs(std::vector<std::shared_ptr<Tensor>>(op.nbInputs(), std::make_shared<Tensor>())), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + assert(op.nbInputs() > 0 && "Concat should have at least one input"); + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Concat_Op>::create(mOutput->getImpl()->backend())(*this) : nullptr; + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Concat_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Concat_Op>(*this); + } + + // Data operator[](const char* inputName) override final { + // std::shared_ptr<Tensor> in = (strcmp(inputName, "data")) ? mInputs[0] : + // (strcmp(inputName, "weight") ? mInputs[1] : + // (strcmp(inputName, "bias") ? mInputs[2] : + // nullptr)); + // assert((in!=nullptr) && "No such parameter"); + // return *in; + // } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator."); + assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); + + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + bool computable = !(mInputs[0]->empty()) && (getAttr<ConcatAttr::Axis>() < mInputs[0]->nbDims()); + for (const auto& input : mInputs) { + computable &= !(input->empty()); + computable &= (input->nbDims() == mInputs[0]->nbDims()); + } + // Every input is non-empty with the same number of dimensions + if (computable) { + auto outputDims = mInputs[0]->dims(); + + for (std::size_t i = 1; i < nbInputs(); ++i) { + outputDims[getAttr<ConcatAttr::Axis>()] += mInputs[i]->dims()[getAttr<ConcatAttr::Axis>()]; + } + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { + return !(mOutput->empty()); + } + + // void checkDims() const override final { + // assert(outputDimsForwarded()); + // for (const auto& in : mInputs) { + // assert(in->dims() == mOutput->dims()); + // } + // } + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator."); + return *(mInputs[inputIdx].get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator."); + return mInputs[inputIdx]; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "Concat Operators has only 1 outputs"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(static_cast<std::size_t>(inputIdx) < nbInputs() && "wrong inputIdx for Concat operator."); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string& name) override { + mImpl = Registrar<Concat_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + for (std::size_t i = 0; i < nbInputs(); ++i) { + mInputs[i]->setBackend(name); + } + } + + void setDatatype(const DataType& datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + for (std::size_t i = 0; i < nbInputs(); ++i) { + mInputs[i]->setDatatype(datatype); + } + } + + inline IOIndex_t nbInputs() const noexcept override final { return mInputs.size(); } + inline IOIndex_t nbDataInputs() const noexcept override final { return mInputs.size(); } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + + static const std::vector<std::string> getInputsName(){ + return {"data_input_0", "data_input_n"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +inline std::shared_ptr<Node> Concat(const IOIndex_t nbIn, const DimIdx_t axis = 0, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Concat_Op>(nbIn, axis), name); +} +} + +namespace { + template <> + const char* const EnumStrings<Aidge::ConcatAttr>::data[] = { + "Axis" + }; +} + +#endif /* AIDGE_CORE_OPERATOR_CONCAT_H_ */ diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 62f2446f33cf1fa2b5e4051a9a9798b234fb03ab..5bce05413f344a7b69f9fcb1b31e989e6fbbe73a 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -14,6 +14,7 @@ #include <array> #include <cmath> +#include <cstddef> #include <numeric> #include <vector> @@ -113,6 +114,57 @@ public: } } + +// std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { + // if (outputIdx != 0) { + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor."); + // } + // if ((outputDims.size() == (DIM+2)) && outputDimsForwarded()) { + // // Offset + // const auto outputIdxDims = mOutput->getCoord(firstIdx); + // auto inputIdxDims = outputIdxDims; // batch idx is the same + // inputIdxDims[1] = 0; // each channel is used so start with the first one + + // for (DimIdx_t i = 0; i < (DIM+2); ++i) { + // if (((outputDims[i] + outputIdxDims[i]) > mOutput->dims<DIM+2>()[i]) || (outputDims[i] == 0)) { + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]); + // } + // } + + // // padding is not a parameter of Conv_Op. It is handled in Pad_Op Operator + // // Input + // // same batch value, every input channel is used + // std::vector<DimSize_t> inputDims{outputDims[0], mInputs[0]->dims()[1]}; + // for (DimIdx_t i = 0; i < DIM; ++i) { + // inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1) + // * this->template getAttr<ConvAttr::StrideDims>()[static_cast<std::size_t>(i)] + // + 1 + // + (this->template getAttr<ConvAttr::KernelDims>()[static_cast<std::size_t>(i)] - 1) + // * this->template getAttr<ConvAttr::DilationDims>()[static_cast<std::size_t>(i)]); + // inputIdxDims[2+i] *= this->template getAttr<ConvAttr::StrideDims>()[static_cast<std::size_t>(i)]; + // } + + // // Weight + // // same output value, every input channel is used + // std::vector<DimSize_t> weightDims{outputDims[0], mInputs[0]->dims()[1]}; + // weightDims.insert(weightDims.end(), this->template getAttr<ConvAttr::KernelDims>()[0], this->template getAttr<ConvAttr::KernelDims>()[static_cast<std::size_t>(DIM)]); + // std::vector<DimSize_t> weightIdxDims = std::vector<DimSize_t>(DIM+2, 0); + // weightIdxDims[0] = outputIdxDims[1]; + + // // Bias + // const std::vector<DimSize_t> biasDims{outputDims[0]}; + // const std::vector<DimSize_t> biasIdxDims{outputIdxDims[1]}; + + // // Result + // std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res; + // res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(mInputs[0]->getIdx(inputIdxDims), inputDims)); + // res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(mInputs[1]->getIdx(weightIdxDims), weightDims)); + // res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(mInputs[2]->getIdx(biasIdxDims), biasDims)); + // return res; + // } + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); + // } + void setBackend(const std::string &name) override { mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 4caec2032a3c61529d452ae855f00c1da411af10..f58f435ac0a3aebfecb069955c6769d2f4dd353a 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -124,6 +124,41 @@ class ConvDepthWise_Op : public Operator, bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + // std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const override { + // if (outputIdx != 0) { + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Conv_Op Operator has got only one output Tensor."); + // } + // if ((outputDims.size() == (DIM+2)) && outputDimsForwarded()) { + // // Offset + // const auto outputIdxDims = mOutput->getCoord(firstIdx); + // auto inputIdxDims = outputIdxDims; // batch idx is the same + + // for (DimIdx_t i = 0; i < (DIM+2); ++i) { + // if (((outputDims[i] + outputIdxDims[i]) > mOutput->dims<DIM+2>()[i]) || (outputDims[i] == 0)) { + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]); + // } + // } + + // // padding is not a parameter of Conv_Op. It is handled in Pad_Op Operator + // // Width + // std::vector<DimSize_t> inputDims; + // inputDims.push_back(outputDims[0]); // same batch value + // inputDims.push_back(outputDims[1]); // same channel value + + // for (DimIdx_t i = 0; i < DIM; ++i) { + // inputDims.push_back((outputDims[2+static_cast<std::size_t>(i)] - 1) + // * this->template getAttr<ConvDepthWiseAttr::StrideDims>()[static_cast<std::size_t>(i)] + // + 1 + // + (this->template getAttr<ConvDepthWiseAttr::KernelDims>()[static_cast<std::size_t>(i)] - 1) + // * this->template getAttr<ConvDepthWiseAttr::DilationDims>()[static_cast<std::size_t>(i)]); + // inputIdxDims[2+i] *= this->template getAttr<ConvDepthWiseAttr::StrideDims>()[static_cast<std::size_t>(i)]; + // } + // std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> res = std::vector<std::pair<std::size_t, std::vector<DimSize_t>>>(); + // res.push_back(std::pair<std::size_t, std::vector<DimSize_t>>(mInputs[0]->getIdx(inputIdxDims), inputDims)); + // return res; + // } + // AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); + // } inline Tensor& input(const IOIndex_t inputIdx) const override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index aa10dea195bb231ed701318cef6e2bfd04a3b7ff..e6e61ff7e20562069332c1a1f572c43a3ff2f14d 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -15,6 +15,8 @@ #include <memory> #include <string> #include <vector> +#include <utility> +#include <cstddef> #include "aidge/backend/OperatorImpl.hpp" #include "aidge/data/Data.hpp" @@ -64,7 +66,15 @@ public: public: virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>* data) = 0; - + /** + * @brief For a given output feature area, compute the associated receptive + * field for each data input. + * @param firstIdx First index of the output feature. + * @param outputDims Size of output feature. + * @param outputIdx Index of the output. Default 0. + * @return std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> For each dataInput Tensor of the Operator, the first index and dimensions of the feature area. + */ + // virtual std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveField(const std::size_t firstIdx, const std::vector<DimSize_t>& outputDims, const IOIndex_t outputIdx = 0) const; virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0; virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0; diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 353666fb3950d034a7dbe8ec1d3ebdb312679f95..40bc397fa71ca6802c7a5a00394af7850fc73a12 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -28,12 +28,12 @@ namespace Aidge { enum class ScalingAttr { - scalingFactor + scalingFactor, quantizedNbBits, isOutputUnsigned }; class Scaling_Op : public Operator, public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, - public StaticAttributes<ScalingAttr, float> { + public StaticAttributes<ScalingAttr, float, size_t, bool> { public: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -44,16 +44,18 @@ public: Scaling_Op() = delete; - using Attributes_ = StaticAttributes<ScalingAttr, float>; + using Attributes_ = StaticAttributes<ScalingAttr, float, std::size_t, bool>; template <ScalingAttr e> using attr = typename Attributes_::template attr<e>; - Scaling_Op(float scalingFactor) + Scaling_Op(float scalingFactor, std::size_t nbBits, bool isOutputUnsigned) : Operator(Type), Attributes_( - attr<ScalingAttr::scalingFactor>(scalingFactor)) - { - setDatatype(DataType::Float32); - } + attr<ScalingAttr::scalingFactor>(scalingFactor), + attr<ScalingAttr::quantizedNbBits>(nbBits), + attr<ScalingAttr::isOutputUnsigned>(isOutputUnsigned)) { + + setDatatype(DataType::Float32); + } /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -154,15 +156,21 @@ public: } }; +/* inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name); } +*/ +inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name); +} + } namespace { template <> const char* const EnumStrings<Aidge::ScalingAttr>::data[] - = {"scalingFactor"}; + = {"scalingFactor", "quantizedNbBits", "isOutputUnsigned"}; } #endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */ diff --git a/include/aidge/operator/Slice.hpp b/include/aidge/operator/Slice.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6b1077c8d383b1fd4c1ebe6f64a4cad8135c594e --- /dev/null +++ b/include/aidge/operator/Slice.hpp @@ -0,0 +1,185 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_SLICE_H_ +#define AIDGE_CORE_OPERATOR_SLICE_H_ + +#include <memory> +#include <vector> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class SliceAttr { Beginning, SliceDims }; + +template <DimIdx_t DIM> +class Slice_Op + : public Operator, + public Registrable<Slice_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Slice_Op<DIM> &)>, + public StaticAttributes<SliceAttr, std::size_t, std::array<DimSize_t, DIM>> { +public: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char *Type = "Slice"; + + Slice_Op() = delete; + + using Attributes_ = StaticAttributes<SliceAttr, std::size_t, std::array<DimSize_t, DIM>>; + template <SliceAttr e> + using attr = typename Attributes_::template attr<e>; + + Slice_Op(std::size_t beginningPos, std::array<DimSize_t, DIM> sliceDims) + : Operator(Type), + Attributes_(attr<SliceAttr::Beginning>(beginningPos), + attr<SliceAttr::SliceDims>(sliceDims)) + { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its + * input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Slice_Op(const Slice_Op &op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + mImpl = op.mImpl ? Registrar<Slice_Op<DIM>>::create(mOutput->getImpl()->backend())(*this) + : nullptr; + } + +public: + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Slice_Op + */ + std::shared_ptr<Operator> clone() const override { return std::make_shared<Slice_Op>(*this); } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + (void)inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) { + // Check input dimensions is compatible with slice dimensions + if (mInput->nbDims() != DIM) { + printf("Error: input and slice dimensions are not the same size.\n"); + exit(-1); + } + std::array<DimSize_t, DIM> outputDims; + + // Check that the sliced Tensor is actually part of the input Tensor + // For a 5*5 tensor ('x') and a 3*3 slice kernel ('o'): + // xxxxx xxxxx + // xxxxx xxxxx + // xxooo --> ok xxxoo --> out of bound + // xxooo xxxoo + // xxooo xxxoo + std::vector<std::size_t> beginningCoords = mInput->getCoord(this->template getAttr<SliceAttr::Beginning>()); + for (std::size_t i = 0; i < DIM; ++i) { + if (beginningCoords[i] + this->template getAttr<SliceAttr::SliceDims>()[i] > mInput->dims()[i]) { + printf("ROI of Slice operator out of bounds"); + exit(-1); + } else { + outputDims[i] = this->template getAttr<SliceAttr::SliceDims>()[i]; + } + } + + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + inline Tensor &input(const IOIndex_t /*inputIdx*/) const override final { + return *(mInput.get()); + } + inline Tensor &output(const IOIndex_t /*outputIdx*/) const override final { + return *(mOutput.get()); + } + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert((inputIdx == 0) && "Slice Operator has only 1 input"); + (void)inputIdx; // avoid unused warning + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert((outputIdx == 0) && "Slice Operator has only 1 output"); + (void)outputIdx; // avoid unused warning + return mOutput; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operator supports only 1 input"); + (void)inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void)outputIdx; // avoid unused warning + return mOutput; + } + + void setBackend(const std::string &name) { + mImpl = Registrar<Slice_Op>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + void setDatatype(const DataType &datatype) { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 1; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } +}; + +template <std::size_t DIM> +inline std::shared_ptr<Node> Slice(std::size_t beginningPos, std::array<DimSize_t, DIM> sliceDims, + const std::string &name = "") { + // FIXME: properly handle default w&b initialization in every cases + return std::make_shared<Node>(std::make_shared<Slice_Op<DIM>>( beginningPos, sliceDims), name); +} + +template <DimIdx_t DIM> +inline std::shared_ptr<Node> Slice(std::size_t beginningPos, DimSize_t const (&sliceDims)[DIM], const std::string& name = "") { + return Slice(beginningPos, to_array(sliceDims), name); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::SliceAttr>::data[] = { "Beginning", "SliceDims" }; +} + +#endif /* AIDGE_CORE_OPERATOR_RELU_H_ */ diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/recipies/Recipies.hpp similarity index 67% rename from include/aidge/utils/Recipies.hpp rename to include/aidge/recipies/Recipies.hpp index c110c9cf8e2ccc84112f7ac48b438f470ee21465..97544937e312c89636fd43098a54a1c63c50dd38 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/recipies/Recipies.hpp @@ -17,8 +17,10 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" -namespace Aidge{ + +namespace Aidge { // FUSE MATMUL + ADD -> FC @@ -27,7 +29,12 @@ namespace Aidge{ * * @param nodes Strict set of Node to merge. */ -void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); +//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); + +void fuseMulAdd(std::shared_ptr<MatchSolution> solution); + +void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); + /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * @@ -43,7 +50,11 @@ void fuseMulAdd(std::shared_ptr<GraphView> graphView); * * @param nodes Strict set of Node to merge. */ -void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +void removeFlatten(std::shared_ptr<Node> flatten); + + +void removeFlatten(std::shared_ptr<MatchSolution> solution); + /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * @@ -59,7 +70,12 @@ void removeFlatten(std::shared_ptr<GraphView> graphView); * * @param nodes Strict set of Node to merge. */ -void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes); +void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm); + + + +void fuseBatchNorm(std::shared_ptr<MatchSolution> solution); + /** * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ @@ -68,6 +84,11 @@ void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes); */ void fuseBatchNorm(std::shared_ptr<GraphView> graphView); +// std::set<std::shared_ptr<Node>> getHorizontalTiling(const std::shared_ptr<Node>& node, const DimIdx_t axis, const std::size_t nbSlices); +// void horizontalTiling(std::shared_ptr<Node> node, DimIdx_t dim, std::size_t nbSlices); +// std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices); +// void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices); + } #endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 1896894ee8690cedaef696394da0829604e36211..faf6c49bdbe28e7214f06a4d116cf23a1739154f 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -64,6 +64,9 @@ public: std::vector<std::shared_ptr<Node>> getStaticScheduling(){ return mStaticSchedule; } + std::shared_ptr<GraphView> getGraphView(){ + return mGraphView; + } private: /** diff --git a/include/aidge/utils/StaticAttributes.hpp b/include/aidge/utils/StaticAttributes.hpp index b67f69ae7afc2c22f3b424812ec994b10974b668..50ed0895e82bb468dee57264534f0ec3a486a815 100644 --- a/include/aidge/utils/StaticAttributes.hpp +++ b/include/aidge/utils/StaticAttributes.hpp @@ -22,8 +22,8 @@ namespace Aidge { /** - * @brief This class is designed to handle static attributes (i.e. known at compile-time) - * with named accessors, with minimal overhead (the name strings are not stored in each object + * @brief This class is designed to handle static attributes (i.e. known at compile-time) + * with named accessors, with minimal overhead (the name strings are not stored in each object * instance and it remains possible to access attribute without overhead at compile-time). */ template <class ATTRS_ENUM, class ...T> @@ -97,6 +97,17 @@ public: AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"%s\" not found", name); } + template <typename R> + const R& getAttr(const char* name) const { + for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { + if (strcmp(EnumStrings<ATTRS_ENUM>::data[i], name) == 0) { + return getAttr<R>(i); + } + } + + AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute \"%s\" not found", name); + } + template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> typename std::enable_if<(SIZE > 0), R&>::type getAttr(std::size_t i) { if (i == SIZE-1) { @@ -117,6 +128,26 @@ public: AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute not found"); } + template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> + typename std::enable_if<(SIZE > 0), const R&>::type getAttr(std::size_t i) const { + if (i == SIZE-1) { + if (std::is_same<R, typename std::tuple_element<SIZE-1,std::tuple<T...>>::type>::value) { + return reinterpret_cast<const R&>(std::get<SIZE-1>(mAttrs)); + } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "wrong type for attribute with index %lu", i); + } + } + else { + return getAttr<R, SIZE-1>(i); + } + } + + template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> + [[noreturn]] typename std::enable_if<(SIZE == 0), const R&>::type getAttr(std::size_t /*i*/) const { + AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute not found"); + } + template <std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> constexpr typename std::enable_if<(SIZE > 0), const std::type_info&>::type getAttrType(std::size_t i) const { if (i == SIZE-1) { diff --git a/python_binding/graphRegex/pybind_GraphRegex.cpp b/python_binding/graphRegex/pybind_GraphRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be3cd9e9124ba1306226dcbdc13ee39748cf0606 --- /dev/null +++ b/python_binding/graphRegex/pybind_GraphRegex.cpp @@ -0,0 +1,69 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "aidge/graphRegex/GraphRegex.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_GraphRegex(py::module& m){ + + + py::class_<GraphRegex, std::shared_ptr<GraphRegex>>(m, "GraphRegex", "GraphRegex class describes a regex to test a graph.") + .def(py::init<>()) + + .def("add_query", &GraphRegex::addQuery, R"mydelimiter( + :rtype: str + )mydelimiter") + + .def("set_key_from_graph", &GraphRegex::setKeyFromGraph, R"mydelimiter( + :param ref: The graph use to define type of Node. + :type ref: :py:class:`aidge_core.GraphView` + )mydelimiter") + +// void setNodeKey(const std::string key, const std::string conditionalExpressions ); +// void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); + + .def("match", &GraphRegex::match, R"mydelimiter( + :param graphToMatch: The graph to perform the matching algorithm on. + :type graphToMatch: :py:class:`aidge_core.GraphView` + )mydelimiter") + + + + .def("set_node_key", + (void (GraphRegex::*)(const std::string, const std::string )) & + GraphRegex::setNodeKey, + py::arg("key"), py::arg("conditionalExpressions"), + R"mydelimiter( + Add a node test + :param key: the key of the node test to use in the query. + :param conditionalExpressions: the test to do . + + )mydelimiter") + + + .def("set_node_key", + (void (GraphRegex::*)(const std::string, std::function<bool(NodePtr)>)) & + GraphRegex::setNodeKey, + py::arg("key"), py::arg("f"), + R"mydelimiter( + Add a node test + :param key: the key of the lambda test to use in the conditional expressions. + :param f: bool lambda (nodePtr) . + + )mydelimiter") + + + + ; +} +} diff --git a/python_binding/graphmatching/pybind_GRegex.cpp b/python_binding/graphmatching/pybind_GRegex.cpp deleted file mode 100644 index 48d0e19ff22c1480636b67b5bde70bf1caa1f1b5..0000000000000000000000000000000000000000 --- a/python_binding/graphmatching/pybind_GRegex.cpp +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> -#include "aidge/graph/GraphView.hpp" -#include "aidge/graphmatching/GRegex.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_GRegex(py::module& m){ - py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex", "GRegex class combines a Node Regex and a list of Graph Regex that together describes a graph pattern as a graph regular expression. GRegex find patterns in a given graph that matches the graph regular expression.") - .def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps"), R"mydelimiter( - Constructor of GRegex - - :param nodesRegex: Describe the conditions an operator has to fulfill. - :type nodesRegex: Dict[str,:py:class:`aidge_core.NodeRegex`] - :param seqRegexps: Describe the graph topological pattern. List of Graph Regex as strings. - :type seqRegexps: List[str] - - )mydelimiter") - .def("match", &GRegex::match, py::arg("graphToMatch"), R"mydelimiter( - Launch the graph matching algorithm on a given graph. - - :param graphToMatch: The graph to perform the matching algorithm on. - :type graphToMatch: :py:class:`aidge_core.GraphView` - - :returns: Matched graph patterns. - :rtype: :py:class:`aidge_core.Match` - - )mydelimiter") - ; -} -} diff --git a/python_binding/graphmatching/pybind_Match.cpp b/python_binding/graphmatching/pybind_Match.cpp deleted file mode 100644 index a2d2654f40ed50e20e8761be57e2c8bb98ce4e3b..0000000000000000000000000000000000000000 --- a/python_binding/graphmatching/pybind_Match.cpp +++ /dev/null @@ -1,34 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> -#include "aidge/graphmatching/Match.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_Match(py::module& m){ - py::class_<Match, std::shared_ptr<Match>>(m, "Match", "Match class stores the matched patterns resulting from a graph matching query. A matched pattern is the combinaison of the graph pattern start nodes and the set of all the nodes in the matched pattern (including the start nodes)") - .def(py::init<>()) - .def("get_nb_match", &Match::getNbMatch, R"mydelimiter( - :returns: The number of graph patterns matched - :rtype: int - )mydelimiter") - .def("get_start_nodes", &Match::getStartNodes, R"mydelimiter( - :returns: All matched graph patterns start nodes - :rtype: List[List[:py:class:`aidge_core.Nodes`]] - )mydelimiter") - .def("get_match_nodes", &Match::getMatchNodes, R"mydelimiter( - :returns: All matched graph patterns sets of matched nodes - :rtype: List[Set[:py:class:`aidge_core.Nodes`]] - )mydelimiter"); -} -} diff --git a/python_binding/graphmatching/pybind_NodeRegex.cpp b/python_binding/graphmatching/pybind_NodeRegex.cpp deleted file mode 100644 index 034987f9ccae200a1b8877ecd8b3e878c84e8fc3..0000000000000000000000000000000000000000 --- a/python_binding/graphmatching/pybind_NodeRegex.cpp +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include "aidge/graphmatching/NodeRegex.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_NodeRegex(py::module& m){ - py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex", "NodeRegex class describes a condition to test on any operator. Current version only supports testing the type of the operator.") - .def(py::init<const std::string>(), py::arg("condition"), R"mydelimiter( - Constructor of NodeRegex - - :param condition: Condition to be fulfilled by an operator. - :type condition: str - - )mydelimiter") - ; -} -} diff --git a/python_binding/operator/pybind_Add.cpp b/python_binding/operator/pybind_Add.cpp index 0b2323c5cfb660415ec3ae009beaa7aa78afca0b..bff795a7326ce9c07c2550aba4396502306d6d8b 100644 --- a/python_binding/operator/pybind_Add.cpp +++ b/python_binding/operator/pybind_Add.cpp @@ -19,15 +19,15 @@ namespace py = pybind11; namespace Aidge { -template <std::size_t NUM> void declare_Add(py::module &m) { - py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "AddOp", py::multiple_inheritance()) - .def("get_inputs_name", &Add_Op<NUM>::getInputsName) - .def("get_outputs_name", &Add_Op<NUM>::getOutputsName); +void declare_Add(py::module &m) { + py::class_<Add_Op, std::shared_ptr<Add_Op>, Operator>(m, "AddOp", py::multiple_inheritance()) + .def("get_inputs_name", &Add_Op::getInputsName) + .def("get_outputs_name", &Add_Op::getOutputsName); - m.def("Add", &Add<NUM>, py::arg("name") = ""); + m.def("Add", &Add, py::arg("nbIn"), py::arg("name") = ""); } void init_Add(py::module &m) { - declare_Add<2>(m); + declare_Add(m); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index 6b535e8cf3293b26aaa64f95ca2f9a394768935f..ef02b8aaef9f4ea3bd97559ad9e94c38c5b1d29e 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -20,13 +20,17 @@ void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("output", &Operator::output, py::arg("outputIdx")) .def("input", &Operator::input, py::arg("inputIdx")) + .def("nb_inputs", &Operator::nbInputs) .def("nb_data_inputs", &Operator::nbDataInputs) + .def("nb_outputs", &Operator::nbOutputs) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_backend", &Operator::setBackend, py::arg("name")) .def("forward", &Operator::forward) // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) + .def("get_hook", &Operator::getHook) + .def("add_hook", &Operator::addHook) ; } } diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index a482191c78ff56b000e043cd7350ca1c150d1d6e..6cc597b5ee934e4a3b849d45e92e5cb62be1b312 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -45,9 +45,7 @@ void init_GraphView(py::module&); void init_OpArgs(py::module&); void init_Connector(py::module&); -void init_Match(py::module&); -void init_NodeRegex(py::module&); -void init_GRegex(py::module&); +void init_GraphRegex(py::module&); void init_Recipies(py::module&); @@ -87,9 +85,8 @@ void init_Aidge(py::module& m){ init_Sub(m); init_Producer(m); - init_Match(m); - init_NodeRegex(m); - init_GRegex(m); + init_GraphRegex(m); + init_Recipies(m); init_Scheduler(m); init_TensorUtils(m); diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 93c131ef7417135bfdbc657c5c809339430616ed..820b6e12b11116b874170bd25a6dc75675894257 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -14,7 +14,7 @@ #include <string> -#include "aidge/utils/Recipies.hpp" +#include "aidge/recipies/Recipies.hpp" namespace py = pybind11; @@ -28,12 +28,13 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. - :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The MatMul and Add nodes to fuse. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. @@ -41,18 +42,20 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( - Recipie to remove a flatten operator. - :param nodes: The flatten operator to remove. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); - m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + // m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( + // Recipie to remove a flatten operator. - :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The flatten operator to remove. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); + + // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + // :param nodes: The MatMul and Add nodes to fuse. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. @@ -60,11 +63,12 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( - Recipie to remove a flatten operator. + + // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( + // Recipie to remove a flatten operator. - :param nodes: The flatten operator to remove. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The flatten operator to remove. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); } } // namespace Aidge diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 13b9d63dca3960911f49fb2f76ca3e13b6af3d56..af3e24c20bc9832f31e73091c156820a2e0f6584 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -165,6 +165,19 @@ Aidge::GraphView::inputs(std::string name) const { return mNodeRegistry.at(name)->inputs(); } +void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType datatype) { + // Backend + // TODO: add Backend attribute to Operator + setBackend(backend); + // Data type + // TODO: manage Datatype attribute in OperatorImpl + setDatatype(datatype); + // Data Format + // TODO: check actual parent output data format and the needed one. Add a Transpose Operator if necessary + // Forward dimensions + forwardDims(); +} + void Aidge::GraphView::forwardDims() { // setInputs // Link every tensor to the right pointer @@ -226,7 +239,7 @@ void Aidge::GraphView::setBackend(const std::string &backend) { } } -void Aidge::GraphView::setDatatype(const DataType &datatype) { +void Aidge::GraphView::setDatatype(const Aidge::DataType &datatype) { for (auto node : getNodes()) { node->getOperator()->setDatatype(datatype); } diff --git a/src/graphRegex/GraphFsmInterpreter.cpp b/src/graphRegex/GraphFsmInterpreter.cpp index 2984ab4fb3864244c9e32dbfcda9ef2ae080acf0..03e86487513065af47d91fc5265335bba456e64e 100644 --- a/src/graphRegex/GraphFsmInterpreter.cpp +++ b/src/graphRegex/GraphFsmInterpreter.cpp @@ -3,15 +3,24 @@ using namespace Aidge; -GraphFsmInterpreter::GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition):mParser(graphMatchExpr){ +GraphFsmInterpreter::GraphFsmInterpreter(const std::string graphMatchExpr,std::vector<std::shared_ptr<ConditionalInterpreter>>&nodesCondition):mParser(graphMatchExpr){ mActGroupe = 0; - mNodesCondition = nodesCondition; + + for (const auto &obj : nodesCondition) { + if(mNodesCondition.find(obj->getKey()) ==mNodesCondition.end()){ + mNodesCondition[obj->getKey()] = obj; + }else{ + throw std::logic_error("GraphFsmInterpreter Bad Key" ); + } + } } std::shared_ptr<FsmGraph> GraphFsmInterpreter::interpret(void){ mActGroupe = 0; std::shared_ptr<AstNode<gRegexTokenTypes>> tree = mParser.parse(); - return visit(tree); + std::shared_ptr<FsmGraph> out = visit(tree); + return out; } + std::shared_ptr<FsmGraph> GraphFsmInterpreter::visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree){ std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>> nextAstNodes = AstTree->getChilds(); @@ -44,7 +53,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::keyF(std::shared_ptr<AstNode<gReg std::shared_ptr<FsmNode> start = std::make_shared<FsmNode>(false,true); std::shared_ptr<FsmNode> valid = std::make_shared<FsmNode>(true,false); - std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(mParser.getQuery()); std::shared_ptr<FsmEdge> edge; @@ -66,7 +75,7 @@ std::shared_ptr<FsmGraph> GraphFsmInterpreter::keyF(std::shared_ptr<AstNode<gReg std::shared_ptr<FsmGraph> GraphFsmInterpreter::sepF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm){ size_t idxLeft = leftFsm->getNbSubFsm(); - rigthFsm->incOrigineAllNodeBy(idxLeft); + rigthFsm->incOriginAllNodeBy(idxLeft); leftFsm->unionG(rigthFsm); //the rigthFsm is no longer usfull return leftFsm; diff --git a/src/graphRegex/GraphLexer.cpp b/src/graphRegex/GraphLexer.cpp index 61214f96a090fef5d28cb0ce1a009644d9570880..f504ad025940c88058ce5949259c464ae2cedfb6 100644 --- a/src/graphRegex/GraphLexer.cpp +++ b/src/graphRegex/GraphLexer.cpp @@ -133,6 +133,11 @@ bool GraphLexer::isEnd(void){ return mPosition >= mRegularExpressions.length(); } + +const std::string GraphLexer::getQuery(){ + return mRegularExpressions; +} + std::runtime_error GraphLexer::badTokenError(const std::string& currentChars,std::size_t position){ std::ostringstream errorMessage; errorMessage << "\nBad syntax " << currentChars << " :\n" << mRegularExpressions << "\n"; diff --git a/src/graphRegex/GraphParser.cpp b/src/graphRegex/GraphParser.cpp index 5aa653c482dae82c2e9fa02bfc36b2ffc821785f..9c3d10114d777cf7755432a5723a3b70b81d37a1 100644 --- a/src/graphRegex/GraphParser.cpp +++ b/src/graphRegex/GraphParser.cpp @@ -9,6 +9,10 @@ mLexer(gRegexExpressions) } +const std::string GraphParser::getQuery(){ + return mLexer.getQuery(); +} + std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::parse(void){ std::shared_ptr<AstNode<gRegexTokenTypes>> astTree = constructAstAllExpr(); diff --git a/src/graphRegex/GraphRegex.cpp b/src/graphRegex/GraphRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef0db8c88f3e753f9b9633b1ffb05bbec6d00424 --- /dev/null +++ b/src/graphRegex/GraphRegex.cpp @@ -0,0 +1,140 @@ +#include "aidge/graphRegex/GraphRegex.hpp" +using namespace Aidge; + + +void GraphRegex::setKeyFromGraph(std::shared_ptr<GraphView> ref){ + + for (const NodePtr& node : ref->getNodes()) { + std::string type = node->type(); + bool isIn = false; + for(const auto &test:mAllTest){ + if(test->getKey() == type){ + isIn = true; + break; + } + } + if(!isIn){ + mAllTest.push_back(std::make_shared<ConditionalInterpreter>(type,"getType($) =='" + type + "'")); + } + // auto it = mAllTest.find(type); + // if (it == mAllTest.end()) { + // mAllTest[type] = std::make_shared<ConditionalInterpreter>(type,"getType($) =='" + type + "'"); + // } + // //if the key exist it's ok, but not make 2 ConditionalInterpreter + } +} + + + +void GraphRegex::addQuery(const std::string query){ + mQuery.push_back(query); +} + + + +// Function to generate all combinations of n elements from a set +void GraphRegex::_generateCombinationsStart(const std::set<NodePtr>& elements, std::size_t n, std::size_t index, std::vector<NodePtr>& current, std::set<std::vector<NodePtr>>& combinations) { + if (n == 0) { + combinations.insert(current); + return; + } + for (auto it = elements.begin(); it != elements.end(); ++it) { + current.push_back(*it); + _generateCombinationsStart(elements, n - 1, index + 1, current, combinations); + current.pop_back(); + } +} + + +void GraphRegex::_findLargestCompatibleSet( + const std::vector<std::shared_ptr<MatchSolution>>& solutions, + std::set<std::shared_ptr<MatchSolution>>& currentSet, + std::set<std::shared_ptr<MatchSolution>>& largestSet, + size_t currentIndex +) { + if (currentIndex >= solutions.size()) { + if (currentSet.size() > largestSet.size()) { + largestSet = currentSet; + } + return; + } + + for (size_t i = currentIndex; i < solutions.size(); ++i) { + if (std::all_of(currentSet.begin(), currentSet.end(), + [&](const std::shared_ptr<MatchSolution>& solution) { + return solution->areCompatible(solutions[i]); + } + )) { + currentSet.insert(solutions[i]); + _findLargestCompatibleSet(solutions, currentSet, largestSet, i + 1); + currentSet.erase(solutions[i]); + } + } +} + +std::set<std::shared_ptr<MatchSolution>> GraphRegex::_findLargestCompatibleSet( + const std::vector<std::shared_ptr<MatchSolution>>& solutions +) { + std::set<std::shared_ptr<MatchSolution>> largestSet; + std::set<std::shared_ptr<MatchSolution>> currentSet; + _findLargestCompatibleSet(solutions, currentSet, largestSet, 0); + return largestSet; +} + + + +std::set<std::shared_ptr<MatchSolution>> GraphRegex::match(std::shared_ptr<GraphView> ref){ + + std::vector<std::shared_ptr<MatchSolution>> solutions = {}; + + for (const std::string& query : mQuery) { + + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>(query,mAllTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + // generate all the start possibility + std::size_t nb_startSt = fsm->getNbStart(); + std::set<std::vector<NodePtr>> combinations; + std::vector<NodePtr> current; + _generateCombinationsStart(ref->getNodes(), nb_startSt, 0, current, combinations); + + + // all start + for (const auto& combination : combinations) { + std::vector<std::shared_ptr<MatchSolution>> solution = fsm->test(combination); + solutions.insert(solutions.end(), solution.begin(), solution.end()); + } + } + return _findLargestCompatibleSet(solutions); +} + +void GraphRegex::setNodeKey(const std::string key, const std::string conditionalExpressions ){ + mAllTest.push_back(std::make_shared<ConditionalInterpreter>(key,conditionalExpressions)); + _majConditionalInterpreterLambda(); +} + + +void GraphRegex::setNodeKey(const std::string key,std::function<bool(NodePtr)> f){ + //we can applied to all key but it's not efficient + if(mAllLambda.find(key) != mAllLambda.end()){ + throw std::runtime_error(key + " is define"); + } + mAllLambda[key] = f; + _majConditionalInterpreterLambda(); +} + +void GraphRegex::_majConditionalInterpreterLambda(){ + + for (const auto& test : mAllTest) { + for (const auto& pair : mAllLambda) { + const std::string& key = pair.first; + const std::function<bool(NodePtr)>& lambda = pair.second; + + if(!test->isLambdaRegister(key)){ + test->insertLambda(key,lambda); + } + + } + } +} + diff --git a/src/graphRegex/matchFsm/FsmEdge.cpp b/src/graphRegex/matchFsm/FsmEdge.cpp index 593da06abe18576d435ae55718d379aa5b682d60..ab307e023209ab770fc63f0550811279bd42eb46 100644 --- a/src/graphRegex/matchFsm/FsmEdge.cpp +++ b/src/graphRegex/matchFsm/FsmEdge.cpp @@ -24,7 +24,7 @@ void FsmEdge::updateRelative( const std::map<size_t,int>& relativePos ){ std::shared_ptr<FsmNode> FsmEdge::getSourceNode(void){ return mNodeSource; } -void FsmEdge::reSetSouceNode(const std::shared_ptr<FsmNode>& newSource){ +void FsmEdge::reSetSourceNode(const std::shared_ptr<FsmNode>& newSource){ mNodeSource->rmEdge(shared_from_this()); mNodeSource = newSource; mNodeSource->addEdge(shared_from_this()); @@ -258,6 +258,11 @@ const std::string lexeme) std::string commonId = m[2]; size_t commonIdx = commonId.empty() ? 0 : std::stoi(commonId) + 1; std::string commonKey = edgeType + std::to_string(commonIdx); + + if(allTest.find(edgeType) == allTest.end()){ + throw std::invalid_argument("Bad Node Test " + edgeType ); + } + return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey); } else { throw std::invalid_argument("error lexem COMMON " + lexeme); @@ -267,6 +272,11 @@ const std::string lexeme) std::smatch m; if (std::regex_match(lexeme, m, uniqueRegex)) { std::string edgeType = m[1]; + + if(allTest.find(edgeType) == allTest.end()){ + throw std::invalid_argument("Bad Node Test " + edgeType ); + } + return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType)); } else { throw std::invalid_argument("error lexem UNIQUE \"" + std::string(lexeme) +" eee\""); diff --git a/src/graphRegex/matchFsm/FsmGraph.cpp b/src/graphRegex/matchFsm/FsmGraph.cpp index 5a9f00d728cd2cd9f58c2228361f8393de2a3d9d..a56474e042cc44a68938b1d19e19a0c6841cb8cb 100644 --- a/src/graphRegex/matchFsm/FsmGraph.cpp +++ b/src/graphRegex/matchFsm/FsmGraph.cpp @@ -4,12 +4,13 @@ using namespace Aidge; -FsmGraph::FsmGraph(/* args */){ +FsmGraph::FsmGraph(const std::string query):mQuery(query){ } //TODO - std::shared_ptr<MatchResult> FsmGraph::test(std::vector<NodePtr>& startNodes){ + std::vector<std::shared_ptr<MatchSolution>> FsmGraph::test(const std::vector<NodePtr>& startNodes){ + std::vector<std::shared_ptr<Aidge::FsmNode>> startNodesFsm = getStartNodes(); if(startNodes.size() != startNodesFsm.size()){ throw std::runtime_error("bad number of Start nodes"); @@ -60,9 +61,9 @@ FsmGraph::FsmGraph(/* args */){ walks.swap(nextWalks); nextWalks.clear(); } - - - return std::make_shared<MatchResult>(allValidContext,getNbSubFsm()); + + MatchResult allMatch(allValidContext,getNbSubFsm(),mQuery,startNodes); + return allMatch.getSolutions(); } @@ -77,8 +78,8 @@ const std::set<std::shared_ptr<FsmEdge>>& FsmGraph::getEdge(void){ void FsmGraph::addEdge(std::shared_ptr<FsmEdge>& edge){ edge->updateWeak(); mEdges.insert(edge); - mAllOrigine.insert(edge->getDestNode()->getOrigine()); - mAllOrigine.insert(edge->getSourceNode()->getOrigine()); + mAllOrigin.insert(edge->getDestNode()->getOrigin()); + mAllOrigin.insert(edge->getSourceNode()->getOrigin()); } const std::vector<std::shared_ptr<FsmNode>> FsmGraph::getStartNodes(void){ @@ -151,19 +152,23 @@ void FsmGraph::mergeOneStartOneValid(const std::shared_ptr<FsmGraph> fsmGraph){ } std::size_t FsmGraph::getNbSubFsm(void){ - return mAllOrigine.size(); + return mAllOrigin.size(); +} + +std::size_t FsmGraph::getNbStart(void){ + return getStartNodes().size(); } -void FsmGraph::incOrigineAllNodeBy(std::size_t incr){ +void FsmGraph::incOriginAllNodeBy(std::size_t incr){ std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); for(auto node :nodes){ - node->incOrigine(incr); + node->incOrigin(incr); } std::set<std::size_t> updatedOrigin; - for(auto origin : mAllOrigine){ + for(auto origin : mAllOrigin){ updatedOrigin.insert(origin + incr); } - mAllOrigine.swap(updatedOrigin); + mAllOrigin.swap(updatedOrigin); } void FsmGraph::_mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest){ @@ -187,7 +192,7 @@ void FsmGraph::_mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNod if(edge->getDestNode() == source ){ edge->reSetDestNode(dest); }else if(edge->getSourceNode() == source ){ - edge->reSetSouceNode(dest); + edge->reSetSourceNode(dest); } } diff --git a/src/graphRegex/matchFsm/FsmNode.cpp b/src/graphRegex/matchFsm/FsmNode.cpp index 84b4a0c3fdbe0730a12a2a62db9158e2538d646f..7bc4cf105b43a540bd0e9c686af35dd220611a09 100644 --- a/src/graphRegex/matchFsm/FsmNode.cpp +++ b/src/graphRegex/matchFsm/FsmNode.cpp @@ -53,11 +53,11 @@ const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared -std::size_t FsmNode::getOrigine(void){ - return mOrigineStm; +std::size_t FsmNode::getOrigin(void){ + return mOriginFsm; } -void FsmNode::incOrigine(std::size_t inc){ - mOrigineStm += inc; +void FsmNode::incOrigin(std::size_t inc){ + mOriginFsm += inc; } void FsmNode::rmEdge(std::shared_ptr<FsmEdge> edge){ mEdges.erase(edge); @@ -93,7 +93,7 @@ const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& FsmNode::getEdges(v } void FsmNode::setGroupe(std::size_t groupeIdx){ - mGroupeStm = groupeIdx; + mGroupeFsm = groupeIdx; } diff --git a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp index 787cf2322a5b8e7001cdc59325345000dbb61553..ddf6a46cc7c75dc853d71ba98b051b4263a31164 100644 --- a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp +++ b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp @@ -155,7 +155,7 @@ void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpr } std::size_t FsmRunTimeContext::getSubStmId(void){ - return mActState->getOrigine(); + return mActState->getOrigin(); } NodePtr FsmRunTimeContext::getCommonNodeFromIdx(std::size_t commonIdx){ @@ -207,7 +207,7 @@ std::set<NodePtr> FsmRunTimeContext::getValidNodesNoCommon(void){ return differenceSet; } -std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> FsmRunTimeContext::getValid(void){ +std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>>& FsmRunTimeContext::getValid(void){ return mValidNodes; } diff --git a/src/graphRegex/matchFsm/MatchResult.cpp b/src/graphRegex/matchFsm/MatchResult.cpp index c35f1a7348e365baa8a27854ee6b0a833e342ee7..c871b3d0e22f3fa1f28b7bcea46ee8b9f61a3178 100644 --- a/src/graphRegex/matchFsm/MatchResult.cpp +++ b/src/graphRegex/matchFsm/MatchResult.cpp @@ -2,10 +2,63 @@ using namespace Aidge; + MatchSolution::MatchSolution(std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence,const std::string query,const std::vector<NodePtr> startNode):mQueryFrom(query),mStartNode(startNode){ + //reformat the solution + for (const auto& context : precedence) { + for (const auto& pair : context->getValid()) { + + if(mSolution.find(pair.first->getKey()) == mSolution.end()){ + mSolution[pair.first->getKey()] = pair.second; + }else{ + mSolution[pair.first->getKey()].insert(pair.second.begin(), pair.second.end()); + } + } + } + } + + + const std::set<NodePtr> & MatchSolution::at(const std::string key){ + + return mSolution[key]; -MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm):mIdToRunTime(nbSubStm){ + } + + const std::set<NodePtr> MatchSolution::getAll(){ + + // Create a unique set to store all the elements + std::set<NodePtr> uniqueSet; + + // Iterate through the map and insert elements from each set into the unique set + for (const auto& pair : mSolution) { + const std::set<NodePtr>& nodeSet = pair.second; + + // Insert elements from the current set into the unique set + uniqueSet.insert(nodeSet.begin(), nodeSet.end()); + } + + return uniqueSet; + + } + + bool MatchSolution::areCompatible(std::shared_ptr<MatchSolution> solution){ + std::set<NodePtr> set1 = solution->getAll(); + std::set<NodePtr> set2 = getAll(); + std::set<NodePtr> intersection ; + std::set_intersection(set1.begin(), set1.end(), set2.begin(), set2.end(), std::inserter(intersection, intersection.begin())); + if (intersection.empty()) { + return true; + } + return false; + } + + + +//////////////////////////////// +// +//////////////////////////////// +MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm, +const std::string& query,const std::vector<NodePtr>& startNodes):mIdToRunTime(nbSubStm),mNbSubStm(nbSubStm){ mAllValid = allValid; - mNbSubStm = nbSubStm; //mIdToRunTimm for (const auto& contextPtr : allValid) { @@ -13,25 +66,26 @@ MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allVali } std::vector<std::shared_ptr<FsmRunTimeContext>> precedence; - //make all solution posible - _generateCombinationd(0,precedence); + //make all solution possible + _generateCombination(0,precedence,query,startNodes); //sort by solution number of elements - std::sort(mSolve.begin(), mSolve.end(), [](const std::set<NodePtr>& set1, const std::set<NodePtr>& set2) { - return set1.size() < set2.size(); + std::sort(mSolve.begin(), mSolve.end(), [](std::shared_ptr<MatchSolution>& set1, std::shared_ptr<MatchSolution>& set2) { + return set1->getAll().size() < set2->getAll().size(); }); } -void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence){ +void MatchResult::_generateCombination( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence, +const std::string& query,const std::vector<NodePtr>& startNodes){ //it's end , we are below the number of stm if (idxSubStm == mNbSubStm) { - //precedence containe a liste of FSM compatible, we just need to - //check if all the node have been valide by at least one contetext + //precedence contain a list of FSM compatible, we just need to + //check if all the nodes have been validated by at least one context - //1) make the set of all node for the comput graph that are valide in all the FsmRunTimeContext + //1) make the set of all node for the compute graph that are valid in all the FsmRunTimeContext std::set<NodePtr> validNode; std::set<NodePtr> rejectNode; for (const auto& contextPtr : precedence) { @@ -40,11 +94,11 @@ void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std: std::set<NodePtr> tmpR = contextPtr->getRejectedNodes(); rejectNode.insert(tmpR.begin(),tmpR.end()); } - // 2) all RejectedNodes need to be valide by an others stm + // 2) all RejectedNodes need to be valid by an others stm // if it's not the case the match is not valid if(std::includes(validNode.begin(), validNode.end(), rejectNode.begin(), rejectNode.end())){ //we can save the solution - mSolve.push_back(validNode); + mSolve.push_back(std::make_shared<MatchSolution>(precedence,query,startNodes)); } precedence.pop_back(); return; @@ -55,10 +109,10 @@ void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std: { if(idxSubStm == 0){ precedence.push_back(contextPtrOneFsm); - _generateCombinationd(idxSubStm+1,precedence); + _generateCombination(idxSubStm+1,precedence,query,startNodes); }else{ - //test if the new context is compatible whith all the context in the precedence + //test if the new context is compatible with all the context in the precedence // bool compatibleSolutionFsm = true; for (const auto& contextPtrOfOtherFsm : precedence) { @@ -70,7 +124,7 @@ void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std: if(compatibleSolutionFsm){ precedence.push_back(contextPtrOneFsm); - _generateCombinationd(idxSubStm+1,precedence); + _generateCombination(idxSubStm+1,precedence,query,startNodes); } } @@ -83,11 +137,16 @@ void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std: } -std::set<NodePtr> MatchResult::getBiggerSolution(void){ +std::shared_ptr<MatchSolution> MatchResult::getBiggerSolution(void){ + if(mSolve.empty()){ - return std::set<NodePtr>(); + return nullptr; }else{ return mSolve[0]; } +} + +std::vector<std::shared_ptr<MatchSolution>> MatchResult::getSolutions(void){ + return mSolve; } \ No newline at end of file diff --git a/src/graphmatching/GRegex.cpp b/src/graphmatching/GRegex.cpp deleted file mode 100644 index 6b54c5a476e0319c3fab0751c0528a2084ebc0a7..0000000000000000000000000000000000000000 --- a/src/graphmatching/GRegex.cpp +++ /dev/null @@ -1,301 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graph/GraphView.hpp" - -using namespace Aidge; - -GRegex::GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ):mStmFab(nodesRegex){ - - - //setup all the STM - for (const std::string& sequRegex : seqRegexps) { - mStmInit.push_back(mStmFab.makeNewStm(sequRegex)); - } - -} - -bool GRegex::walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm){ - //test if all stm type are in a valid state - std::vector<int> number_of_valid; - number_of_valid.resize(all_stm.size()); - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - number_of_valid[i] = 0; - for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) { - SeqStm* stm = *it; - if (stm->isValid()){ - number_of_valid[i] +=1; - } - } - } - - for (std::size_t i = 0; i < number_of_valid.size(); ++i) { - if (number_of_valid[i] == 0) { - //std::cout << "NO MATCH at least one stm are not valid" << std::endl; - return false; - } - if (number_of_valid[i] > 1) { - //std::cout << "NO MATCH multiple brach match of stm (// quantification)" << std::endl; - return false; - } - } - return true; -} - -bool GRegex::walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm){ - std::set<NodeTmp> all_stm_node_tested; - std::set<NodeTmp> all_stm_node_validated; - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - //std::cout << "all stm index " << i << " on dimension 1 of size " << all_stm.size() <<std::endl; - for (std::size_t j = 0; j < all_stm[i].size(); ++j) { - //std::cout << "all stm index " << j << " on dimension 2 of size " << all_stm[i].size() <<std::endl; - - std::set<NodeTmp> stm_node_tested = all_stm[i][j]->getAllNodeTested(); - std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated(); - - all_stm_node_tested.insert(stm_node_tested.begin(), stm_node_tested.end()); - all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end()); - } - } - - - std::set<NodeTmp> test_but_not_valid; - for (const auto& x : all_stm_node_tested) { - if (all_stm_node_validated.find(x) == all_stm_node_validated.end()) { - test_but_not_valid.insert(x); - } - } - - - if (!test_but_not_valid.empty()) { - std::cout << "NO MATCH. The node(s) "; - for (const auto& x : test_but_not_valid) { - std::cout << x.get() << ", "; - } - std::cout << " have been tested but not validated." << std::endl; - return false; - } - return true; - -} - -bool GRegex::walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm){ - std::map<NodeTmp, std::pair<std::string,int>> node_to_common_tag; - for (std::size_t i = 0; i < all_stm.size(); ++i) { - for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) { - SeqStm* stm = *it; - - if (!stm->isValid()){ - continue; - } - - for (const auto& pair : stm->getAllCommonNode()) { - const NodeTmp node = pair.first; - const std::string common_tag = pair.second; - - if (node_to_common_tag.find(node) != node_to_common_tag.end()) { - std::string tag = node_to_common_tag[node].first; - int& occurence = node_to_common_tag[node].second; - if (tag!=common_tag){ - std::cout << "NO MATCH. The node " << node << " have two different tags "<< tag << " and " << common_tag << std::endl; - return false; - } else { - occurence += 1; - } - } else { - node_to_common_tag.insert(std::make_pair(node, std::make_pair(common_tag, 1))); - } - } - } - } - /*std::cout << "Node to common tag "; - for (const auto& x : node_to_common_tag) { - std::cout << "(" << x.first << ", " << "[" << x.second.first << ", " << x.second.second << "]" << ") ; "; - } - std::cout << std::endl;*/ - - - for (const auto& pair : node_to_common_tag) { - const std::pair<std::string, int> tag_occurence_pair = pair.second; - if (tag_occurence_pair.second < 1){ - //std::cout << "NO MATCH. The common tag " << tag_occurence_pair.first << " did not match " << std::endl; - return false; - } - } - - return true; -} - -std::set<NodeTmp> GRegex::get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm){ - std::set<NodeTmp> all_stm_node_validated; - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - for (std::size_t j = 0; j < all_stm[i].size(); ++j) { - std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated(); - all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end()); - } - } - return all_stm_node_validated; -} - - -std::set<NodeTmp> GRegex::matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch){ - std::set<NodeTmp> empty_set_return; - //ASSERT - if(startNodes.size() != mStmInit.size()){ - throw std::runtime_error ("bad GRegex start nodes"); - } - - //init the walk - std::vector<std::vector<SeqStm*>> allStm; - std::vector<std::pair<NodeTmp,SeqStm*>> currentWalk; - - for (SeqStm* seqStmPtr : mStmInit) { - SeqStm* newStm = mStmFab.duplicateStm(seqStmPtr); - std::size_t idxStart = newStm->getStmIdx(); - currentWalk.push_back(std::make_pair(startNodes[idxStart],newStm)); - allStm.push_back(std::vector<SeqStm*>()); - } - - //walk - while (currentWalk.size()!=0) - { - std::vector<std::pair<NodeTmp,SeqStm*>> newWalk; - for (const auto& pair : currentWalk) { - const NodeTmp node = pair.first; - SeqStm* stmPtr = pair.second; - - std::pair<int,std::string> test = stmPtr->testNode(node); - int res = test.first; - std::string commonTag = test.second; - - std::set<NodeTmp> next_nodes = graphToMatch->getChildren(node); - - /*std::cout << "Next nodes : " ; - for (const auto& x : next_nodes) { - std::cout << x->name() << ", "; - } - std::cout << std::endl;*/ - - // Test Match - if (commonTag == "" && next_nodes.size() > 1) { - std::cout << "NO MATCH. The node " << node.get() << " is not common and has more than one child" << std::endl; - return empty_set_return; - } - - // If there is no more nodes --> Archive the branch - if (res == -1 || next_nodes.empty()) { - int indexToInsert = stmPtr->getStmIdx(); - allStm[indexToInsert].push_back(stmPtr); - //std::cout << "No more nodes --> STM archived : " << indexToInsert << std::endl; - continue; // TODEV : replace this with 'else' that encapsulate the rest of the function ? - } - - bool first = true; - - // Use an iterator to read through the next_nodes - std::set<NodeTmp>::iterator it; - for (it = next_nodes.begin(); it != next_nodes.end(); ++it) { - // Access the current element using the iterator - std::shared_ptr<Aidge::Node> next_node = *it; - if (first){ - newWalk.push_back(std::make_pair(next_node, stmPtr)); - first = false; - } else { - SeqStm* new_stmPtr = mStmFab.duplicateStm(stmPtr); - newWalk.push_back(std::make_pair(next_node, new_stmPtr)); - } - } - } - currentWalk = newWalk; - } - - //std::cout << "Walk finished" << std::endl; - - if (!walk_validation_all_stm_are_valid(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_all_stm_are_valid finished" << std::endl; - - - if (!walk_validation_all_node_read_validate_by_one_stm(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_all_node_read_validate_by_one_stm finished" << std::endl; - - - if (!walk_validation_common_nodes_same_tag_for_all_stm(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_common_nodes_same_tag_for_all_stm finished" << std::endl; - - //std::cout << "MATCH" << std::endl; - - return get_all_validate_nodes(allStm); - -} - - - -Match GRegex::match(const std::shared_ptr<GraphView> graphToMatch){ - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches; - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches; - Match matches; - std::size_t nbStartNodes = mStmInit.size(); - std::set<NodeTmp> allNodes = graphToMatch->getNodes(); - std::size_t nbAllNodes = allNodes.size(); - - std::vector<std::size_t> indices(nbStartNodes, 0); - - while (true) { - // Generate all permutations of the current combination - do { - std::vector<NodeTmp> startNodes; - //std::cout <<"start nodes :"; - for (std::size_t i = 0; i < nbStartNodes; ++i) { - auto it = std::begin(allNodes); - std::advance(it, indices[i]); - //std::cout << (*it).get() << " "; - startNodes.push_back(*it); - } - //std::cout <<"\n"; - - std::set<NodeTmp> match = matchFromStartNodes(startNodes, graphToMatch); - //std::cout << "match size : " << match.size() << " "; - if(match.size() != 0){ - //matches.push_back(std::make_pair(startNodes,match)); - //matches.insert(std::make_pair(startNodes,match)); - matches.insert(startNodes,match); - } - - } while (std::next_permutation(indices.begin(), indices.end())); - - // Generate the next combination with replacement - std::size_t i = nbStartNodes - 1; - while (true) { - if (indices[i] < nbAllNodes - 1) { - ++indices[i]; - break; - } - if (i == 0) { - return matches; - } - --i; - } - std::fill(indices.begin() + i + 1, indices.end(), indices[i]); - } - - return matches; -} \ No newline at end of file diff --git a/src/graphmatching/Match.cpp b/src/graphmatching/Match.cpp deleted file mode 100644 index 6c08b30b11ab220310b476bab2c6d17ed86e4fd1..0000000000000000000000000000000000000000 --- a/src/graphmatching/Match.cpp +++ /dev/null @@ -1,37 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/Match.hpp" - -using namespace Aidge; - -Match::Match(){ - //ctr -} - -size_t Match::getNbMatch(){ - assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted"); - return mStartNodes.size(); -} - -void Match::insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes){ - assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted"); - mStartNodes.push_back(startnodes); - mMatchNodes.push_back(matchnodes); -} - -std::vector<std::vector<NodeTmp>> Match::getStartNodes(){ - return mStartNodes; -} - -std::vector<std::set<NodeTmp>> Match::getMatchNodes(){ - return mMatchNodes; -} \ No newline at end of file diff --git a/src/graphmatching/NodeRegex.cpp b/src/graphmatching/NodeRegex.cpp deleted file mode 100644 index 9bf164f60255c17492e528b0f27dec8c53f74979..0000000000000000000000000000000000000000 --- a/src/graphmatching/NodeRegex.cpp +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/NodeRegex.hpp" - - -// Verification done by the Attribute system - - -// Version 1 - Only test the type of the node (no need for a lexer) -// Input : Node_op -// Output : bool -// return mCondition == Node_op.type -bool Aidge::NodeRegex::_is(std::shared_ptr<Node> &Node_op){ - - std::string NodeType = Node_op->type(); - - return strcmp(NodeType.c_str(), mCondition.c_str()) == 0; -} - - -bool Aidge::NodeRegex::isA(std::string NodeType){ - - return strcmp(NodeType.c_str(), mCondition.c_str()) == 0; -} - -// Version 2 - Test the node to an advanced condition -// Input : Node_op -// Output : bool -// return mCondition applied on Node -/**bool NodeRegex::_is(string &Node_op){ - // Parsing the condition is done in the initialization of the NodeRegex - - // assert attributes exist in the node with the attribute function hasAttr() - - // get the attributes - -}*/ diff --git a/src/graphmatching/SeqStm.cpp b/src/graphmatching/SeqStm.cpp deleted file mode 100755 index 84553cb44cb898535943b31b8c955378e73ccbd5..0000000000000000000000000000000000000000 --- a/src/graphmatching/SeqStm.cpp +++ /dev/null @@ -1,247 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/SeqStm.hpp" - -using namespace Aidge; - - - - - /////////////////////////////////////////////////////// - - SeqStm::SeqStm( - const int stmIdx, - const std::vector<std::vector<int>>& transitionMatrix, - const std::map<std::string,NodeRegex*>& nodesRegex, - const std::map<NodeTypeKey,int>& typeToIdxTransition, - int actSt, - std::set<NodeTmp> allNodeValidated, - std::set<NodeTmp> allNodeTested, - std::set<std::pair<NodeTmp,std::string>> allCommonNode, - bool stmIsValid):mStmIdx(stmIdx), - mTransitionMatrix(transitionMatrix), - mNodesRegex(nodesRegex), - mTypeToIdxTransition(typeToIdxTransition) - { - - //assert - if (transitionMatrix.size() == 0){ - throw std::runtime_error ("no transitionMatrix"); - } - if(transitionMatrix[0].size() == 0 || transitionMatrix[0].size() != typeToIdxTransition.size()){ - throw std::runtime_error ("bad transitionMatrix"); - } - int size = static_cast<int>(transitionMatrix.size()); - if (actSt >= size){ - throw std::runtime_error ("bad actSt"); - } - - - mActSt = actSt; - mAllNodeValidated = allNodeValidated; - mAllNodeTested = allNodeTested; - mAllCommonNode = allCommonNode; - mStmIsValid = stmIsValid; - - } - - SeqStm* SeqStm::duplicateStm(){ - - //deep copy of the set - // std::set<Node> cAllNodeValidated(mAllNodeValidated.begin(), mAllNodeValidated.end()); - // std::set<Node> cAllNodeTested(mAllNodeTested.begin(), mAllNodeTested.end()); - - // std::set<std::pair<Node,std::string>> cAllCommonNode; - // for (const auto& p : mAllCommonNode) { - // cAllCommonNode.insert(p); - // } - - auto newStm = new SeqStm( - mStmIdx, - mTransitionMatrix, - mNodesRegex, - mTypeToIdxTransition, - mActSt, - mAllNodeValidated, - mAllNodeTested, - mAllCommonNode, - mStmIsValid - ); - - return newStm; - } - - - std::pair<NodeRegex*,std::string> SeqStm::getNodeRegexAndCommonAt(int idxType) - { - //std::cout << "!" << idxType << "\n"; - for (auto const& x : mTypeToIdxTransition) - { - //x.second is the value : idx in mTransitionMatrix for the type - //x.first pair of the node regex class and a string that is the common tag '',#,#n - if (x.second == idxType ){ - - if (mNodesRegex.find(x.first.first) != mNodesRegex.end()){ - return std::make_pair(mNodesRegex.find(x.first.first)->second, x.first.second); - }else{ - throw std::runtime_error ("a type is not define in NodesRegex"); - } - } - } - throw std::runtime_error ("bad idx in mNodesRegex"); - return std::make_pair(nullptr,nullptr); - } - - - NodeType SeqStm::getTheNodeType(NodeTmp node) - { - //the node is a str of '{type}{idx}' and we juste want type - // // std::regex re("([a-zA-Z]+)[0-9]+"); - // // std::smatch match; - // // if (std::regex_search(node, match, re) == true) { - // // return match.str(1); - // // } - // // throw std::runtime_error ("Type node not found"); - // // return ""; - - //return node->name(); - return node->type(); - } - - - std::string SeqStm::transitionOnNodeType(NodeType nodeType){ - - if (!isStmBlocked()){ - int idxType = 0; - for (auto & nextSt : mTransitionMatrix[mActSt]) { - // There are a next step for this type - //std::cout << "transition matrix next state -> "<< nextSt<<"\n" ; - if (nextSt != -1){ - //std::cout << "next -> "<< nextSt<< " "<< isAValidSt(nextSt) <<"\n" ; - auto nodeRegex = getNodeRegexAndCommonAt(idxType); - //std::cout << "-> "<< nodeRegex.second<<"\n" ; - if (nodeRegex.first->isA(nodeType)){ - //std::cout << "nodetype tested !"<<"\n" ; - if(isAValidSt(nextSt)){ - //std::cout << "Valid state !"<<"\n" ; - mStmIsValid = true; - } - mActSt = nextSt; - return nodeRegex.second; - } - - } - idxType += 1; - } - - mActSt =-1; - } - - return ""; - } - - - std::pair<int,std::string> SeqStm::testNode(const NodeTmp node){ - - std::string commonTag = ""; - //std::cout << "0\n" ; - if (!isStmBlocked()){ - bool isNextStEnd = std::all_of(mTransitionMatrix[mActSt].begin(), mTransitionMatrix[mActSt].end(), [&](int x){ return x == -1; }); - //std::cout << "1:"<< isNextStEnd <<"\n" ; - //if the next state if full of -1 can we relay add the node test to all node tested - // oker y test it but it sure that not be valid - if(!isNextStEnd){ - mAllNodeTested.insert(node); - } - //std::cout << "2\n" ; - //recurtion avoidance - if(mAllNodeValidated.find(node) == mAllNodeValidated.end()){ - - NodeType nodeType = getTheNodeType(node); - //std::cout << "3 " << nodeType << "\n" ; - commonTag = transitionOnNodeType(nodeType); - //after the transition test, if the node is != -1 the node is valid for the stm - //std::cout << " mActSt = " << mActSt << "\n" ; - if( mActSt != -1 ){ - mAllNodeValidated.insert(node); - } - }else{ - mActSt = -1; - } - } - - if(commonTag != ""){ - mAllCommonNode.insert(std::make_pair(node,commonTag)); - } - return std::make_pair(mActSt,commonTag); - } - - -void SeqStm::drawStm(){ - - //mTransitionMatrix - // Find the maximum width of each column - std::vector<std::size_t> max_widths(mTransitionMatrix[0].size(), 0); - for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i) - { - for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j) - { - std::size_t width = std::to_string(mTransitionMatrix[i][j]).length(); - if (width > max_widths[j]) - { - max_widths[j] = width; - } - } - } - - // Print the vector with aligned columns - for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i) - { - for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j) - { - int i_int = static_cast<int>(i); - if (mActSt == -1 ){ - if(mStmIsValid){ - std::cout << "\033[48;5;40m"; - }else{ - std::cout << "\033[48;5;9m"; - } - } - else if (mActSt == i_int){ - std::cout << "\033[48;5;30m"; - }else{ - std::cout << "\033[48;5;27m"; - } - - // Pad the value with spaces to align it with the maximum width - std::size_t width = std::to_string(mTransitionMatrix[i][j]).length(); - std::string padding(max_widths[j] - width, ' '); - std::cout << padding << mTransitionMatrix[i][j] << " "; - std::cout << "\033[0m"; - } - std::cout << "\n"; - } - - std::cout << "mAllNodeTested : "; - for (const auto& x : mAllNodeTested) { - std::cout << x << ", "; - } - std::cout << "\n"; - - - std::cout << "mAllNodeValidated : "; - for (const auto& x : mAllNodeValidated) { - std::cout << x << ", "; - } - std::cout << "\n"; -} - diff --git a/src/graphmatching/StmFactory.cpp b/src/graphmatching/StmFactory.cpp deleted file mode 100644 index 30b1fad81fc9e7f97dab03f7e6d091a27eeec32b..0000000000000000000000000000000000000000 --- a/src/graphmatching/StmFactory.cpp +++ /dev/null @@ -1,150 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/StmFactory.hpp" - -using namespace Aidge; - -StmFactory::StmFactory(const std::map<std::string, NodeRegex *> &nodesRegex) - : mNodesRegex(nodesRegex) {} - -SeqStm *StmFactory::duplicateStm(SeqStm *stm) { return stm->duplicateStm(); } - -SeqStm *StmFactory::makeNewStm(const std::string &sequRegex) { - - ParsingReturn parsing = initParsingSequRegex(sequRegex); - std::vector<std::vector<int>> transitionMatrix = - initTransitionMatrix(parsing); - - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp, std::string>> allCommonNode; - - SeqStm *newStm = new SeqStm(static_cast<int>(mCmptStm), transitionMatrix, mNodesRegex, - parsing.typeToIdxTransition, 0, allNodeValidated, - allNodeTested, allCommonNode, false); - mCmptStm += 1; - - return newStm; -} - -ParsingReturn StmFactory::initParsingSequRegex(const std::string &sequRegex) { - - std::string toMatch; - std::regex re("\\s*([A-Za-z]+)(#\\d*)?([+*])?\\s*(->|;)"); - std::smatch matches; - - int idxType = 0; - // return - ParsingReturn parsing; - // std::map<std::pair<NodeType,std::string>,int> typeToIdxTransition; - // std::vector<std::pair<std::pair<NodeType,std::string>,std::string>> - // transition; - // assert - std::map<NodeType, std::string> assertCommonNodeTypes; - - for (std::size_t i = 0; i < sequRegex.length(); i++) { - toMatch += sequRegex[i]; - if (std::regex_match(toMatch, matches, re)) { - - std::string type = matches.str(1); - std::string commonTag = matches.str(2); - std::string quantification = matches.str(3); - - if ((commonTag != "") && (quantification != "")) { - throw std::runtime_error("bad commonTag and quantification"); - } - - // make the typeToIdxTransition - NodeTypeKey typeTag = std::make_pair(type, commonTag); - /*std::cout << " typeTag: " << type << " " << commonTag - << parsing.typeToIdxTransition.size() << std::endl;*/ - if (parsing.typeToIdxTransition.find(typeTag) == - parsing.typeToIdxTransition.end()) { - parsing.typeToIdxTransition[typeTag] = idxType; - idxType += 1; - } - //////////////////////////////////////////////////////////// - // ASSERT - // SAME Common node in the sequ - if (commonTag != "") { - if (assertCommonNodeTypes.find(type) != assertCommonNodeTypes.end()) { - if (assertCommonNodeTypes[type] == commonTag) { - throw std::runtime_error("same common node in the sequ regex"); - } - } else { - assertCommonNodeTypes[type] = commonTag; - } - } - - // save all transition - parsing.transition.push_back(std::make_pair(typeTag, quantification)); - - /*std::cout << "Match found: " << matches.str() << std::endl; - std::cout << "Type: " << matches.str(1) << std::endl; - std::cout << "Common tag: " << matches.str(2) << std::endl; - std::cout << "Quantification: " << matches.str(3) << std::endl;*/ - - toMatch = ""; - } - } - if (parsing.transition.size() == 0) { - throw std::runtime_error("Bad Parsing SequRegex "); - } - - return parsing; -} - -std::vector<std::vector<int>> -StmFactory::initTransitionMatrix(ParsingReturn &parsing) { - - // std::pair<NodeTypeKey,std::string> - std::vector<std::vector<int>> transitionMatrix; - std::size_t numberOfType = parsing.typeToIdxTransition.size(); - - if (numberOfType == 0) { - throw std::runtime_error("Bad number Of Type "); - } - // init start st - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - - std::size_t idxTransition = 0; - int idxState = 0; - for (const auto &pair : parsing.transition) { - const NodeTypeKey &nodeTypeKey = pair.first; - const std::string &quant = pair.second; - - /*std::cout << "Key: {" << nodeTypeKey.first << ", " << nodeTypeKey.second - << "}, Value: " << quant << std::endl; - std::cout << "idxState " << idxState << " TM: " << transitionMatrix.size() - << std::endl;*/ - std::size_t idxType = parsing.typeToIdxTransition[nodeTypeKey]; - /*std::cout << "idxType " << idxType << " TM: " << transitionMatrix[0].size() - << "type" << numberOfType << std::endl;*/ - - if (quant == "*") { - transitionMatrix[idxTransition][idxType] = idxState; - } else if (quant == "+") { - idxState += 1; - transitionMatrix[idxTransition][idxType] = idxState; - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - idxTransition += 1; - transitionMatrix[idxTransition][idxType] = idxState; - } else { - - idxState += 1; - transitionMatrix[idxTransition][idxType] = idxState; - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - idxTransition += 1; - } - } - return transitionMatrix; -} \ No newline at end of file diff --git a/src/nodeTester/ConditionalInterpreter.cpp b/src/nodeTester/ConditionalInterpreter.cpp index e01bdd76a28576451a1a09202d5fd1e87a4856e5..f40e62305334f740057f88ef21cdab749d64bd99 100644 --- a/src/nodeTester/ConditionalInterpreter.cpp +++ b/src/nodeTester/ConditionalInterpreter.cpp @@ -8,7 +8,7 @@ using namespace Aidge; //ConditionalRegisterFunction /////////////////////////////// - ConditionalData* ConditionalRegisterFunction::run(const std::string key,std::vector<ConditionalData*> & datas){ + std::shared_ptr<ConditionalData> ConditionalRegisterFunction::run(const std::string key,std::vector< std::shared_ptr<ConditionalData>> & datas){ auto lambdaIt = mWlambda.find(key); if (lambdaIt != mWlambda.end()) { @@ -18,37 +18,46 @@ using namespace Aidge; } } + ////////////////////// //ConditionalInterpreter /////////////////////// - ConditionalInterpreter::ConditionalInterpreter(const std::string ConditionalExpressions) - :mLambdaRegiter() + ConditionalInterpreter::ConditionalInterpreter(const std::string key,const std::string ConditionalExpressions) + :mLambdaRegister(),mKey(key) { ConditionalParser conditionalParser = ConditionalParser(ConditionalExpressions); mTree = conditionalParser.parse(); + ///lambda by default - mLambdaRegiter.insert("getType",+[](NodePtr NodeOp){return NodeOp->type();}); + mLambdaRegister.insert("getType",+[](NodePtr NodeOp){return NodeOp->type();}); } + + bool ConditionalInterpreter::isLambdaRegister(const std::string &key){ + return mLambdaRegister.isLambdaRegister(key); + } + + const std::string& ConditionalInterpreter::getKey(){ + return mKey; + } bool ConditionalInterpreter::test( const NodePtr nodeOp) { - - clearRes(); + mResolution.clear(); try{ - std::vector<ConditionalData*> r = visit({mTree},nodeOp); - - if (mResolution.size() != 1){ - throw std::runtime_error("Multi-output interpretation output"); - }else{ - if (!mResolution[0]->isTypeEqualTo<bool>()){ - throw std::runtime_error("TEST OUT MUST BE A BOOL "); + std::vector< std::shared_ptr<ConditionalData>> r = visit({mTree},nodeOp); + + if (mResolution.size() != 1){ + throw std::runtime_error("Multi output interpretation output"); }else{ - return mResolution[0]->getValue<bool>(); + if (!mResolution[0]->isTypeEqualTo<bool>()){ + throw std::runtime_error("TEST OUT MUST BE A BOOL "); + }else{ + return mResolution[0]->getValue<bool>(); + } } - } }catch(const std::exception& e){ std::ostringstream errorMessage; @@ -58,12 +67,12 @@ using namespace Aidge; } void ConditionalInterpreter::insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f){ - mLambdaRegiter.insert<std::function<bool(Aidge::NodePtr)> >(key, f); + mLambdaRegister.insert<std::function<bool(Aidge::NodePtr)> >(key, f); } ///// - std::vector<ConditionalData*> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){ - std::vector<ConditionalData*> dataVector; + std::vector< std::shared_ptr<ConditionalData>> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){ + std::vector< std::shared_ptr<ConditionalData>> dataVector; for ( std::shared_ptr<AstNode<ConditionalTokenTypes>> node : nodes) { try{ @@ -130,7 +139,7 @@ using namespace Aidge; case ConditionalTokenTypes::NODE: //TODO { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<NodePtr>(nodeOp); mResolution.push_back(data); @@ -147,7 +156,7 @@ using namespace Aidge; case ConditionalTokenTypes::BOOL: //TODO { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); if(node->getValue() == "true"){ data->setValue<bool>(true); @@ -169,8 +178,8 @@ using namespace Aidge; } }catch(const std::exception& e){ std::ostringstream errorMessage; - errorMessage << "Error in visiting AST for node"<< nodeOp->name() << "\n\t" << e.what() << "\n"; - throw std::runtime_error(errorMessage.str()); + errorMessage << "Error in visiting AST for node "<< nodeOp->name() << "\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); } } @@ -185,7 +194,8 @@ using namespace Aidge; void ConditionalInterpreter::fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); + data->setValue<int>(std::stoi(node->getValue())); mResolution.push_back(data); } @@ -193,14 +203,14 @@ using namespace Aidge; void ConditionalInterpreter::fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<float>(std::stof(node->getValue())); mResolution.push_back(data); } void ConditionalInterpreter::fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<std::string>(node->getValue()); mResolution.push_back(data); } @@ -208,34 +218,37 @@ using namespace Aidge; void ConditionalInterpreter::fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) { //if the lambda have input - ConditionalData* data; + std::shared_ptr<ConditionalData> data; try { - data = mLambdaRegiter.run(node->getValue(),mResolution); + data = mLambdaRegister.run(node->getValue(),mResolution); } catch (const std::exception& e) { std::ostringstream errorMessage; errorMessage << "Error in conditional interpretation when run the "<< node->getValue() <<" Lambda\n\t" << e.what() << "\n"; throw std::runtime_error(errorMessage.str()); } - clearRes(); + //clearRes(); mResolution.push_back(data); } void ConditionalInterpreter::fEq(void) { - if (mResolution.size() != 2){ + if (mResolution.size() < 2){ throw std::runtime_error("EQ need 2 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; - auto b = mResolution[1]; + auto a = mResolution.back(); + mResolution.pop_back(); + auto b = mResolution.back(); + mResolution.pop_back(); + if (a->getType() != b->getType()){ - throw std::runtime_error("EQ Unsuported between type :" + a->getType() +" "+ b->getType()); + throw std::runtime_error("EQ Unsupported between type :" + a->getType() +" "+ b->getType()); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); if (a->isTypeEqualTo<int>()) { data->setValue<bool>( a->getValue<int>() == b->getValue<int>()); @@ -249,23 +262,25 @@ using namespace Aidge; throw std::runtime_error("EQ Unknown type encountered :" + a->getType() ); } - clearRes(); + mResolution.push_back(data); } void ConditionalInterpreter::fNeq(void) { - if (mResolution.size() != 2){ + if (mResolution.size() < 2){ throw std::runtime_error("NEQ need 2 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; - auto b = mResolution[1]; + auto a = mResolution.back(); + mResolution.pop_back(); + auto b = mResolution.back(); + mResolution.pop_back(); if (a->getType() != b->getType()){ - throw std::runtime_error("NEQ Unsuported between type :" + a->getType() +" "+ b->getType()); + throw std::runtime_error("NEQ Unsupported between type :" + a->getType() +" "+ b->getType()); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); if (a->isTypeEqualTo<int>()) { data->setValue<bool>( a->getValue<int>() != b->getValue<int>()); @@ -278,67 +293,72 @@ using namespace Aidge; throw std::runtime_error("NEQ Unknown type encountered :" + a->getType() ); } - clearRes(); + mResolution.push_back(data); } void ConditionalInterpreter::fAnd(void) { - if (mResolution.size() != 2){ + if (mResolution.size() < 2){ throw std::runtime_error("AND need 2 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; - auto b = mResolution[1]; + auto a = mResolution.back(); + mResolution.pop_back(); + auto b = mResolution.back(); + mResolution.pop_back(); if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ throw std::runtime_error("AND Unknown type encountered need bool get :" + a->getType() ); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<bool>( a->getValue<bool>() && b->getValue<bool>()); - clearRes(); + mResolution.push_back(data); } void ConditionalInterpreter::fOr(void) { - if (mResolution.size() != 2){ + if (mResolution.size() < 2){ throw std::runtime_error("OR need 2 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; - auto b = mResolution[1]; + auto a = mResolution.back(); + mResolution.pop_back(); + auto b = mResolution.back(); + mResolution.pop_back(); if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ throw std::runtime_error("OR Unknown type encountered need bool get :" + a->getType() ); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<bool>( a->getValue<bool>() || b->getValue<bool>()); - clearRes(); + mResolution.push_back(data); } void ConditionalInterpreter::fNot() { - if (mResolution.size() != 1){ + if (mResolution.size() < 1){ throw std::runtime_error("NOT need 1 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; + auto a = mResolution.back(); + mResolution.pop_back(); if (a->getType() != typeid(bool).name()){ throw std::runtime_error("NOT Unknown type encountered need bool get :" + a->getType() ); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<bool>( !a->getValue<bool>() ); - clearRes(); + mResolution.push_back(data); } diff --git a/src/operator/Operator.cpp b/src/operator/Operator.cpp index a8f2fe4675a6b664e14ced447a3ab4f7a9183ff9..05a1d8ab9fe9d35f792fff045ea18362b367bd9b 100644 --- a/src/operator/Operator.cpp +++ b/src/operator/Operator.cpp @@ -10,10 +10,14 @@ ********************************************************************************/ #include <cassert> +#include <cstddef> +#include <vector> +#include <utility> #include "aidge/backend/OperatorImpl.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" // constexpr Aidge::Operator::Operator(const char* type) // : mType(type) @@ -27,6 +31,29 @@ Aidge::Operator::~Operator() = default; // IMPLEMENTATION /////////////////////////////////////////////////////// +// std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>> Aidge::Operator::computeReceptiveField( +// const std::size_t firstIdx, const std::vector<Aidge::DimSize_t>& outputDims, const Aidge::IOIndex_t outputIdx) const +// { +// static_cast<void>(outputIdx); +// if (outputIdx >= nbOutputs()) { +// AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator output index out of range."); +// } +// if (nbInputs() != nbDataInputs()) { +// AIDGE_THROW_OR_ABORT(std::runtime_error, "Operator has attributes. Must be handled in an overrided function."); +// } +// if (!outputDimsForwarded() || getOutput(0)->nbDims() != outputDims.size()) { +// AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range or output dim not forwarded yet."); +// } +// const auto outputIdxDims = getOutput(0)->getCoord(firstIdx); +// for (DimIdx_t i = 0; i < outputDims.size(); ++i) { +// if (((outputDims[i] + outputIdxDims[i]) > getOutput(0)->dims()[i]) || (outputDims[i] == 0)) { +// AIDGE_THROW_OR_ABORT(std::runtime_error, "Given outputDim out of range for dimension %lu (%lu + %lu)", static_cast<std::size_t>(i), outputIdxDims[i], outputDims[i]); +// } +// } +// // return the same Tensor description as given in function parameter for each data input +// return std::vector<std::pair<std::size_t, std::vector<Aidge::DimSize_t>>>(nbDataInputs(),std::pair<std::size_t, std::vector<Aidge::DimSize_t>>(firstIdx, outputDims)); +// } + Aidge::NbElts_t Aidge::Operator::getNbRequiredData(const Aidge::IOIndex_t inputIdx) const { return mImpl->getNbRequiredData(inputIdx); } diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index 4b2f7a811c022ee80eec98548049853d56951edb..4b02692c2338384b39ad17c7eac55bd308fb59ce 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -16,33 +16,22 @@ #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" -#include "aidge/utils/Recipies.hpp" +#include "aidge/recipies/Recipies.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" + + +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + using namespace Aidge; -void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ +void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm){ + - assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - // Assert the nodes types are correct to be fused - std::shared_ptr<Node> conv; - std::shared_ptr<Node> batchnorm; - for (const auto& element : nodes) { - assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace"); - if (element->type() == "Conv"){ - conv = element; - } - else if (element->type() == "BatchNorm") { - batchnorm = element; - } - } - // TODO : check if batchnorm is the only child of the Conv or FC std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second); @@ -127,19 +116,32 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ } +void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n"); + assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n"); + + for (const auto& op : solution->at("OP")) { + for (const auto& batchNorm : solution->at("BatchNorm")) { + fuseBatchNorm(op,batchNorm); + } + } + +} + void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ - std::map<std::string,NodeRegex*> nodesRegex ; - nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm"); - nodesRegex["Conv"] = new NodeRegex("Conv"); - nodesRegex["FC"] = new NodeRegex("FC"); - - - std::vector<std::string> seqRegex; - seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC) - GRegex GReg(nodesRegex, seqRegex); - Match matches = GReg.match(graphView); - std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - for (size_t i = 0; i < matches.getNbMatch(); ++i) { - fuseBatchNorm(matchNodes[i]); + + + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'"); + regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' "); + + regex->addQuery("OP -> BatchNorm"); + + for (const auto& solution : regex->match(graphView)) { + + fuseBatchNorm(solution); + } + } diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 528d57e31a5ecf3f5a633a20205e79f7926a1f61..a268b7fefa96b7fa8877c69341a34f659a0077d0 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -15,37 +15,24 @@ #include <string> #include "aidge/operator/FC.hpp" -#include "aidge/utils/Recipies.hpp" +#include "aidge/recipies/Recipies.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + using namespace Aidge; -void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ +void Aidge::fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add){//std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ? + assert((matmul->type() == "MatMul" && add->type() == "Add") && "Wrong type for the nodes to replace"); - // Step 0 : Assert the nodes types are correct to be fused - std::shared_ptr<Node> add; - std::shared_ptr<Node> matmul; - for (const auto& element : nodes) { - assert((element->type() == "MatMul" || element->type() == "Add") && "Wrong type for the nodes to replace"); - if (element->type() == "MatMul"){ - matmul = element; - } - else if (element->type() == "Add") { - add = element; - } - } // Step 1 : Create FC // Fetch the output dimension throught the bias size @@ -78,17 +65,35 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ } + +void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); + assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); + + for (const auto& matmul : solution->at("MatMul")) { + for (const auto& add : solution->at("Add")) { + fuseMulAdd(matmul,add); + } + } +} + + void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ - std::map<std::string,NodeRegex*> nodesRegex ; - nodesRegex["MatMul"] = new NodeRegex("MatMul"); - nodesRegex["Add"] = new NodeRegex("Add"); - std::vector<std::string> seqRegex; - seqRegex.push_back("MatMul -> Add;"); - GRegex GReg(nodesRegex, seqRegex); - Match matches = GReg.match(graphView); - std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - for (size_t i = 0; i < matches.getNbMatch(); ++i) { - fuseMulAdd(matchNodes[i]); + + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Add","getType($) =='Add'"); + regex->setNodeKey("MatMul","getType($) =='MatMul'"); + regex->addQuery("MatMul -> Add ;"); + + for (const auto& solution : regex->match(graphView)) { + + fuseMulAdd(solution); + + + } + + } diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index fdfdbfd4aea7543dde31d5f5d4845e54e930feac..d571b53023b7665c25aedc869628045b3b13d509 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -13,38 +13,43 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" -#include "aidge/utils/Recipies.hpp" +#include "aidge/recipies/Recipies.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" -namespace Aidge { - void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { - assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - std::shared_ptr<Node> flatten; - for (const auto& element : nodes) { - assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); - if (element->type() == "Flatten"){ - flatten = element; - } - } +namespace Aidge { + void removeFlatten(std::shared_ptr<Node> flatten) { + GraphView::replace({flatten}, {}); } + void removeFlatten(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n"); + assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n"); + + for (const auto& flatten : solution->at("Flatten")) { + removeFlatten(flatten); + } + } + + + void removeFlatten(std::shared_ptr<GraphView> graphView){ - std::map<std::string,NodeRegex*> nodesRegex ; - nodesRegex["Flatten"] = new NodeRegex("Flatten"); - nodesRegex["FC"] = new NodeRegex("FC"); - std::vector<std::string> seqRegex; - seqRegex.push_back("Flatten->FC;"); - GRegex GReg(nodesRegex, seqRegex); - Match matches = GReg.match(graphView); - std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - for (size_t i = 0; i < matches.getNbMatch(); ++i) { - removeFlatten(matchNodes[i]); + + + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Flatten","getType($) =='Flatten'"); + regex->setNodeKey("FC","getType($) =='FC'"); + regex->addQuery("Flatten->FC"); + + for (const auto& solution : regex->match(graphView)) { + removeFlatten(solution); } + + } } diff --git a/unit_tests/CMakeLists.txt b/unit_tests/CMakeLists.txt index 9d9f81516b0cd2611484ee9e3e06e838833200db..5ccfa3832a8ce2522f18ab07e11a78cf8b462a40 100644 --- a/unit_tests/CMakeLists.txt +++ b/unit_tests/CMakeLists.txt @@ -10,6 +10,8 @@ FetchContent_MakeAvailable(Catch2) file(GLOB_RECURSE src_files "*.cpp") +#file(GLOB_RECURSE src_files "graphRegex/Test_GraphRegex.cpp") + add_executable(tests${module_name} ${src_files}) target_link_libraries(tests${module_name} PUBLIC ${module_name}) diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index e7f0d60950c71e1eeaa0bda01ddce92bd5db8b70..3b4fb167bd559b8228f98c91f6f4ee069aee10e8 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -245,7 +245,7 @@ TEST_CASE("[core/graph] GraphView(resetConnections)") { } } -TEST_CASE("Graph Forward dims", "[GraphView]") { +TEST_CASE("[core/graph] GraphView(forwardDims)", "[GraphView][forwardDims]") { auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); auto conv1 = Conv(3, 32, {3, 3}, "conv1"); auto conv2 = Conv(32, 64, {3, 3}, "conv2"); diff --git a/unit_tests/graphMatching/Test_GRegex.cpp b/unit_tests/graphMatching/Test_GRegex.cpp deleted file mode 100644 index 2c5907d82e7c5b1d32f1fb38493c7333b68f8731..0000000000000000000000000000000000000000 --- a/unit_tests/graphMatching/Test_GRegex.cpp +++ /dev/null @@ -1,318 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Match.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/graph/GraphView.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init GRegex", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("A->B;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - // Perform tests - REQUIRE(GReg.getStmInit().size() == 1); - REQUIRE(GReg.getStmFab().getNumberOfStm() == 1); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - - -TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"Conv","BN","ReLU"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("Conv->BN->ReLU;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> Random2 = GenericOperator("Random2", 1, 1, 1); - - - g1->add(Conv1); - g1->addChild(BN1, Conv1); - g1->addChild(ReLU1, BN1); - g1->addChild(Random, ReLU1); - //g1->addChild(BN1, Random2); - - std::vector<std::shared_ptr<Node>> startNodes1; - std::set<std::shared_ptr<Node>> result; - - startNodes1.push_back(Conv1); - result = GReg.matchFromStartNodes(startNodes1, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(Conv1); - true_result.insert(BN1); - true_result.insert(ReLU1); - - // Perform tests - REQUIRE(result == true_result); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"Add","FC","Conv"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("Add#->Conv;"); - seqRegex.push_back("Add#->FC;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - // Instanciate a graphView - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> Add1 = GenericOperator("Add", 1, 1, 1); - std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - - g1->add(Random0); - g1->addChild(Add1, Random0); - g1->addChild(Conv1, Add1); - g1->addChild(BN1, Conv1); - g1->addChild(ReLU1, BN1); - g1->addChild(FC1, Add1); - g1->addChild(Random, FC1); - - // Test 1 : Find the match - std::vector<std::shared_ptr<Node>> startNodes; - std::set<std::shared_ptr<Node>> result; - - startNodes.push_back(Add1); - startNodes.push_back(Add1); - result = GReg.matchFromStartNodes(startNodes, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(Add1); - true_result.insert(Conv1); - true_result.insert(FC1); - - // Test 2 : Return an empty set when the start nodes are wrong - std::vector<std::shared_ptr<Node>> wrong_startNodes; - std::set<std::shared_ptr<Node>> wrong_start_result; - std::set<std::shared_ptr<Node>> empty_result; - - wrong_startNodes.push_back(Random0); - wrong_startNodes.push_back(Random0); - wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); - - // Perform tests - REQUIRE(result == true_result); - REQUIRE(wrong_start_result == empty_result); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -/* -TEST_CASE("Function matchFromStartNodes | Match a sequence with quantifier ", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"FC"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("FC+;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - - // Instanciate a graphView - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> FC2 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> FC3 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - - g1->add(Random0); - g1->addChild(FC1, Random0); - g1->addChild(FC2, FC1); - g1->addChild(FC3, FC2); - g1->addChild(ReLU1, FC3); - - // Test 1 : Find the match - std::vector<std::shared_ptr<Node>> startNodes; - std::set<std::shared_ptr<Node>> result; - - startNodes.push_back(FC1); - result = GReg.matchFromStartNodes(startNodes, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(FC1); - true_result.insert(FC2); - true_result.insert(FC3); - - // Test 2 : Return an empty set when the start nodes are wrong - std::vector<std::shared_ptr<Node>> wrong_startNodes; - std::set<std::shared_ptr<Node>> wrong_start_result; - std::set<std::shared_ptr<Node>> empty_result; - - wrong_startNodes.push_back(Random0); - wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); - - // Perform tests - REQUIRE(result == true_result); - REQUIRE(wrong_start_result == empty_result); -} -*/ - -TEST_CASE("Function match | ALL matches of Nodes sequence", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"GEMM"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("GEMM;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - //init the input graph - std::shared_ptr<GraphView> graphToMatch = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> GEMM1 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> GEMM2 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> GEMM3 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> ReLU2 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - - graphToMatch->add(Random0); - graphToMatch->addChild(GEMM1, Random0); - graphToMatch->addChild(ReLU1, GEMM1); - graphToMatch->addChild(GEMM2, ReLU1); - graphToMatch->addChild(GEMM3, GEMM2); - graphToMatch->addChild(ReLU2, GEMM3); - graphToMatch->addChild(Random, ReLU2); - - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); - Match matches = GReg.match(graphToMatch); - - size_t nb = matches.getNbMatch(); - std::vector<std::vector<NodeTmp>> gm_startnodes = matches.getStartNodes(); - std::vector<std::set<NodeTmp>> gm_matchnodes = matches.getMatchNodes(); - - std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs; - - for (size_t i = 0; i < nb; ++i) { - matchs.insert(std::make_pair(gm_startnodes[i], gm_matchnodes[i])); - } - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; - std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; - // Carefull : as the assert is on a vector, the Order of match matters - std::vector<NodeTmp> startNode = {GEMM1}; - std::set<NodeTmp> matchNode = {GEMM1}; - //toMatchs.push_back(std::make_pair(startNode,matchNode)); - toMatchs.insert(std::make_pair(startNode,matchNode)); - - std::vector<NodeTmp> startNode2 = {GEMM2}; - std::set<NodeTmp> matchNode2 = {GEMM2}; - //toMatchs.push_back(std::make_pair(startNode2,matchNode2)); - toMatchs.insert(std::make_pair(startNode2,matchNode2)); - - std::vector<NodeTmp> startNode3 = {GEMM3}; - std::set<NodeTmp> matchNode3 = {GEMM3}; - //toMatchs.push_back(std::make_pair(startNode3,matchNode3)); - toMatchs.insert(std::make_pair(startNode3,matchNode3)); - - REQUIRE(matchs == toMatchs); - REQUIRE(nb == 3); -} - - diff --git a/unit_tests/graphMatching/Test_NodeRegex.cpp b/unit_tests/graphMatching/Test_NodeRegex.cpp deleted file mode 100644 index 2866642bf1355f49a451edffec9e1b62c802ae1f..0000000000000000000000000000000000000000 --- a/unit_tests/graphMatching/Test_NodeRegex.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> - -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/operator/GenericOperator.hpp" - - -using namespace Aidge; - -TEST_CASE("Create Noderegex", "[Noderegex]") { - std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("conv"); -} - -TEST_CASE("Test _is function", "[Noderegex]") { - // Create Noderegex with only condition on the name of the Node - // Create several operators to pass into Noderegex _is function - // Assert Noderegex._is(operators) are correct - std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("Conv"); - - std::shared_ptr<Node> Conv = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> FC = GenericOperator("FC", 1, 1, 1); - - REQUIRE(nr->_is(Conv) == true); - REQUIRE(nr->_is(FC) == false); - REQUIRE(nr->isA("Conv") == true); - REQUIRE(nr->isA("FC") == false); - -} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_SeqStm.cpp b/unit_tests/graphMatching/Test_SeqStm.cpp deleted file mode 100644 index db8662e3329abe153d4a0fb2b3c46b950208d6bc..0000000000000000000000000000000000000000 --- a/unit_tests/graphMatching/Test_SeqStm.cpp +++ /dev/null @@ -1,167 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init SeqStm", "[SeqStm]") { - //init all iniput for SeqStm - - - int stmIdx = 0; - //matrix that in B->C - std::vector<std::vector<int>> transitionMatrix { - { -1, 1, -1 }, - { -1, -1, 2 }, - { -1, -1, -1 } }; - - //std::cout << transitionMatrix.size() << "\n"; - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // - - std::map<NodeTypeKey,int> typeToIdxTransition; - std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; - //init nodeTypeCommonTag - int idx = 0; - for (const NodeTypeKey& key : nodeTypeCommonTag) { - typeToIdxTransition[key] = idx; - idx += 1; - } - - int actSt = 0; - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp,std::string>> allCommonNode; - bool stmIsValid =false; - - - SeqStm stm( - stmIdx, - transitionMatrix, - nodesRegex, - typeToIdxTransition, - actSt, - allNodeValidated, - allNodeTested, - allCommonNode, - stmIsValid); - - REQUIRE(stm.getStmIdx() == 0); - REQUIRE(stm.isValid() == false); - REQUIRE(stm.getAllCommonNode().size() == 0); - REQUIRE(stm.getAllNodeTested().size() == 0); - REQUIRE(stm.getAllNodeValidated().size() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test testNode function", "[SeqStm]") { - - int stmIdx = 0; - std::map<NodeTypeKey,int> typeToIdxTransition; - std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; - //init nodeTypeCommonTag - int idx = 0; - for (const NodeTypeKey& key : nodeTypeCommonTag) { - typeToIdxTransition[key] = idx; - idx += 1; - } - //matrix that in B->C - std::vector<std::vector<int>> transitionMatrix { - { -1, 1, -1 }, - { -1, -1, 2 }, - { -1, -1, -1 } }; - - //std::cout << transitionMatrix.size() << "\n"; - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // - int actSt = 0; - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp,std::string>> allCommonNode; - bool stmIsValid =false; - - SeqStm stm( - stmIdx, - transitionMatrix, - nodesRegex, - typeToIdxTransition, - actSt, - allNodeValidated, - allNodeTested, - allCommonNode, - stmIsValid); - REQUIRE(stm.getStmIdx() == 0); - //test a node - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - stm.testNode(nodeB); - REQUIRE(stm.isValid() == false); - REQUIRE(stm.getState() == 1); - REQUIRE(stm.isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - - stm.testNode(nodeC); - REQUIRE(stm.isValid() == true); - REQUIRE(stm.getState() == 2); - REQUIRE(stm.isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - stm.testNode(nodeC); - REQUIRE(stm.isValid() == true); - REQUIRE(stm.getState() == -1); - REQUIRE(stm.isStmBlocked() == true); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_StmFactory.cpp b/unit_tests/graphMatching/Test_StmFactory.cpp deleted file mode 100644 index 3c66d0fa817cea674de5ab849091290c976e5735..0000000000000000000000000000000000000000 --- a/unit_tests/graphMatching/Test_StmFactory.cpp +++ /dev/null @@ -1,204 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init StmFactory", "[StmFactory]") { - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - StmFactory stmF(nodesRegex); - REQUIRE(stmF.getNumberOfStm() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - StmFactory stmF(nodesRegex); - - std::string seq1 = "A->B+->A#;"; - SeqStm* stm = stmF.makeNewStm(seq1); - REQUIRE(stm->getStmIdx() == 0); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getAllCommonNode().size() == 0); - REQUIRE(stm->getAllNodeTested().size() == 0); - REQUIRE(stm->getAllNodeValidated().size() == 0); - - std::string seq2 = "A->B;"; - SeqStm* stm2 = stmF.makeNewStm(seq2); - REQUIRE(stm2->getStmIdx() == 1); - REQUIRE(stm2->isValid() == false); - REQUIRE(stm2->getAllCommonNode().size() == 0); - REQUIRE(stm2->getAllNodeTested().size() == 0); - REQUIRE(stm2->getAllNodeValidated().size() == 0); - - //test the number of stm - REQUIRE(stmF.getNumberOfStm() == 2); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - - StmFactory stmF(nodesRegex); - std::string seq1 = "B->C;"; - SeqStm* stm = stmF.makeNewStm(seq1); - //test the number of stm - REQUIRE(stmF.getNumberOfStm() == 1); - - //std::shared_ptr<Node> nodeB = GenericOperator("B",1,1,1); - //std::shared_ptr<Node> nodeC = GenericiOperator("C",1,1,1); - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 0); - REQUIRE(stm->isStmBlocked() == false); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeB); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 1); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == 2); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == -1); - REQUIRE(stm->isStmBlocked() == true); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } - -} - -TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - - StmFactory stmF(nodesRegex); - std::string seq1 = "B->C;"; - SeqStm* stm = stmF.makeNewStm(seq1); - SeqStm* stmD = stmF.duplicateStm(stm); - - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - //run the stm - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 0); - REQUIRE(stm->isStmBlocked() == false); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeB); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 1); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == 2); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == -1); - REQUIRE(stm->isStmBlocked() == true); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - //check if stmD not move - REQUIRE(stmD->isValid() == false); - REQUIRE(stmD->getState() == 0); - REQUIRE(stmD->isStmBlocked() == false); - REQUIRE(stmD->getAllNodeTested().size() == 0); - REQUIRE(stmD->getAllNodeValidated().size() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - diff --git a/unit_tests/graphRegex/Test_Fsm.cpp b/unit_tests/graphRegex/Test_Fsm.cpp index e5950f21b323f07b380ae95f70637ca48a173481..c011a50455e9e21f3df66c3ed46a835bed5346b3 100644 --- a/unit_tests/graphRegex/Test_Fsm.cpp +++ b/unit_tests/graphRegex/Test_Fsm.cpp @@ -14,10 +14,10 @@ using namespace Aidge; TEST_CASE("matchFSM", "FsmEdge") { - std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); - std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); - FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest); + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); + FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest); SECTION("FsmEdgeUnique constructor") { REQUIRE(EdgeToTest.getSourceNode() == nodeA); @@ -28,7 +28,7 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("FsmEdgeCommon constructor") { std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); FsmEdgeCommon EdgeToTest(nodeA,nodeB,toTest,"A"); @@ -40,7 +40,7 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("FsmEdgeRef constructor") { std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); FsmEdgeRef EdgeToTest(nodeA,nodeB,0,-1); @@ -52,7 +52,7 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("FsmEdgeEmpty constructor") { std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); FsmEdgeEmpty EdgeToTest(nodeA,nodeB); @@ -65,9 +65,9 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("FsmEdgeFactory"){ std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest = { - {"A",std::make_shared<ConditionalInterpreter>("true==true")}, - {"B",std::make_shared<ConditionalInterpreter>("true==true")}, - {"C",std::make_shared<ConditionalInterpreter>("true==true")} + {"A",std::make_shared<ConditionalInterpreter>("A","true==true")}, + {"B",std::make_shared<ConditionalInterpreter>("B","true==true")}, + {"C",std::make_shared<ConditionalInterpreter>("C","true==true")} }; // make(std::shared_ptr<FsmNode> source, std::shared_ptr<FsmNode> dest, @@ -103,11 +103,11 @@ TEST_CASE("matchFSM", "FsmEdge") { std::shared_ptr<FsmNode> nodeC = std::make_shared<FsmNode>(false,true); //make the edges - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); - std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(""); graph->addEdge(edgeAB); graph->addEdge(edgeBC); @@ -120,7 +120,7 @@ TEST_CASE("matchFSM", "FsmEdge") { SECTION("graph merge") { - std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("A","true==true"); //make the nodes std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(false,true); @@ -132,7 +132,7 @@ TEST_CASE("matchFSM", "FsmEdge") { std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); - std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(""); graph->addEdge(edgeAB); graph->addEdge(edgeBC); @@ -149,7 +149,7 @@ TEST_CASE("matchFSM", "FsmEdge") { std::shared_ptr<FsmEdge> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest); std::shared_ptr<FsmEdge> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest); - std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>(""); graph2->addEdge(edge2AB); @@ -184,7 +184,7 @@ TEST_CASE("matchFSM", "FsmEdge") { // std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); // std::shared_ptr<FsmEdgeUnique> edge = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); -// std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); +// std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(""); // graph->addEdge(edge); diff --git a/unit_tests/graphRegex/Test_FsmMatch.cpp b/unit_tests/graphRegex/Test_FsmMatch.cpp index 1fe75be1a47033f75af7ccc4dc5202774444cd10..4b0a009a4b142f56334b133919025e5e83b7435a 100644 --- a/unit_tests/graphRegex/Test_FsmMatch.cpp +++ b/unit_tests/graphRegex/Test_FsmMatch.cpp @@ -14,14 +14,14 @@ using namespace Aidge; TEST_CASE("FsmMatch") { SECTION("Construction") { - std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { - {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, - {"B",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, - {"C",std::make_shared<ConditionalInterpreter>("true==true")} + std::vector<std::shared_ptr<ConditionalInterpreter>> allTest = { + std::make_shared<ConditionalInterpreter>("A","isConv($)==true"), + std::make_shared<ConditionalInterpreter>("B","isConv($)==true"), + std::make_shared<ConditionalInterpreter>("C","true==true") }; - allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); - allTest["B"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest[0]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest[1]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->A",allTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); @@ -41,14 +41,14 @@ TEST_CASE("FsmMatch") { g1->addChild(conv1, "c"); - REQUIRE(allTest["A"]->test(conv) == true); - REQUIRE(allTest["B"]->test(conv) == true); + REQUIRE(allTest[0]->test(conv) == true); + REQUIRE(allTest[1]->test(conv) == true); std::vector<std::shared_ptr<Node>> startNodes = {conv}; auto result = fsm->test(startNodes); - REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1}); + REQUIRE( result[0]->getAll() == std::set<NodePtr>{conv,conv1}); } @@ -70,19 +70,20 @@ TEST_CASE("FsmMatch") { ///////////// - std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { - {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, - {"B",std::make_shared<ConditionalInterpreter>("isFc($)==true")} + std::vector<std::shared_ptr<ConditionalInterpreter>> allTest = { + std::make_shared<ConditionalInterpreter>("A","isConv($)==true"), + std::make_shared<ConditionalInterpreter>("B","isFc($)==true") }; - allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); - allTest["B"]->insertLambda("isFc",+[](NodePtr NodeOp){return NodeOp->type() == "Fc";}); + allTest[0]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest[1]->insertLambda("isFc",+[](NodePtr NodeOp){return NodeOp->type() == "Fc";}); std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A#->A; A#->B",allTest); std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); std::vector<std::shared_ptr<Node>> startNodes = {conv,conv}; auto result = fsm->test(startNodes); - REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1,conv2}); + + REQUIRE( result[0]->getAll() == std::set<NodePtr>{conv,conv1,conv2}); } diff --git a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp index 9ce090506c9a61abd928b3ae590ee838afb05999..e789677d44efa68071017a9832fa01b5ed340f75 100644 --- a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp +++ b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp @@ -8,10 +8,10 @@ using namespace Aidge; TEST_CASE("GraphFsmInterpreter", "GraphFsmInterpreter") { SECTION("Construction") { - std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { - {"A",std::make_shared<ConditionalInterpreter>("true==true")}, - {"B",std::make_shared<ConditionalInterpreter>("true==true")}, - {"C",std::make_shared<ConditionalInterpreter>("true==true")} + std::vector<std::shared_ptr<ConditionalInterpreter>> allTest = { + std::make_shared<ConditionalInterpreter>("A","true==true"), + std::make_shared<ConditionalInterpreter>("B","true==true"), + std::make_shared<ConditionalInterpreter>("C","true==true") }; //GraphFsmInterpreter("A->B",allTest); diff --git a/unit_tests/graphRegex/Test_GraphRegex.cpp b/unit_tests/graphRegex/Test_GraphRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b30560ea3ea696821d2422bf760a11973a104e85 --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphRegex.cpp @@ -0,0 +1,84 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphRegex.hpp" + +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" + +using namespace Aidge; + +TEST_CASE("GraphRegexUser") { + + SECTION("INIT") { + + const std::string query = "Conv->FC"; + + std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> fc = GenericOperator("FC", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> fc2 = GenericOperator("FC", 1, 1, 1, "c3"); + + g1->add(conv); + g1->addChild(fc, "c"); + g1->addChild(conv2, "c1"); + g1->addChild(fc2, "c2"); + + + sut->setKeyFromGraph(g1); + sut->addQuery(query); + + for (const auto& solution : sut->match(g1)) { + + REQUIRE(solution->getQuery() == query); + if(solution->getStartNode() == std::vector<NodePtr>{conv}){ + REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv} ); + REQUIRE(solution->at("FC") == std::set<NodePtr>{fc} ); + }else if (solution->getStartNode() == std::vector<NodePtr>{conv2}) + { + REQUIRE(solution->at("Conv") == std::set<NodePtr>{conv2} ); + REQUIRE(solution->at("FC") == std::set<NodePtr>{fc2} ); + } + } + //REQUIRE( sut->match(g1)[1]->getAll() == std::set<NodePtr>{conv,fc}); + + } + + SECTION("CC") { + + + + std::shared_ptr<GraphRegex> sut = std::make_shared<GraphRegex>(); + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + + g1->add(conv); + g1->addChild(conv1, "c"); + g1->addChild(conv2, "c1"); + g1->addChild(conv3, "c2"); + + + sut->setKeyFromGraph(g1); + + const std::string query = "Conv->Conv"; + const std::string query2 = "Conv->FC"; + + sut->setNodeKey("FC","getType($) =='FC'"); + + sut->addQuery(query); + sut->addQuery(query2); + + + for (const auto& solution : sut->match(g1)) { + REQUIRE(solution->getQuery() == query); + } + + } +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp index 8b502fb546e2f1396b629ebc78bc1bd4d67842e2..ec068358a34567e57c417a664284bd1db76d7a69 100644 --- a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp +++ b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp @@ -12,19 +12,44 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("custom Lambda") { - const std::string test = " !toto($) == true " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); - conditionalParser.insertLambda("toto",+[](NodePtr NodeOp){return false;}); + + ConditionalInterpreter conditionalParserB = ConditionalInterpreter("A"," bad($) == false "); + ConditionalInterpreter conditionalParserG = ConditionalInterpreter("A"," good($) == true "); + + + conditionalParserB.insertLambda("bad",+[](NodePtr NodeOp){return NodeOp->name() == "ZZ";}); + conditionalParserG.insertLambda("good",+[](NodePtr NodeOp){return NodeOp->name() == "Gop1";}); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); - bool result = conditionalParser.test(nodeOp); - REQUIRE(result == true); + REQUIRE(conditionalParserB.test(nodeOp) == true); + REQUIRE(conditionalParserG.test(nodeOp) == true); + } + + + ConditionalInterpreter conditionalParserT = ConditionalInterpreter("A","isConv($)==true"); + conditionalParserT.insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + std::shared_ptr<Node> zz = GenericOperator("conv", 0, 0, 0, "Gop1"); + conditionalParserT.test(zz); + + SECTION("Lambdas") { + ConditionalInterpreter conditionalParser = ConditionalInterpreter("OP_test","getType($) =='Conv' || getType($) =='FC' "); + + std::shared_ptr<Node> A = GenericOperator("Conv", 0, 0, 0, "A"); + REQUIRE(conditionalParser.test(A) == true); + + std::shared_ptr<Node> B = GenericOperator("FC", 0, 0, 0, "B"); + REQUIRE(conditionalParser.test(B) == true); + + + std::shared_ptr<Node> C = GenericOperator("A", 0, 0, 0, "C"); + conditionalParser.test(C); + REQUIRE(conditionalParser.test(C) == false); } SECTION("syntax error") { const std::string test = "'A' == 'A' ,&& "; - REQUIRE_THROWS_AS( ConditionalInterpreter(test), std::runtime_error); + REQUIRE_THROWS_AS( ConditionalInterpreter("A",test), std::runtime_error); } @@ -32,7 +57,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("test false int ") { const std::string test = " 10 == 11 " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); bool result = conditionalParser.test(nodeOp); REQUIRE(result == false); @@ -40,7 +65,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("test true int ") { const std::string test = " 42 == 42 " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); bool result = conditionalParser.test(nodeOp); REQUIRE(result == true); @@ -48,7 +73,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("test false str ") { const std::string test = " 'toto' == 'Corgi' " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); bool result = conditionalParser.test(nodeOp); REQUIRE(result == false); @@ -57,7 +82,7 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("test true str ") { const std::string test = " 'Corgi' == 'Corgi' " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); bool result = conditionalParser.test(nodeOp); REQUIRE(result == true); diff --git a/unit_tests/operator/Test_ConvDepthWise_Op.cpp b/unit_tests/operator/Test_ConvDepthWise_Op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ef68c439d3a3cdf95b7122c1b41bc9fc97311f2d --- /dev/null +++ b/unit_tests/operator/Test_ConvDepthWise_Op.cpp @@ -0,0 +1,68 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <cstddef> +#include <memory> +#include <string> +#include <vector> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +// TEST_CASE("[core/operator] ConvDepthWise_Op(computeReceptiveField)", "[Operator][computeReceptiveFiled][ConvDepthWise]") { +// auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); +// auto conv1 = ConvDepthWise({5, 5}, "conv1"); // output dims: {16, 3, 220, 220} +// auto conv2 = ConvDepthWise({3, 3}, "conv2"); // output dims: {16, 3, 218, 218} +// auto conv3 = ConvDepthWise({2, 2}, "conv3", {2,2}); // output dims: {16, 3, 109, 109} +// auto conv4 = ConvDepthWise({1, 1}, "conv4"); // output dims: {16, 3, 109, 109} + +// auto g = std::make_shared<GraphView>("TestGraph"); + +// dataProvider->addChild(conv1, 0); +// g->add(conv1); +// g->addChild(conv2, conv1, 0); +// g->addChild(conv3, conv2, 0); +// g->addChild(conv4, conv3, 0); + +// g->forwardDims(); + +// SECTION("Check individual receptive fields") { +// auto res1 = conv1->getOperator()->computeReceptiveField(0, {16,3,10,10}); +// auto res2 = conv2->getOperator()->computeReceptiveField(conv2->getOperator()->output(0).getIdx({3,1,100,28}), {4,2,30,40}); +// auto res3 = conv3->getOperator()->computeReceptiveField(0, {1,1,109,109}); +// auto res4 = conv4->getOperator()->computeReceptiveField(conv4->getOperator()->input(0).getIdx({5,0,108,108}), {10,1,1,1}); + +// REQUIRE(((res1[0].first == 0) && (res1[0].second == std::vector<DimSize_t>({16, 3, 14, 14})))); +// REQUIRE(((res2[0].first == conv2->getOperator()->input(0).getIdx({3,1,100,28})) && (res2[0].second == std::vector<DimSize_t>({4, 2, 32, 42})))); +// REQUIRE(((res3[0].first == 0) && (res3[0].second == std::vector<DimSize_t>({1, 1, 218, 218})))); +// REQUIRE(((res4[0].first == conv4->getOperator()->input(0).getIdx({5, 0, 108, 108})) && (res4[0].second == std::vector<DimSize_t>({10, 1, 1, 1})))); +// } + +// SECTION("Check receptive field propagation") { +// // input: first-{5, 0, 50, 50} dims-{1, 1, 1, 1} +// auto res4 = conv4->getOperator()->computeReceptiveField(conv4->getOperator()->input(0).getIdx({5,0,50,50}), {1,1,1,1}); +// // conv4 RF: first-{5, 0, 50, 50} dims-{1, 1, 1, 1} +// auto res3 = conv3->getOperator()->computeReceptiveField(res4[0].first, res4[0].second); +// // conv3 RF: first-{5, 0, 100, 100} dims-{1, 1, 2, 2} +// auto res2 = conv2->getOperator()->computeReceptiveField(res3[0].first, res3[0].second); +// // conv2 RF: first-{5, 0, 100, 100} dims-{1, 1, 4, 4} +// auto res1 = conv1->getOperator()->computeReceptiveField(res2[0].first, res2[0].second); +// // conv1 RF: first-{5, 0, 100, 100} dims-{1, 1, 8, 8} + +// REQUIRE(((res1[0].first == conv1->getOperator()->input(0).getIdx({5, 0, 100, 100})) && (res1[0].second == std::vector<DimSize_t>({1, 1, 8, 8})))); +// } +// } +} // namespace Aidge \ No newline at end of file diff --git a/unit_tests/operator/Test_Conv_Op.cpp b/unit_tests/operator/Test_Conv_Op.cpp new file mode 100644 index 0000000000000000000000000000000000000000..ac667ec5af69dccc3e421530a17aca88018aab09 --- /dev/null +++ b/unit_tests/operator/Test_Conv_Op.cpp @@ -0,0 +1,79 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <cstddef> +#include <memory> +#include <string> +#include <vector> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +// TEST_CASE("[core/operator] Conv_Op(computeReceptiveField)", "[Operator][computeReceptiveField][Conv]") { +// auto dataProvider = Producer({16, 3, 224, 224}, "dataProvider"); +// auto conv1 = Conv(3, 32, {5, 5}, "conv1"); // output dims: {16, 32, 220, 220} +// auto conv2 = Conv(32, 64, {3, 3}, "conv2"); // output dims: {16, 64, 218, 218} +// auto conv3 = Conv(64, 10, {2, 2}, "conv3", {2,2}); // output dims: {16, 10, 109, 109} +// auto conv4 = Conv(10, 10, {1, 1}, "conv4"); // output dims: {16, 10, 109, 109} + +// auto g = std::make_shared<GraphView>("TestGraph"); + +// dataProvider->addChild(conv1, 0); +// g->add(conv1); +// g->addChild(conv2, conv1, 0); +// g->addChild(conv3, conv2, 0); +// g->addChild(conv4, conv3, 0); + +// g->forwardDims(); + +// SECTION("Check individual receptive fields") { +// auto res1 = conv1->getOperator()->computeReceptiveField(0, {16,32,10,10}); +// auto res2 = conv2->getOperator()->computeReceptiveField(conv2->getOperator()->output(0).getIdx({3,20,100,28}), {4,20,30,40}); +// auto res3 = conv3->getOperator()->computeReceptiveField(0, {1,1,109,109}); +// auto res4 = conv4->getOperator()->computeReceptiveField(conv4->getOperator()->output(0).getIdx({5,0,108,108}), {10,10,1,1}); + +// REQUIRE(((res1[0].first == 0) && (res1[0].second == std::vector<DimSize_t>({16, 3, 14, 14})))); +// REQUIRE(((res2[0].first == conv2->getOperator()->input(0).getIdx({3,0,100,28})) && (res2[0].second == std::vector<DimSize_t>({4, 32, 32, 42})))); +// REQUIRE(((res3[0].first == 0) && (res3[0].second == std::vector<DimSize_t>({1, 64, 218, 218})))); +// REQUIRE(((res4[0].first == conv4->getOperator()->input(0).getIdx({5, 0, 108, 108})) && (res4[0].second == std::vector<DimSize_t>({10, 10, 1, 1})))); +// } + +// SECTION("Check receptive field propagation") { +// // input: first-{5, 0, 50, 50} dims-{1, 1, 1, 1} +// auto res4 = conv4->getOperator()->computeReceptiveField(conv4->getOperator()->output(0).getIdx({5,0,50,50}), {1,1,1,1}); +// // conv4 RF: first-{5, 0, 50, 50} dims-{1, 10, 1, 1} +// auto res3 = conv3->getOperator()->computeReceptiveField(res4[0].first, res4[0].second); +// // conv3 RF: first-{5, 0, 100, 100} dims-{1, 64, 2, 2} +// auto res2 = conv2->getOperator()->computeReceptiveField(res3[0].first, res3[0].second); +// // conv2 RF: first-{5, 0, 100, 100} dims-{1, 32, 4, 4} +// auto res1 = conv1->getOperator()->computeReceptiveField(res2[0].first, res2[0].second); +// // conv1 RF: first-{5, 0, 100, 100} dims-{1, 3, 8, 8} + +// REQUIRE(((res1[0].first == conv1->getOperator()->input(0).getIdx({5, 0, 100, 100})) && (res1[0].second == std::vector<DimSize_t>({1, 3, 8, 8})))); + + +// // std::cout << "conv1: {"; +// // std::cout << conv1->getOperator()->input(0).getCoord(res1[0].first)[0] << ", " +// // << conv1->getOperator()->input(0).getCoord(res1[0].first)[1] << ", " +// // << conv1->getOperator()->input(0).getCoord(res1[0].first)[2] << ", " +// // << conv1->getOperator()->input(0).getCoord(res1[0].first)[3] << "} - {"; +// // std::cout << res1[0].second[0] << ", " +// // << res1[0].second[1] << ", " +// // << res1[0].second[2] << ", " +// // << res1[0].second[3] << "}" << std::endl; +// } +// } +} // namespace Aidge \ No newline at end of file diff --git a/unit_tests/operator/Test_Operator.cpp b/unit_tests/operator/Test_Operator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a050bbc4021b0c70a0d8faf6478eb2bd13ebdb58 --- /dev/null +++ b/unit_tests/operator/Test_Operator.cpp @@ -0,0 +1,50 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include <cstddef> +#include <iostream> +#include <memory> +#include <string> +#include <vector> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/operator/Producer.hpp" + +namespace Aidge { +// TEST_CASE("[core/operator] Operator(computeReceptiveField)", "[Operator][computeReceptiveFiled]") { +// auto dataProvider1 = Producer({16, 3, 224, 224}, "dataProvider1"); +// auto dataProvider2 = Producer({16, 3, 224, 224}, "dataProvider2"); +// auto gen1 = Add(2); +// auto gen2 = ReLU(); + +// auto g = std::make_shared<GraphView>("TestGraph"); + +// dataProvider1->addChild(gen1, 0); +// dataProvider2->addChild(gen1, 0); +// g->add(gen1); +// g->addChild(gen2, gen1, 0); + +// g->forwardDims(); + +// SECTION("Check individual receptive fields") { +// auto res1 = gen1->getOperator()->computeReceptiveField(0, {16,3,10,10}); +// auto res2 = gen2->getOperator()->computeReceptiveField(gen2->getOperator()->output(0).getIdx({3,2,100,28}), {1,1,30,40}); + +// REQUIRE(((res1[0].first == 0) && (res1[0].second == std::vector<DimSize_t>({16, 3, 10, 10})))); +// REQUIRE(((res1[1].first == 0) && (res1[1].second == std::vector<DimSize_t>({16, 3, 10, 10})))); +// REQUIRE(((res2[0].first == gen2->getOperator()->input(0).getIdx({3,2,100,28})) && (res2[0].second == std::vector<DimSize_t>({1, 1, 30, 40})))); +// } +// } +} // namespace Aidge \ No newline at end of file diff --git a/unit_tests/recipies/Test_FuseBatchNorm.cpp b/unit_tests/recipies/Test_FuseBatchNorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13facefd2979a9b0ca4409ead6972013cb1bc0a8 --- /dev/null +++ b/unit_tests/recipies/Test_FuseBatchNorm.cpp @@ -0,0 +1,70 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +/* +#include <catch2/catch_test_macros.hpp> +#include <set> + + +//#include "aidge/backend/cpu/operator/BatchNormImpl.hpp" +//#include "aidge/backend/cpu/operator/ConvImpl.hpp" + + + +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/BatchNorm.hpp" +#include "aidge/utils/Recipies.hpp" + +//#include "aidge/backend/TensorImpl.hpp" +//#include "aidge/backend/cpu.hpp" +//#include "aidge/" + +#include <cstddef> + + +namespace Aidge { + + + TEST_CASE("[FuseBatchNorm] conv") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + BatchNorm<2>() + }); + + g1->setDatatype(DataType::Float32); + g1->setBackend("cpu"); + g1->forwardDims(); + + // std::set<std::string> availableBackends = Tensor::getAvailableBackends(); + // if (availableBackends.find("cpu") != availableBackends.end()){ + // g1->setBackend("cpu"); + // newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); + // }else{ + // printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); + // } + + fuseBatchNorm(g1); + + SECTION("Check resulting nodes") { + // REQUIRE(g1->getNodes().size() == 2); + // REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + // REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + // REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + // REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + // REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } + } + +} +*/ \ No newline at end of file diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index da53642055a3146c71a211ad7816f21c9b92d6cd..39a4f794944e7889bac17eaab94eabf44587903c 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -12,26 +12,23 @@ #include <catch2/catch_test_macros.hpp> #include <set> -// #include "aidge/backend/cpu/operator/AddImpl.hpp" -// #include "aidge/backend/cpu/operator/ConvImpl.hpp" -// #include "aidge/backend/cpu/operator/FCImpl.hpp" -// #include "aidge/backend/cpu/operator/MatMulImpl.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/operator/Add.hpp" #include "aidge/operator/FC.hpp" #include "aidge/operator/MatMul.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/utils/Recipies.hpp" +#include "aidge/recipies/Recipies.hpp" namespace Aidge { + TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // generate the original GraphView auto matmul0 = MatMul(5, "matmul0"); - auto add0 = Add<2>("add0"); + auto add0 = Add(2, "add0"); auto matmul1 = MatMul(5, "matmul1"); - auto add1 = Add<2>("add1"); + auto add1 = Add(2, "add1"); auto b0 = Producer({5}, "B0"); auto w0 = Producer({5, 5}, "W0"); @@ -74,4 +71,5 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); } } + } // namespace Aidge \ No newline at end of file diff --git a/unit_tests/recipies/Test_HorizontalTiling.cpp b/unit_tests/recipies/Test_HorizontalTiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c9fb5ed6dc8a5d994ce2d3434a8176c29e418f95 --- /dev/null +++ b/unit_tests/recipies/Test_HorizontalTiling.cpp @@ -0,0 +1,200 @@ +// /******************************************************************************** +// * Copyright (c) 2023 CEA-List +// * +// * This program and the accompanying materials are made available under the +// * terms of the Eclipse Public License 2.0 which is available at +// * http://www.eclipse.org/legal/epl-2.0. +// * +// * SPDX-License-Identifier: EPL-2.0 +// * +// ********************************************************************************/ + +// #include <catch2/catch_test_macros.hpp> +// #include <set> + +// #include "aidge/graph/GraphView.hpp" +// #include "aidge/graph/OpArgs.hpp" +// #include "aidge/operator/Conv.hpp" +// #include "aidge/operator/ReLU.hpp" +// #include "aidge/recipies/Recipies.hpp" + + +// namespace Aidge { + +// TEST_CASE("[core/recipies] Tiling(transformation)", "[Tiling][Recipies]") { + +// SECTION("Transform a pre-generated GraphView") { + +// SECTION("Simple Node: Conv") { +// std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv"); +// myConv->getOperator()->setDatatype(DataType::Int32); +// myConv->getOperator()->setBackend("cpu"); +// std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<int,4,3,3,3> { +// { +// { +// {{ 0, 1, 2}, +// { 3, 4, 5}, +// { 6, 7, 8}}, +// {{ 9, 10, 11}, +// { 12, 13, 14}, +// { 15, 16, 17}}, +// {{ 18, 19, 20}, +// { 21, 22, 23}, +// { 24, 25, 26}} +// }, +// { +// {{ 27, 28, 29}, +// { 30, 31, 32}, +// { 33, 34, 35}}, +// {{ 36, 37, 38}, +// { 39, 40, 41}, +// { 42, 43, 44}}, +// {{ 45, 46, 47}, +// { 48, 49, 50}, +// { 51, 52, 53}} +// }, +// { +// {{ 54, 55, 56}, +// { 57, 58, 59}, +// { 60, 61, 62}}, +// {{ 63, 64, 65}, +// { 66, 67, 68}, +// { 69, 70, 71}}, +// {{ 72, 73, 74}, +// { 75, 76, 77}, +// { 78, 79, 80}} +// }, +// { +// {{ 81, 82, 83}, +// { 84, 85, 86}, +// { 87, 88, 89}}, +// {{ 90, 91, 92}, +// { 93, 94, 95}, +// { 96, 97, 98}}, +// {{ 99, 100, 101}, +// {102, 103, 104}, +// {105, 106, 107}} +// } +// } +// }); +// std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<int,4> {{7,0,9,0}}); +// std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<int,2,3,5,5> { //NCHW +// { +// { +// {{ 0, 1, 2, 3, 4}, +// { 5, 6, 7, 8, 9}, +// { 10, 11, 12, 13, 14}, +// { 15, 16, 17, 18, 19}, +// { 20, 21, 22, 23, 24}}, + +// {{ 25, 26, 27, 28, 29}, +// { 30, 31, 32, 33, 34}, +// { 35, 36, 37, 38, 39}, +// { 40, 41, 42, 43, 44}, +// { 45, 46, 47, 48, 49}}, + +// {{ 50, 51, 52, 53, 54}, +// { 55, 56, 57, 58, 59}, +// { 60, 61, 62, 63, 64}, +// { 65, 66, 67, 68, 69}, +// { 70, 71, 72, 73, 74}} +// }, +// { +// {{ 75, 76, 77, 78, 79}, +// { 80, 81, 82, 83, 84}, +// { 85, 86, 87, 88, 89}, +// { 90, 91, 92, 93, 94}, +// { 95, 96, 97, 98, 99}}, + +// {{100, 101, 102, 103, 104}, +// {105, 106, 107, 108, 109}, +// {110, 111, 112, 113, 114}, +// {115, 116, 117, 118, 119}, +// {120, 121, 122, 123, 124}}, + +// {{125, 126, 127, 128, 129}, +// {130, 131, 132, 133, 134}, +// {135, 136, 137, 138, 139}, +// {140, 141, 142, 143, 144}, +// {145, 146, 147, 148, 149}} +// } +// } +// }); +// std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<int,2,4,3,3> { +// { +// { +// {{ 15226, 15577, 15928}, +// { 16981, 17332, 17683}, +// { 18736, 19087, 19438}}, + +// {{ 37818, 38898, 39978}, +// { 43218, 44298, 45378}, +// { 48618, 49698, 50778}}, + +// {{ 60426, 62235, 64044}, +// { 69471, 71280, 73089}, +// { 78516, 80325, 82134}}, + +// {{ 83016, 85554, 88092}, +// { 95706, 98244, 100782}, +// {108396, 110934, 113472}} +// }, +// { +// {{ 41551, 41902, 42253}, +// { 43306, 43657, 44008}, +// { 45061, 45412, 45763}}, + +// {{118818, 119898, 120978}, +// {124218, 125298, 126378}, +// {129618, 130698, 131778}}, + +// {{196101, 197910, 199719}, +// {205146, 206955, 208764}, +// {214191, 216000, 217809}}, + +// {{273366, 275904, 278442}, +// {286056, 288594, 291132}, +// {298746, 301284, 303822}} +// } +// } +// }); +// myConv->getOperator()->associateInput(0,myInput); +// myConv->getOperator()->associateInput(1,myWeights); +// myConv->getOperator()->associateInput(2,myBias); +// myConv->getOperator()->computeOutputDims(); + +// std::shared_ptr<GraphView> g; +// g->add(myConv); +// horizontalTiling({myConv}, 3); + +// SequentialScheduler s(g); +// s->forward(); + +// // myConv->getOperator()->getOutput(0)->print(); +// REQUIRE(*(myConv->getOperator()->getOutput(0)) == *myOutput); +// } +// } +// } +// } +// // std::shared_ptr<GraphView> g = Sequential({ +// // Conv(3, 16, {3,3}, "conv1"), +// // ReLU("relu1"), +// // Conv(16, 32, {1,1}, "conv2"), +// // Conv(32, 16, {1,1}, "conv3"), +// // Conv(16, 10, {3,3}, "conv4"), +// // ReLU("relu2") +// // }); + +// // for (auto& individualConv : g->match("Conv")) { +// // auto tiledConv = horizontalTiling(individualConv); +// // g->replace(individualConv, tiledConv); +// // } +// // } + +// // SECTION("Create the GraphView with tiled layers") { +// // std::shared_ptr<GraphView> g; +// // g->addChild(horizontalTiling(Conv())) +// // } + +// // } +// // } // namespace Aidge \ No newline at end of file