diff --git a/.gitlab/ci/build.gitlab-ci.yml b/.gitlab/ci/build.gitlab-ci.yml index cd56a55fa7e9cbcefba4715188fd270462e81976..6bfae0be1e31a89d27413677fa4cdc4612561333 100644 --- a/.gitlab/ci/build.gitlab-ci.yml +++ b/.gitlab/ci/build.gitlab-ci.yml @@ -17,6 +17,66 @@ build:ubuntu_cpp: - build_cpp/ - install_cpp/ +build:ubuntu_cpp_g++10: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y g++-10 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/g++-10 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_g++12: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y g++-12 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/g++-12 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_clang12: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y clang-12 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/clang++-12 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_clang15: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y clang-15 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/clang++-15 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + build:ubuntu_python: stage: build needs: [] @@ -65,3 +125,32 @@ build:windows_cpp: paths: - build_cpp/ - install_cpp/ + +build:windows_python: + stage: build + needs: [] + tags: + - windows + + image: buildtools + before_script: + # Install Chocolatey + - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + # Install dependencies + - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y + - choco install git -Y + - choco install python -Y + # Update PATH + - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + script: + - python -m pip install virtualenv + - virtualenv venv + - venv\Scripts\Activate.ps1 + # Numpy dependancy for unit test + - python -m pip install numpy + - $env:AIDGE_INSTALL = "$pwd" + "install" + - python -m pip install . + artifacts: + expire_in: 1 week + paths: + - venv/ diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 5532fb533dc5adb74f261bd780a20bc824b9c2d7..7aceff35aa4f59e6b5f007423d2722eeb9528fc4 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -39,7 +39,8 @@ #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/MatMul.hpp" #include "aidge/operator/MaxPooling.hpp" -//#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/Operator.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/ReLU.hpp" diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 1f1eeafa859b116606613392a13a65ad398669ad..89ba148497709f0af475bbf953ff285c88036102 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -124,7 +124,7 @@ public: } /** - * @brief List dataInput connections of the GraphView object's inputNodes. + * @brief List outside dataInput connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; @@ -137,7 +137,7 @@ public: inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } /** - * @brief List input connections of the GraphView object's inputNodes. + * @brief List outside input connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index dbe017fc7f8935e83ff1672392992c75a2e0658c..1d8449ac25cf8c31192da0c350c14cbfa50a48f4 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -399,6 +399,18 @@ public: return node->clone(); } + + /** + * @brief Get the set of pointers to connected node at a distance of a delta. + * @details the recution are cut + * Return a nullptr is nofing found. + * @param delta Input delta. + * @return std::shared_ptr<Node> + */ + + std::set<NodePtr> getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee); + + private: /////////////////////////////////////////////////////// // OPERATORS diff --git a/include/aidge/graphRegex/GraphFsmInterpreter.hpp b/include/aidge/graphRegex/GraphFsmInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e92b6fe8fc9d5e44cb8051e687e33d7192e0eb7 --- /dev/null +++ b/include/aidge/graphRegex/GraphFsmInterpreter.hpp @@ -0,0 +1,73 @@ +#ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ +#define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ + +#include <string> +#include <memory> + +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +namespace Aidge { + + class GraphFsmInterpreter + { + private: + /* data */ + GraphParser mParser; + std::size_t mActGroupe; + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> mNodesCondition; + + public: + GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition); + virtual ~GraphFsmInterpreter() =default; + + + std::shared_ptr<FsmGraph> interpret(void); + + private: + + + std::shared_ptr<FsmGraph> visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree); + + /** + * @defgroup graphFsmInterpreterF Functions for interpreting AST nodes + * @brief For each node type in the AST, define how build the FsmGraph + */ + + + /** + * @ingroup graphFsmInterpreterF + * @brief leaf of fsm make the fsm for test one transition + */ + std::shared_ptr<FsmGraph> keyF(std::shared_ptr<AstNode<gRegexTokenTypes>> AstNode); + /** + * @ingroup graphFsmInterpreterF + * @brief combine two fsm of two expression. + */ + std::shared_ptr<FsmGraph> sepF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm); + /** + * @ingroup graphFsmInterpreterF + * @brief combine two to make a new that match leftFsm next rigthFsm + */ + std::shared_ptr<FsmGraph> nextF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm); + /** + * @ingroup graphFsmInterpreterF + * @brief make the fsm match + + */ + std::shared_ptr<FsmGraph> qomF(std::shared_ptr<FsmGraph> fsm); + /** + * @ingroup graphFsmInterpreterF + * @brief make the fsm match * + */ + std::shared_ptr<FsmGraph> qzmF(std::shared_ptr<FsmGraph> fsm); + + }; + + + +} + + +#endif // AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ diff --git a/include/aidge/graphRegex/GraphLexer.hpp b/include/aidge/graphRegex/GraphLexer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e4137ab093c466b7349007da91e032dae48eda51 --- /dev/null +++ b/include/aidge/graphRegex/GraphLexer.hpp @@ -0,0 +1,68 @@ +#ifndef AIDGE_CORE_GRAPH_LEXER_H_ +#define AIDGE_CORE_GRAPH_LEXER_H_ + +#include <string> +#include <memory> +#include <regex> +#include <stdexcept> //error +#include <sstream> + +#include "aidge/utilsParsing/ParsingToken.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +namespace Aidge { + + class GraphLexer + { + + public: + GraphLexer( const std::string gRegexExpressions ); + + /** + * @brief Get the next token on the gRegexExpressions + * @return ConditionalToken + */ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> getNextToken(void); + /** + * @brief Restart at the start of the gRegexExpressions + * + */ + void rstPosition(void); + + /** + * @brief Test if the string is completely read + * @return bool + */ + bool isEnd(void); + + + /** + * @brief Get the representation of the class + * @return string + */ + const std::string rep(); + + private: + + /** + * @brief Constructs an error message to display the character not understood by the lexer + * @return error mesage + */ + std::runtime_error badTokenError(const std::string& currentChars,std::size_t position); + + /** + * @brief The expression of the test to be performed on the nodes + */ + const std::string mRegularExpressions; + /** + * @brief The lexer's current position in mConditionalExpressions + */ + std::size_t mPosition; + + }; +} + + + + +#endif //AIDGE_CORE_GRAPH_LEXER_H_ diff --git a/include/aidge/graphRegex/GraphParser.hpp b/include/aidge/graphRegex/GraphParser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..73406203a8be87e1df75cc694ab1ff281c27fbfa --- /dev/null +++ b/include/aidge/graphRegex/GraphParser.hpp @@ -0,0 +1,98 @@ +#ifndef AIDGE_CORE_GRAPH_PARSER_H_ +#define AIDGE_CORE_GRAPH_PARSER_H_ + + +#include <memory> // for shared_ptr +#include "aidge/graphRegex/GraphLexer.hpp" +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +namespace Aidge{ + +/** + * @brief this class uses the lexer to create an AST according to a set of gramer rules + */ +class GraphParser{ + + public: + /** + * @brief AST graph creation function + * @param gRegexExpressions String representing the logical fuction to be performed + */ + GraphParser(const std::string gRegexExpressions); + + virtual ~GraphParser() = default; + + /** + * @brief AST graph creation function + * @return The AST tree + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> parse(void); + + + private: + /** + * @brief restart at the start of the ConditionalExpressions for LEXER and restart mCurrentToken + */ + void rstParser(void); + + ////////////////// + + /** + * @defgroup ParsingFunctions Function for creating AST + * @brief Functions for recursive construction of the AST representing grammar rules + */ + + /** + * @ingroup ParsingFunctions + * @brief Token reading and verification function + * + */ + void ackToken(gRegexTokenTypes tokenType); + + //TODO TODO + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for key : KEY(QOM | QZM)? | CKEY + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstExp(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for sequence : seq :exp (NEXT seq)* + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstSeq(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for domain : (seq NEXT domain)? | LPAREN domain RPAREN (QOM | QZM) (NEXT domain)? + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstDomain(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for multiple exepresion : allExpr: domain (SEP allExpr)* + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstAllExpr(void); + + + /** + * @brief The actual token in the parce + */ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> mCurrentToken; + + /** + * @brief The lexem use + */ + GraphLexer mLexer; + +}; + + +} + +#endif //AIDGE_CORE_GRAPH_PARSER_H_ diff --git a/include/aidge/graphRegex/GraphRegexTypes.hpp b/include/aidge/graphRegex/GraphRegexTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e35f8f027eb363b71358922ffe0caa4a55fff1d --- /dev/null +++ b/include/aidge/graphRegex/GraphRegexTypes.hpp @@ -0,0 +1,29 @@ + +#ifndef AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ +#define AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ + + +namespace Aidge { + /** + * @brief enum for all types of token use in the of the regex + * 7-5 type + * 4-0 id + */ + enum class gRegexTokenTypes + { + STOP, + NEXT, /**< -> */ + + QOM, /**< + */ + QZM, /**< * */ + + KEY, /**< [A-Za-z_0-9]+ */ + CKEY, /**< [A-Za-z_0-9]+#[0-9]* */ + + SEP, /**< \( */ + LPAREN, /**< \( */ + RPAREN, /**< \) */ + }; + +} +#endif //AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ diff --git a/include/aidge/graphRegex/GraphStrInterpreter.hpp b/include/aidge/graphRegex/GraphStrInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..98dca0e9f84de0be2614aed0e47c9d86ae552674 --- /dev/null +++ b/include/aidge/graphRegex/GraphStrInterpreter.hpp @@ -0,0 +1,40 @@ +#ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ +#define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ + +#include <iostream> +#include <sstream> +#include <memory> +#include <algorithm> + +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +namespace Aidge { + + class GraphStrInterpreter + { + private: + /* data */ + GraphParser mParser; + std::string mToTest; + public: + GraphStrInterpreter(const std::string graphMatchExpr); + virtual ~GraphStrInterpreter() =default; + + + std::string interpret(void); + + private: + + + std::string visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree); + }; + + + +} + + +#endif //AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3eae528808dbdb8023718c961b7c45cbf4afac9 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp @@ -0,0 +1,228 @@ +#ifndef AIDGE_CORE_FSM_EDGE_H_ +#define AIDGE_CORE_FSM_EDGE_H_ + +#include <memory> +#include <set> +#include <string> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + + +namespace Aidge{ + + class FsmNode; + class FsmRunTimeContext; + + struct EdgeTestResult { + bool success; + std::set<NodePtr> node; + }; + + /** + * @brief virtual class use test the node on the node to validate + */ + class FsmEdge: public std::enable_shared_from_this<FsmEdge> + { + private: + + /** + * @brief the relative position to this test relative to all the const key + * first is common id, second is the relative position + */ + std::map<size_t,int> mRelativePos; + /** + * @brief the ptr on the source node + */ + std::shared_ptr<FsmNode> mNodeSource; + /** + * @brief the ptr on the dest node + */ + std::shared_ptr<FsmNode> mNodeDest; + /** + * @brief the weak ptr + */ + std::weak_ptr<FsmEdge> weakPtr; + + public: + FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest); + + virtual ~FsmEdge(){}; + + FsmEdge() : weakPtr(shared_from_this()) {} + + + /** + * @brief test is the validation of the node, it must be defined for all types of edge + * it takes as argument an FSM traversal context and returns a set of next nodes + * @return set of next node or nullptr if not next + */ + + virtual const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) =0; + + /** + * @brief test is the egde test a common node + * @return true if is a common + */ + virtual bool isCommon(void); + /** + * @brief get the Common idx of the common test in this edge (if is a common edge) + * @return idx of the common + */ + virtual size_t getCommonIdx(void); + /** + * @brief get the relative postion to the common node deffine in this edge + * @return map + */ + const std::map<size_t,int>& getRelative(void); + /** + * @brief add new relative position + */ + void updateRelative( const std::map<size_t,int>& relativePos ); + /** + * @brief get source FsmNode + * @return FsmNode + */ + std::shared_ptr<FsmNode> getSourceNode(void); + /** + * @brief set a new source to the edge + * @return FsmNode + */ + void reSetSouceNode(const std::shared_ptr<FsmNode>& newSource); + /** + * @brief get dest FsmNode + * @return FsmNode + */ + std::shared_ptr<FsmNode> getDestNode(void); + /** + * @brief set a new dest to the edge + * @return FsmNode + */ + void reSetDestNode(const std::shared_ptr<FsmNode>& newDest); + /** + * @brief propagate the edge mRelativePos to the others Edge and recalcul the relative position + */ + void propagateRelativePos(void); + + /** + * @brief test to make on the node to validate + * @see ConditionalInterpreter + */ + const std::shared_ptr<ConditionalInterpreter> mToTest; + + /** + * @brief update week ptr for the node, TODO best + */ + void updateWeak(void); + }; + + /** + * @brief class spesialisation for not commun node (node that must be match one Unique) transition + */ + class FsmEdgeUnique:public FsmEdge + { + + public: + FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest); + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + }; + + /** + * @brief class spesialisation for commun node transition + * @see FsmEdge + */ + class FsmEdgeCommon:public FsmEdge + { + + private: + /** + * @brief the map that defind the ralation between the commonKey find by the lexer and a unique id use to refer to the common node + */ + static std::map<std::string,int> mCommonIdxMap; + /** + * @brief the common id test in this transition + */ + int mCommonIdx; + public: + + /** + * @brief constructor commun node , + * @details during construction, + * the node key found by the lexer is converted to a unique id and the relative positions are updated. + */ + FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey); + // ~FsmEdgeCommon() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + bool isCommon(void) override; + + }; + + + + /** + * @brief class spesialisation for ref transition + * @see FsmEdge + */ + class FsmEdgeRef:public FsmEdge + { + private: + /** + * @brief the id of one common node that we use as an anchor + */ + const int mRefCommonIdx; + /** + * @brief the delta in terme of child or parent refer to the anchor + */ + const int mdeltaCommonIdx; + public: + FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx); + //~FsmEdgeRef() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + + }; + + /** + * @brief class spesialisation for ref empty transition + * @see FsmEdge + */ + class FsmEdgeEmpty:public FsmEdge + { + + public: + FsmEdgeEmpty(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + //~FsmEdgeEmpty() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + + }; + + + +//////////////////////// +// FACTORY +//////////////////////// + +enum class FsmEdgeTypes { + EMPTY = 0, + REF, + COMMON, + UNIQUE +}; + + +class FsmEdgeFactory { + public: + /** + * @brief factory for making edge and read the info in the lexeme of the token + * @param source source node of the edge + * @param dest Dest node of the edge + * @param type type of the edge + * @param lexeme the additional information to build the edge + * @return s prt of the edge + */ + static std::shared_ptr<FsmEdge> make(std::shared_ptr<FsmNode> source, std::shared_ptr<FsmNode> dest, + FsmEdgeTypes type,std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest, + const std::string lexeme = ""); + }; + +} + +#endif //AIDGE_CORE_FSM_EDGE_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a74551367dd492cb0abb820e4c5ce5a601d071e --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp @@ -0,0 +1,98 @@ + +#ifndef AIDGE_CORE_FSM_GRAPH_H_ +#define AIDGE_CORE_FSM_GRAPH_H_ + +#include <set> +#include <vector> +#include <memory> +#include <stdexcept> //error + +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" +namespace Aidge{ + + + +class FsmGraph +{ +private: + /** + * @brief all node origine + */ + std::set<std::size_t> mAllOrigine; + std::set<std::shared_ptr<FsmEdge>> mEdges; +public: + FsmGraph(/* args */); + virtual ~FsmGraph() = default; + +std::shared_ptr<MatchResult> test(std::vector<NodePtr>& StartNodes); + + + +const std::set<std::shared_ptr<FsmEdge>>& getEdge(void); +/** + * @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph +*/ +void addEdge(std::shared_ptr<FsmEdge>& edge); + +/** + * @brief get the liste of the starting states + * @details we need to use a vector because the order of the nodes is important for start node initialization \ref test() +*/ +const std::vector<std::shared_ptr<FsmNode>> getStartNodes(void); + +/** + * @brief get the set of the valide states + * @return set of valide state +*/ +const std::set<std::shared_ptr<FsmNode>> getValidNodes(void); + +/** + * @brief get the set of all the node in the graph + * @return set of all nodes +*/ +const std::set<std::shared_ptr<FsmNode>> getNodes(void); + +/** + * @brief set a groupe idx for all the nodes in the graph +*/ +void setGroupe(std::size_t groupeIdx); + +/** + * @brief make the union beteen this graph and an input graph + * @param fsmGraph graph to union +*/ +void unionG(const std::shared_ptr<FsmGraph> fsmGraph); + + +/** + * @brief make the union beteen this graph and an input graph and merge the valide state to the start state + * @param fsmGraph graph to merge +*/ +void mergeOneStartOneValid(const std::shared_ptr< FsmGraph> fsmGraph); +/** + * @brief get the number of sub FSM + * @return number of sub Fsm +*/ +std::size_t getNbSubFsm(void); + +/** + * @brief increment the origine of all node in the graph + * @param incr the incrémentation value +*/ +void incOrigineAllNodeBy(std::size_t incr); + +private: + +/** + * @brief merge tow node of the graph + * @param node +*/ +void _mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + +}; + + +} +#endif //AIDGE_CORE_FSM_GRAPH_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmNode.hpp b/include/aidge/graphRegex/matchFsm/FsmNode.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2776ff8eb297fd5ad9a4c425fb386adde0a25269 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmNode.hpp @@ -0,0 +1,99 @@ +#ifndef AIDGE_CORE_FSM_NODE_H_ +#define AIDGE_CORE_FSM_NODE_H_ + +#include <set> +#include <vector> +#include <memory> + +//#include "graphRegex/matchFsm/FsmEdge.hpp" +//#include "graphRegex/matchFsm/FsmRunTimeContext.hpp" + +namespace Aidge{ + // Forward declaration of the class defined in graphRegex/matchFsm/FsmEdge.hpp + class FsmEdge; + struct EdgeTestResult; + class FsmRunTimeContext; + + + //------------------------------------------------------------------------------ + + // MAY BE IN UTILE + template <typename T> + struct lex_compare { + bool operator() (const std::weak_ptr<T> &lhs, const std::weak_ptr<T> &rhs)const { + auto lptr = lhs.lock(), rptr = rhs.lock(); + if (!rptr) return false; // nothing after expired pointer + if (!lptr) return true; + return lptr < rptr; + } + }; + + /** + * @brief is a node in the FSM graph, it's a state in the FSM + * @details a state can be and/or : + * - a valide state, the match is valide if it stop on this edge + * - a start state , the match start on this state + * The state is also define by this origine (is the unique id of it's expretion ) + * and it's groupe (for inner expression TODO) + */ + class FsmNode : public std::enable_shared_from_this<FsmNode> + { + private: + /** + * @brief the edge of the node + * @details the edge have a shared ref to the node so we use weak ref + */ + std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>> mEdges; + /** + * @brief the parent of the node + */ + std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>> mParents; + + std::size_t mOrigineStm = 0; + std::size_t mGroupeStm = 0; + + bool mIsAValid; + bool mIsAStart; + + public: + FsmNode(bool isAValid,bool isAStart ); + virtual ~FsmNode() = default; + /** + * @brief use to MAG the actual context , and return all the posible new context + * @details one input context can generate a multitude of contexts because a graph node + * can have more than one child, and each traversal possibility is a new context. + * @param actContext the actual context + * @return A vector of all the new context + */ + const std::vector<std::shared_ptr<FsmRunTimeContext>> test( std::shared_ptr<FsmRunTimeContext>); + + + std::size_t getOrigine(void); + void incOrigine(std::size_t inc); + + + void rmEdge(std::shared_ptr<FsmEdge>); + void addEdge(std::shared_ptr<FsmEdge>); + + //const std::set<std::shared_ptr<FsmNode>> getChildNodes(void); + + const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>>& getParentNodes(void); + const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& getEdges(void); + + void setGroupe(std::size_t groupeIdx); + + bool isValid(void); + bool isStart(void); + void unValid(void); + void valid(void); + void unStart(void); + void start(void); + + + + void addParent(std::shared_ptr<FsmNode>); + void rmParent(std::shared_ptr<FsmNode>); + }; + +} +#endif //AIDGE_CORE_FSM_NODE_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f1b9fc2bfe68195f67cfc0bf17d57aed5345219 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp @@ -0,0 +1,173 @@ +#ifndef AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ +#define AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ + +#include <memory> +#include <vector> +#include <set> +#include <algorithm> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" +#include "aidge/graph/Node.hpp" + + + +namespace Aidge{ + + class FsmNode; + + class FsmNode; + + /** + * @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph + * all node that have been Validate,Rejecte or Considered common + */ + class FsmRunTimeContext + { + private: + /** + * @brief the list of node rejected for all the context + */ + static std::vector<std::set<NodePtr>> mRejectedNodes; + /** + * @brief the actual state of this Context (where it's in the FSM graph) + */ + std::shared_ptr<FsmNode> mActState; + /** + * @brief the actual node of this Context (where it's in the graph) + */ + NodePtr mActOpNode; + /** + * @brief the map of the node consider as common and the common ID + * @details we need to store what node it's consider as common because of the end + * resolution of the matching, all node consider as common need to be the same in all context + */ + std::map<NodePtr,std::size_t> mCommonNodes; + /** + * @brief the map of the node that as been valid in this context , and the test that valide the node + */ + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes; + /** + * @brief the index in the rejected node of this context + */ + std::size_t mLocalIdxRejeced; + public: + /** + * @brief constructor + * @param actState the actual state in the FSM + * @param actOpNode the actual node in the graph + * @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind + */ + FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() ); + FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime); + FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ); + + virtual ~FsmRunTimeContext()=default; + + /** + * @defgroup FsmRunTimeContextRejected Function for managing rejected nodes + */ + + /** + * @ingroup FsmRunTimeContextRejected + * @brief Add a node as rejected in this context + */ + void addRejectedNode(NodePtr node); + + /** + * @ingroup FsmRunTimeContextRejected + * @brief get the rejected nodes of this context + */ + std::set<NodePtr> getRejectedNodes(void); + + + /** + * @defgroup FsmRunTimeContextTest Function for test the context + */ + + /** + * @ingroup FsmRunTimeContextTest + * @brief test if the actual state is valide + * @return bool + */ + bool isOnValidState(void); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if the node is considered as common in this context + * @param node node to test + * @return bool + */ + bool isCommonDefined(NodePtr node); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if has already validated in this context + * @param node node to test + * @return bool + */ + bool isAlreadyValid(NodePtr node); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if this context is compatible with an others + * @details to say that two contexts are compatible is to check : + * that the contexts do not validate the same nodes (other than the common ones) + * and that the common ones have the same idx + * @param fsmContext the others context + * @return bool + */ + bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if this context is strictly equal with an others + * @param fsmContext the others context + * @return bool + */ + bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext); + + /** + * @defgroup FsmRunTimeContextSet Function set context + */ + + + void setCommon(NodePtr node,std::size_t commonIdx); + + + void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag); + + /** + * @defgroup FsmRunTimeContextGet Function get context + */ + + + /** + * @ingroup FsmRunTimeContextGet + * @brief get the sub idx state + * @return bool + */ + std::size_t getSubStmId(void); + + NodePtr getCommonNodeFromIdx(std::size_t commonIdx); + std::size_t getCommonNodeIdx(NodePtr node); + std::set<NodePtr> getCommonNodes(void); + + std::map<NodePtr,std::size_t> getCommon(void); + std::set<NodePtr> getValidNodes(void); + + std::set<NodePtr> getValidNodesNoCommon(void); + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> getValid(void); + + + NodePtr getActNode(void); + std::shared_ptr<FsmNode> getActState(void); + + + /** + * @defgroup FsmRunTimeContextMem + */ + + void rst(void); + + + }; + +} + +#endif //AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ diff --git a/include/aidge/graphRegex/matchFsm/MatchResult.hpp b/include/aidge/graphRegex/matchFsm/MatchResult.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac2f2a627a9d88b3cabeac4b181af2f3b7566d72 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/MatchResult.hpp @@ -0,0 +1,60 @@ +#ifndef AIDGE_CORE_MATCH_RESULT_H_ +#define AIDGE_CORE_MATCH_RESULT_H_ + +#include <memory> +#include <vector> +#include <map> + + +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge{ + +/** + * @brief class that old the result of a matching + * give acess to all node ant there tag in the expression +*/ +class MatchResult +{ +private: + /* data */ + std::vector<std::shared_ptr<FsmRunTimeContext>> mAllValid; + + /* + the Run time of eatch sub FSM , to have a valide match we need a set of one run time per FSM compatible + the id must be contigue + */ + std::vector<std::vector<std::shared_ptr<FsmRunTimeContext>>> mIdToRunTime; + + std::vector<std::set<NodePtr>> mSolve; + + std::size_t mNbSubStm; + +public: + MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm); + + virtual ~MatchResult() = default; + + /** + * @brief get the set of the node match for une expression + * @return the set of node of the graph that corresponding to an expression + */ + std::set<NodePtr> getBiggerSolution(void); + +private: + +/** + * @brief recurent function use to inite mSolve in the constructor + * + **/ + +void _generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence); + +}; + + +} + + +#endif //AIDGE_CORE_MATCH_RESULT_H_ diff --git a/include/aidge/hook/Hook.hpp b/include/aidge/hook/Hook.hpp index 28f7ef5cddbc649af50209ba77527b8b75d731b7..5e00db5d68f11aadd4f3b6eb8174ba61b33e4a49 100644 --- a/include/aidge/hook/Hook.hpp +++ b/include/aidge/hook/Hook.hpp @@ -31,11 +31,11 @@ protected: public: Hook(std::shared_ptr<Operator> op) : mOperator(op) {} - virtual ~Hook(); + virtual ~Hook() = default; virtual void call() = 0; }; } -#endif /* Hook_H_ */ \ No newline at end of file +#endif /* Hook_H_ */ diff --git a/include/aidge/nodeTester/ConditionalData.hpp b/include/aidge/nodeTester/ConditionalData.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12df32a728571678a3885f9981e526e1d73db785 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalData.hpp @@ -0,0 +1,98 @@ + +#ifndef AIDGE_CORE_CONDITIONAL_DATA_H_ +#define AIDGE_CORE_CONDITIONAL_DATA_H_ + +#include <vector> +#include <string> +#include <stdexcept> //error +#include <memory> +#include <map> +namespace Aidge{ + + + +///////////////////////// +// The data type in AST Intepretation +//////////////////////// + +class BaseConditionalValue { +public: + virtual ~BaseConditionalValue() {} +}; + +template <typename T> +class ConditionalValue : public BaseConditionalValue { +public: + ConditionalValue(const T& data) : value(data) {} + T value; +}; + + +struct ConditionalData { + /** + * @brief generic type to propagate all the different values in the AST interpretation + */ + //void* value; + std::unique_ptr<BaseConditionalValue> value; + const std::type_info* type =nullptr; + + ///////////////////////////////// + // + //////////////////////////////// + /** + * @brief set a value + */ + template <typename T> + void setValue(const T& newValue) { + //make sure that the old value is free + deleteValue(); + value = std::make_unique<ConditionalValue<T>>(newValue); + type = &typeid(T); + } + + /** + * @brief get the actual value + * @details recaste the value to the templaited type and checks that the conversion type is compatible with type + * @tparam the type of the return value + * @return the value + */ + template <typename T> + T getValue() const { + if (type && *type == typeid(T)) { + //const Value<T>* typedValue = dynamic_cast<const Value<T>*>(static_cast<const BaseValue*>(value)); + const ConditionalValue<T>* typedValue = dynamic_cast<const ConditionalValue<T>*>(value.get()); + if (typedValue) { + return typedValue->value; + } + } + throw std::runtime_error(std::string("DATA ERROR ") + type->name() + " != " + typeid(T).name()); + } + /////////////////////////////////// + // + /////////////////////////////////// + std::string getType() const { + return type ? type->name() : "nullptr"; + } + + + template <typename T> + bool isTypeEqualTo() const { + return (type && *type == typeid(T)); + } + + void deleteValue() { + if (type) { + value.reset(); + type = nullptr; + } + } + + ~ConditionalData() { // TODO best can we have a list of type supported ? + deleteValue(); + } +}; + +} + + +#endif //AIDGE_CORE_CONDITIONAL_DATA_H_ diff --git a/include/aidge/nodeTester/ConditionalInterpreter.hpp b/include/aidge/nodeTester/ConditionalInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..165fac1c2ae98bf76b73c039de9fc975e9845cc9 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalInterpreter.hpp @@ -0,0 +1,339 @@ + + +#ifndef AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ +#define AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ + +#include "aidge/nodeTester/ConditionalParser.hpp" +#include "aidge/nodeTester/ConditionalData.hpp" + +#include <memory> // for shared_ptr +#include <unordered_map> +#include <functional> +#include "aidge/graph/Node.hpp" +#include <sstream> + + +namespace Aidge{ + + + +////////////////////////////// +// +///////////////////////////// +/** + * @brief class used to register any lambda function without context, + * it encapsulates the source lambda in a lambda which takes as argument ConditionalData* which are any type. + * @see ConditionalData + */ +class ConditionalRegisterFunction { + ////////////////////////// + //Safe recaste + ////////////////////////// + + /** + * @brief recast the ConditionalData* to the argument type of the lambda + * @tparam T type of the lambda argument + * @see ConditionalData + */ + template <typename T> + T safeCastInput(ConditionalData* data) { + //cnvertion and type cheking + if (data->isTypeEqualTo<T>()){ + return data->getValue<T>(); + }else{ + throw std::invalid_argument( "incompatible input type " + data->getType() +" "+ typeid(T).name() ); + } + + } + + + /** + * @brief recaste the output of the lambda to a ConditionalData* + * @tparam T type of the lambda return + * @see ConditionalData + */ + template <typename T> + ConditionalData* safeCastOutput(T data) { + + ConditionalData* out = new ConditionalData; + out->setValue<T>(data); + + return out; + } + + + + + ////////////////////// + // get all the type of the function + ////////////////////// + + /** + * @brief Retrieves information about a function's return type and argument types. + * @tparam T The function type. + */ + template <typename T> + struct function_traits; + + + /** + * @brief Specialization of function_traits for function pointers. + * @tparam R The return type of the function. + * @tparam Args The argument types of the function. + */ + template <typename R, typename... Args> + struct function_traits<R (*)(Args...)> { + using return_type = R; + static constexpr std::size_t arity = sizeof...(Args); + + template <std::size_t N> + struct argument { + static_assert(N < arity, "Index out of range."); + using type = typename std::tuple_element<N, std::tuple<Args...>>::type; + }; + }; + + /** + * @brief Specialization of function_traits for std::function types. + * @tparam R The return type of the function. + * @tparam Args The argument types of the function. + */ + template <typename R, typename... Args> + struct function_traits<std::function<R(Args...)>> { + using return_type = R; + static constexpr std::size_t arity = sizeof...(Args); + + template <std::size_t N> + struct argument { + static_assert(N < arity, "Index out of range."); + using type = typename std::tuple_element<N, std::tuple<Args...>>::type; + }; + }; + + ///////////////////// + //change the function to ConditionalData*(std::vector<ConditionalData*>) + ///////////////////// + + /** + * @brief Converts a function to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam F The type of the function to convert. + * @tparam ParamsIdx The indices of the function parameters. + * @param f The function to convert. + * @return The pointer to the converted function. + */ + template <class F, std::size_t... ParamsIdx> + auto funcPointer(F f, std::index_sequence<ParamsIdx...>) { + //wrapp the lambda in a new one that as ConditionalData as inputs and output + return [this,f](std::vector<ConditionalData*> &args) { + if (args.size() != sizeof...(ParamsIdx)){ + std::ostringstream errorMessage; + errorMessage << "bad Number of argument: get " << args.size() << " need " << sizeof...(ParamsIdx) << "\n"; + throw std::runtime_error(errorMessage.str()); + } + //assert(args.size() == sizeof...(ParamsIdx));//the size of the vector valide + + using FuncTraits = function_traits<decltype(f)>; + using outType = typename FuncTraits::return_type; + + outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[ParamsIdx])...); + //typename + return safeCastOutput<outType>(result); + }; + } + + /** + * @brief Converts a function pointer to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam R The return type of the function. + * @tparam Params The parameter types of the function. + * @param f The function pointer to convert. + * @return The pointer to the converted function. + */ + template <class R,class... Params> + auto funcPointer(R (*f)(Params...)) { + return funcPointer(f, std::index_sequence_for<Params...>{}); + } + + /** + * @brief Converts a std::function to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam R The return type of the function. + * @tparam Params The parameter types of the function. + * @param f The function pointer to convert. + * @return The pointer to the converted function. + */ + template <class R,class... Params> + auto funcPointer(std::function<R(Params...)> f) { + return funcPointer(f, std::index_sequence_for<Params...>{}); + } + + + /////////////////// + // interface + /////////////////// + + public: + + /** + * @brief Default constructor + */ + ConditionalRegisterFunction(){} + + + /** + * @brief Inserts a function into the map with the provided key. + * @tparam T The function type. + * @param key The key to associate with the function. + * @param f The function to insert. + */ + template <class T> + void insert(const std::string key,T f){ + mWlambda.insert({ key, funcPointer(f)}); + } + + + /** + * @brief Runs the function associated with the given key, using the provided vector of input data. + * @param key The key of the function to run. + * @param datas The vector of input data. + * @return A pointer to the output ConditionalData object. + */ + ConditionalData* run(const std::string key,std::vector<ConditionalData*> & datas); + + private: + /// @brief map of name and the converted function. + std::map<const std::string, std::function<ConditionalData*(std::vector<ConditionalData*> &)>> mWlambda; +}; + +/////////////////// +//AST tree node +// //////////////// +/** + * @brief this class interprets AST to generate a test on a graph node. For each AST node, + * it generates an interpretation and registers lambda functions that can be used in the test expression. + * there are two lambda control mechanisms: + * - A cpp mechanism which allows any lambda to be inserted into the constructor that use templaite + * - A user mechanism limited to lambda bool(NodePtr) + * @see ConditionalParser use to get the AST + */ +class ConditionalInterpreter +{ + private: + + /** + * @brief the AST generate by the Parser + * @see ConditionalParser + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> mTree; + /** + * @brief the registery for the lambda fuction + * @see ConditionalRegisterFunction + */ + ConditionalRegisterFunction mLambdaRegiter; + + + std::vector<ConditionalData*> mResolution ; + + void clearRes(){ + + for (std::size_t i = 0; i < mResolution.size(); ++i) { + delete mResolution[i]; + } + mResolution.clear(); + } + + public: + /** + * @brief Constructor + * @param ConditionalExpressions The expression of the test to be performed on the nodes + */ + + ConditionalInterpreter(const std::string ConditionalExpressions); + + ~ConditionalInterpreter(){clearRes();} + + /** + * @brief Test a node depending of the ConditionalExpressions + * @details the AST is visit using \ref visit() whith the $ init whit the nodeOp + * @return bool the match node has the initialized expresion + * @see visit() This function uses the visit() function to perform the evaluation. + */ + bool test( const NodePtr nodeOp); + + /** + * @brief Interface for inserting custom lambda bool(NodePtr) functions in AST interpretation, + * it will be available in the ConditionalExpressions expretion as : key($) + * @param key The key that will be used to call the function in the expression + * @param f The pointer to function + */ + void insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f); + + + ///// + + private: + /** + * @brief Recursive AST traversal function, using the for interpreting AST nodes function, + * using \ref ASTnodeInterpreterF fuctions + * @param NodeOp The node currently being tested + * @param nodes The AST given by the parsing process + */ + std::vector<ConditionalData*> visit(const ASTNodeCh& nodes, const NodePtr NodeOp ); + + /** + * @defgroup ASTnodeInterpreterF Functions for interpreting AST nodes + * @brief For each node type in the AST, function defines the processing to be performed + * they return a std::vector<ConditionalData*> which corresponds to the value(s) obtained + */ + + /** + * @ingroup ASTnodeInterpreterF + * @brief Function that does something. + */ + void fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a int and to ConditionalData* + */ + void fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a float and to ConditionalData* + */ + void fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a str and to ConditionalData* + */ + void fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the == operation between two previously converted ConditionalData* + */ + void fEq(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the != operation between two previously converted ConditionalData* + */ + void fNeq(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the && operation between two previously converted ConditionalData* in bool + */ + void fAnd(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the || operation between two previously converted ConditionalData* in bool + */ + void fOr(void); + + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the ! operation + */ + void fNot(void); +}; + + +} + +#endif //AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ diff --git a/include/aidge/nodeTester/ConditionalLexer.hpp b/include/aidge/nodeTester/ConditionalLexer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fcfb9ebe783ac719076ce675e6fc3d78caf5be07 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalLexer.hpp @@ -0,0 +1,88 @@ +/** + * @file + * @brief + * @version file 1.0.0 + * @author vl241552 + * @copyright + * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All + * rights reserved. + */ + + + +#ifndef AIDGE_CORE_CONDITIONAL_LEXER_H_ +#define AIDGE_CORE_CONDITIONAL_LEXER_H_ + +#include <string> +#include <regex> +#include <memory> // for shared_ptr + + +#include <stdexcept> //error +#include <sstream> + +#include "aidge/nodeTester/ConditionalTypes.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" + + +namespace Aidge{ + + + +class ConditionalLexer +{ + +public: +ConditionalLexer( const std::string ConditionalExpressions ); + +/** + * @brief Get the next token on the ConditionalExpressions + * @return ParsingToken<ConditionalTokenTypes> + */ +std::shared_ptr<ParsingToken<ConditionalTokenTypes>> getNextToken(void); +/** + * @brief Restart at the start of the ConditionalExpressions + * + */ +void rstPosition(void); + +/** + * @brief Test if the string is completely read + * @return bool + */ +bool isEnd(void); + + +/** + * @brief Get the representation of the class + * @return string + */ +const std::string rep(){ + return mConditionalExpressions; +} + +private: + +/** + * @brief Constructs an error message to display the character not understood by the lexer + * @return error mesage + */ +std::runtime_error badTokenError(const std::string& currentChars,std::size_t position); + +/** + * @brief The expression of the test to be performed on the nodes + */ +const std::string mConditionalExpressions; +/** + * @brief The lexer's current position in mConditionalExpressions + */ +std::size_t mPosition; + +}; + +///////////////////////////////////// + + +} + +#endif //AIDGE_CORE_CONDITIONAL_LEXER_H_ diff --git a/include/aidge/nodeTester/ConditionalParser.hpp b/include/aidge/nodeTester/ConditionalParser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a99f5374182f57c0adca3b4d44691ff4e37de44d --- /dev/null +++ b/include/aidge/nodeTester/ConditionalParser.hpp @@ -0,0 +1,109 @@ + + + +#ifndef AIDGE_CORE_CONDITIONAL_PARSER_H_ +#define AIDGE_CORE_CONDITIONAL_PARSER_H_ + + +#include <memory> // for shared_ptr +#include <map> +#include <vector> + +#include "aidge/nodeTester/ConditionalLexer.hpp" +#include "aidge/nodeTester/ConditionalTypes.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" +#include "aidge/utilsParsing/AstNode.hpp" + +namespace Aidge{ + +const std::map<ConditionalTokenTypes, std::size_t> ConditionalPrec{ + {ConditionalTokenTypes::AND,2}, + {ConditionalTokenTypes::OR,1} +}; + + + + +using ASTNodeCh = std::vector<std::shared_ptr<AstNode<ConditionalTokenTypes>>>; + +/** + * @brief this class uses the lexer to create an AST according to a set of gramer rules + */ +class ConditionalParser{ + + public: + /** + * @brief AST graph creation function + * @param ConditionalExpressions String representing the logical fuction to be performed + */ + ConditionalParser(const std::string ConditionalExpressions); + + virtual ~ConditionalParser() = default; + /** + * @brief AST graph creation function + * @return The AST tree + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> parse(void); + + + private: + /** + * @brief restart at the start of the ConditionalExpressions for LEXER and restart mCurrentToken + */ + void rstParser(void); + + ////////////////// + + /** + * @defgroup ParsingFunctions Function for creating AST + * @brief Functions for recursive construction of the AST representing grammar rules + */ + + /** + * @ingroup ParsingFunctions + * @brief Token reading and verification function + * + */ + void ackToken(ConditionalTokenTypes tokenType); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for values : (KEY|INTEGER|FOAT|STRING|LAMBDA lambda) + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstVal(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for comparison : val (EQ|NEQ) val | LPAREN expr RPAREN + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstCmpr(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for arguments of a lambda : LAMBDA val (ARGSEP val)* RPAREN + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstLambda(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for a expresion : cmpr ((AND | OR) cmpr)* + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstExpr(std::size_t precLimit = 0); + + + /** + * @brief The actual token in the parce + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> mCurrentToken; + /** + * @brief The lexem use + */ + ConditionalLexer mLexer; + +}; + + +} + +#endif //AIDGE_CORE_CONDITIONAL_PARSER_H_ diff --git a/include/aidge/nodeTester/ConditionalTypes.hpp b/include/aidge/nodeTester/ConditionalTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6cb2edfd78e6b43c4f2dbc89c49cdaa9ea79f7d2 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalTypes.hpp @@ -0,0 +1,36 @@ + + +#ifndef AIDGE_CORE_CONDITIONAL_TYPES_H_ +#define AIDGE_CORE_CONDITIONAL_TYPES_H_ +namespace Aidge{ + /** + * @brief enum for all types of token use in the parsing + * 7-5 type + * 4-0 id + */ + enum class ConditionalTokenTypes + { + STOP, + + NOT, /**< ! */ + AND, /**< && */ + OR, /**< || */ + + EQ, /**< == */ + NEQ, /**< != */ + + KEY, /**< [A-Za-z][A-Za-z0-9_]* */ + INTEGER, /**< [0-9]+ */ + FLOAT, /**< [0-9]+\.[0-9]* */ + STRING , /**< \'.*\' */ + BOOL, /**< true|false */ + NODE, /**< \$ */ + LAMBDA , /**< [A-Za-z][A-Za-z0-9_]*\( */ + + ARGSEP, /**< , */ + LPAREN, /**< \( */ + RPAREN, /**< \) */ + + }; +} +#endif // AIDGE_CORE_CONDITIONAL_TYPES_H_ diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 1e0f17e6db9278e7edf2a11918472c084561a308..65c7e8ce0e47bd470e2a1499a682ed2f2c8c2dbc 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -81,14 +81,14 @@ public: // return *in; // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { const auto expectedDims = mInputs[0]->dims(); std::size_t nonEmptyInputTensor = 1; @@ -140,7 +140,7 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Add_Op<NUM>>::create(name)(*this); mOutput->setBackend(name); @@ -150,7 +150,7 @@ public: } } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -162,6 +162,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return NUM; } inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input_0", "data_input_n"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::size_t NUM> diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index b29463c675eb8516e02b83ad47816e9e9aa5d147..36de6c11a50692cc53ce9a70af4bef81ab0924bd 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -26,15 +26,14 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class AvgPoolingAttr { StrideDims, KernelDims, PaddingDims }; +enum class AvgPoolingAttr { StrideDims, KernelDims }; template <DimIdx_t DIM> class AvgPooling_Op : public Operator, public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>, public StaticAttributes<AvgPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >> { + std::array<DimSize_t, DIM>> { private: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -47,18 +46,15 @@ public: using Attributes_ = StaticAttributes<AvgPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1)> >; + std::array<DimSize_t, DIM>>; template <AvgPoolingAttr e> using attr = typename Attributes_::template attr<e>; constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), Attributes_(attr<AvgPoolingAttr::StrideDims>(stride_dims), - attr<AvgPoolingAttr::KernelDims>(kernel_dims), - attr<AvgPoolingAttr::PaddingDims>(padding_dims)) { + attr<AvgPoolingAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } @@ -84,7 +80,7 @@ public: return std::make_shared<AvgPooling_Op<DIM>>(*this); } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); @@ -92,16 +88,14 @@ public: mInput = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInput->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) { outputDims[dim+2] = 1 + static_cast<DimSize_t>( std::floor(static_cast<float>(mInput->dims()[dim+2] - - this->template getAttr<AvgPoolingAttr::KernelDims>()[dim] + - this->template getAttr<AvgPoolingAttr::PaddingDims>()[dim] + - this->template getAttr<AvgPoolingAttr::PaddingDims>()[dim+DIM]) / + this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) / static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim]))); } outputDims[1] = mInput->dims()[1]; @@ -145,7 +139,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -153,7 +147,7 @@ public: mInput->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -163,16 +157,21 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); - auto avgPool = std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims), name); + auto avgPool = std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims), name); return avgPool; } @@ -180,17 +179,16 @@ template <DimSize_t DIM> inline std::shared_ptr<Node> AvgPooling( DimSize_t const (&kernel_dims)[DIM], const std::string& name = "", - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); - return AvgPooling(to_array(kernel_dims), name, stride_dims, padding_dims); + return AvgPooling(to_array(kernel_dims), name, stride_dims); } } // namespace Aidge namespace { template <> const char *const EnumStrings<Aidge::AvgPoolingAttr>::data[] = {"StrideDims", - "KernelDims", "PaddingDims"}; + "KernelDims"}; } #endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */ diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 90a6be7222ee1b3e377520f2bc612a72c2ba4ab3..da7360c8ba3816cdfe1d2d00f80b08808a80f961 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -87,14 +87,14 @@ public: // return *in; // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 5 && "operators supports only 5 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) { if(mInputs[i]->size() != mInputs[0]->dims()[1]) { @@ -136,7 +136,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -147,7 +147,7 @@ public: mInputs[4]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -160,6 +160,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 5; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "scale", "shift", "mean", "variance"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <DimSize_t DIM> diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 22553080c6d4d8359149b3b34c5d040e5e900c4d..5e6374c488a34fc8b29a5f841f42b8f44d2fc7a6 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -26,13 +26,13 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims, PaddingDims }; +enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims }; template <DimIdx_t DIM> class Conv_Op : public Operator, public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >> { + DimSize_t, std::array<DimSize_t, DIM>> { public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), @@ -45,7 +45,7 @@ public: Conv_Op() = delete; using Attributes_ = StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, - DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>; + DimSize_t, DimSize_t, std::array<DimSize_t, DIM>>; template <ConvAttr e> using attr = typename Attributes_::template attr<e>; @@ -53,15 +53,13 @@ public: DimSize_t out_channels, const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), Attributes_(attr<ConvAttr::StrideDims>(stride_dims), - attr<ConvAttr::DilationDims>(dilation_dims), - attr<ConvAttr::InChannels>(in_channels), - attr<ConvAttr::OutChannels>(out_channels), - attr<ConvAttr::KernelDims>(kernel_dims), - attr<ConvAttr::PaddingDims>(padding_dims)) { + attr<ConvAttr::DilationDims>(dilation_dims), + attr<ConvAttr::InChannels>(in_channels), + attr<ConvAttr::OutChannels>(out_channels), + attr<ConvAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } @@ -100,14 +98,14 @@ public: // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; @@ -117,9 +115,7 @@ public: 1; outputDims[dim+2] = 1 + static_cast<DimSize_t>( - floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent + - this->template getAttr<ConvAttr::PaddingDims>()[dim] + - this->template getAttr<ConvAttr::PaddingDims>()[dim+DIM]) / + floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent) / static_cast<float>(this->template getAttr<ConvAttr::StrideDims>()[dim]))); } @@ -160,7 +156,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -169,7 +165,7 @@ public: mInputs[2]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -181,6 +177,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> @@ -189,14 +191,13 @@ inline std::shared_ptr<Node> Conv(DimSize_t in_channels, const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); - auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, padding_dims, dilation_dims), name); + auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), name); // addProducer(conv, 1, append(append(kernel_dims, in_channels), out_channels), "w"); addProducer(conv, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); - addProducer(conv, 2, {out_channels}, "b"); + addProducer(conv, 2, std::array<DimSize_t, 1>({out_channels}), "b"); return conv; } @@ -207,10 +208,9 @@ inline std::shared_ptr<Node> Conv( DimSize_t const (&kernel_dims)[DIM], const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); - return Conv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); + return Conv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, dilation_dims); } } // namespace Aidge @@ -221,8 +221,7 @@ const char *const EnumStrings<Aidge::ConvAttr>::data[] = { "DilationDims", "InChannels", "OutChannels", - "KernelDims", - "PaddingDims" + "KernelDims" }; } diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 7a4db68bae2f42eb892dd7240463e7363753b5a7..ec8ce2b3e1e2961658bd5fce7342fe5a31b7bb5b 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -26,7 +26,7 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class ConvDepthWiseAttr { StrideDims, DilationDims, Channels, KernelDims, PaddingDims }; +enum class ConvDepthWiseAttr { StrideDims, DilationDims, Channels, KernelDims }; template <DimIdx_t DIM> class ConvDepthWise_Op : public Operator, @@ -35,8 +35,7 @@ class ConvDepthWise_Op : public Operator, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >> { + std::array<DimSize_t, DIM>> { public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), @@ -52,21 +51,18 @@ class ConvDepthWise_Op : public Operator, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >>; + std::array<DimSize_t, DIM>>; template <ConvDepthWiseAttr e> using attr = typename Attributes_::template attr<e>; constexpr ConvDepthWise_Op(const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), Attributes_(attr<ConvDepthWiseAttr::StrideDims>(stride_dims), - attr<ConvDepthWiseAttr::DilationDims>(dilation_dims), - attr<ConvDepthWiseAttr::Channels>(0), - attr<ConvDepthWiseAttr::KernelDims>(kernel_dims), - attr<ConvDepthWiseAttr::PaddingDims>(padding_dims)) { + attr<ConvDepthWiseAttr::DilationDims>(dilation_dims), + attr<ConvDepthWiseAttr::Channels>(0), + attr<ConvDepthWiseAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } @@ -92,14 +88,14 @@ class ConvDepthWise_Op : public Operator, return std::make_shared<ConvDepthWise_Op<DIM>>(*this); } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; @@ -109,9 +105,7 @@ class ConvDepthWise_Op : public Operator, 1; outputDims[dim+2] = 1 + static_cast<DimSize_t>( - floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent + - this->template getAttr<ConvDepthWiseAttr::PaddingDims>()[dim] + - this->template getAttr<ConvDepthWiseAttr::PaddingDims>()[dim+DIM]) / + floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent) / static_cast<float>(this->template getAttr<ConvDepthWiseAttr::StrideDims>()[dim]))); } this->template getAttr<ConvDepthWiseAttr::Channels>() = mInputs[0]->dims()[1]; @@ -161,7 +155,7 @@ class ConvDepthWise_Op : public Operator, - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -170,7 +164,7 @@ class ConvDepthWise_Op : public Operator, mInputs[2]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -182,17 +176,22 @@ class ConvDepthWise_Op : public Operator, inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> ConvDepthWise(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); - auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims, dilation_dims), name); + auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), name); addProducer(convDW, 1, std::array<DimSize_t,0>({}), "w"); addProducer(convDW, 2, std::array<DimSize_t,0>({}), "b"); return convDW; @@ -203,17 +202,16 @@ inline std::shared_ptr<Node> ConvDepthWise( DimSize_t const (&kernel_dims)[DIM], const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); - return ConvDepthWise(to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); + return ConvDepthWise(to_array(kernel_dims), name, stride_dims, dilation_dims); } } // namespace Aidge namespace { template <> const char *const EnumStrings<Aidge::ConvDepthWiseAttr>::data[] = {"StrideDims", "DilationDims", "Channels", - "KernelDims", "PaddingDims"}; + "KernelDims"}; } #endif /* AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H_ */ diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 127d39a8bdfdd233cdac9e1ca6cf0bf85f656d16..b949527c51b9330077dd3bd8f8b4bf1f1b9d719c 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -135,7 +135,7 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<FC_Op>::create(name)(*this); mOutput->setBackend(name); @@ -145,7 +145,7 @@ public: mInputs[2]->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -158,13 +158,19 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const std::string& name = "") { // FIXME: properly handle default w&b initialization in every cases auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(out_channels, noBias), name); - addProducer(fc, 1, {out_channels, 1}, "w"); - addProducer(fc, 2, {(noBias ? 0 : out_channels)}, "b"); // already sets bias dims + addProducer(fc, 1, std::array<DimSize_t, 2>({out_channels, 1}), "w"); + addProducer(fc, 2, (noBias ? std::array<DimSize_t, 1>({0}) : std::array<DimSize_t, 1>({out_channels})), "b"); // already sets bias dims return fc; } } // namespace Aidge @@ -175,4 +181,4 @@ const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels", "NoBias"}; } -#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 1e51866177acf80441f236070aea9dee6145bc19..83b9a932633deb822ad86c24b96e6e928b5e2be2 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -24,6 +24,7 @@ #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" + namespace Aidge { class GenericOperator_Op : public Operator, @@ -165,8 +166,8 @@ class GenericOperator_Op ~GenericOperator_Op() = default; - void setBackend(const std::string & /*name*/) { printf("setBackend: not available yet.\n"); } - void setDatatype(const DataType & /*datatype*/) { printf("setDatatype: not available yet.\n"); } + void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); } + void setDatatype(const DataType & /*datatype*/) override { printf("setDatatype: not available yet.\n"); } void forward() override final { printf("forward: not available yet.\n"); } void backward() override final { printf("backward: not available yet.\n"); } diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index c6ee01239e1ed065587276c1891d26ba3899fe89..40d9959b3802dcbe337adeafa9643bf0682df64b 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -120,14 +120,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<LeakyReLU_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -137,6 +137,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const std::string& name = "") { diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index d0dadd847a59c9d2a1c0dd97f2f200437da71863..eed1ec04535aa5896aa3d01a27d8023d37a42183 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -127,7 +127,7 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<MatMul_Op>::create(name)(*this); mOutput->setBackend(name); @@ -136,7 +136,7 @@ public: mInputs[1]->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -148,12 +148,18 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 2; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> MatMul(DimSize_t out_channels, const std::string& name = "") { // FIXME: properly handle default w initialization in every cases auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(out_channels), name); - addProducer(matmul, 1, {out_channels, 1}, "w"); + addProducer(matmul, 1, std::array<DimSize_t, 2>({out_channels, 1}), "w"); return matmul; } } // namespace Aidge diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index eae7e30df039c0514443e567032427f7a6556360..e261fb4b8c6d1448cca010f6aa11214bf6597f67 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -26,15 +26,14 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class MaxPoolingAttr { StrideDims, KernelDims, PaddingDims }; +enum class MaxPoolingAttr { StrideDims, KernelDims }; template <DimIdx_t DIM> class MaxPooling_Op : public Operator, public Registrable<MaxPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const MaxPooling_Op<DIM> &)>, public StaticAttributes<MaxPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >> { + std::array<DimSize_t, DIM>> { private: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -47,18 +46,15 @@ public: using Attributes_ = StaticAttributes<MaxPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1)> >; + std::array<DimSize_t, DIM>>; template <MaxPoolingAttr e> using attr = typename Attributes_::template attr<e>; constexpr MaxPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), Attributes_(attr<MaxPoolingAttr::StrideDims>(stride_dims), - attr<MaxPoolingAttr::KernelDims>(kernel_dims), - attr<MaxPoolingAttr::PaddingDims>(padding_dims)), + attr<MaxPoolingAttr::KernelDims>(kernel_dims)), mOutput(std::make_shared<Tensor>()) { setDatatype(DataType::Float32); } @@ -85,7 +81,7 @@ public: return std::make_shared<MaxPooling_Op<DIM>>(*this); } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); @@ -93,16 +89,14 @@ public: mInput = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInput->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; for (std::size_t dim = 0; dim < this->template getAttr<MaxPoolingAttr::KernelDims>().size() ; ++dim) { outputDims[dim+2] = 1 + static_cast<DimSize_t>( std::floor(static_cast<float>(mInput->dims()[dim+2] - - this->template getAttr<MaxPoolingAttr::KernelDims>()[dim] + - this->template getAttr<MaxPoolingAttr::PaddingDims>()[dim] + - this->template getAttr<MaxPoolingAttr::PaddingDims>()[dim+DIM]) / + this->template getAttr<MaxPoolingAttr::KernelDims>()[dim]) / static_cast<float>(this->template getAttr<MaxPoolingAttr::StrideDims>()[dim]))); } outputDims[1] = mInput->dims()[1]; @@ -146,7 +140,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -154,7 +148,7 @@ public: mInput->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -164,16 +158,21 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> MaxPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by MaxPooling, not supported"); - auto avgPool = std::make_shared<Node>(std::make_shared<MaxPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims), name); + auto avgPool = std::make_shared<Node>(std::make_shared<MaxPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims), name); return avgPool; } @@ -181,16 +180,15 @@ template <DimSize_t DIM> inline std::shared_ptr<Node> MaxPooling( DimSize_t const (&kernel_dims)[DIM], const std::string& name = "", - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by MaxPooling, not supported"); - return MaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims); + return MaxPooling(to_array(kernel_dims), name, stride_dims); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::MaxPoolingAttr>::data[] = {"StrideDims", "KernelDims", "PaddingDims"}; +const char *const EnumStrings<Aidge::MaxPoolingAttr>::data[] = {"StrideDims", "KernelDims"}; } #endif /* AIDGE_CORE_OPERATOR_MAXPOOLING_H_ */ diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 0c77a752493d251303c036c4061823c4f8bc499d..bb34fd9c7756f103d4f31f17f815309c925306b7 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -13,21 +13,38 @@ #define AIDGE_CORE_OPERATOR_METAOPERATOR_H_ #include "aidge/operator/Operator.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/scheduler/Scheduler.hpp" namespace Aidge { -class MetaOperator : public Operator { +class MetaOperator_Op : public Operator, + public Registrable<MetaOperator_Op, std::array<std::string, 2>, std::unique_ptr<OperatorImpl>(const MetaOperator_Op &)> { public: - MetaOperator() - : Operator("MetaOp") - { - } + std::vector<std::shared_ptr<Tensor>> mInputs; + std::vector<std::shared_ptr<Tensor>> mOutputs; // These are shared with micro-graph outputs tensors + + // Micro-graph handling: + std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph + std::shared_ptr<SequentialScheduler> mScheduler; + // Need to store an ordored list of input/output operators for the micro-graph, + // because input/output nodes in a GraphView are unordered. + // TODO: refactor GraphView to handle ordered input/output? + std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mInputOps; + std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mOutputOps; + + public: + MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, + std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), + std::vector<NodePtr> outputNodes = std::vector<NodePtr>()); /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - MetaOperator(const MetaOperator& op) - : Operator("MetaOp") + MetaOperator_Op(const MetaOperator_Op& op) + : Operator(op.type().c_str()), + mGraph(op.mGraph->clone()) { // cpy-ctor } @@ -37,11 +54,112 @@ public: * @see Operator::MatMul_Op */ std::shared_ptr<Operator> clone() const override { - return std::make_shared<MetaOperator>(*this); + return std::make_shared<MetaOperator_Op>(*this); + } + + const std::shared_ptr<GraphView>& getMicroGraph() const { + return mGraph; + } + + const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const { + return mScheduler; + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + const auto& inputOp = mInputOps[inputIdx]; + inputOp.first->associateInput(inputOp.second, data); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + // Forward dims of micro-graph + mGraph->forwardDims(); + + // Associate outputs to micro-graph outputs for custom implementation + for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { + const auto& outputOp = mOutputOps[outputIdx]; + mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); + } + } + + bool outputDimsForwarded() const override final { return !(mOutputs[0]->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return *(mInputs[inputIdx].get()); + } + + inline Tensor& output(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return *(mOutputs[outputIdx].get()); + } + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return mInputs[inputIdx]; + } + + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return mOutputs[outputIdx]; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return std::static_pointer_cast<Data>(mOutputs[outputIdx]); + } + + void setBackend(const std::string &name) override { + if (Registrar<MetaOperator_Op>::exists({name, type()})) { + // A custom implementation exists for this meta operator + mImpl = Registrar<MetaOperator_Op>::create({name, type()})(*this); + } + + // The micro-graph should always be set to the right backend, since it + // shares input/output tensors. + // Input/output tensors backend are updated here. + mGraph->setBackend(name); + } + + void setDatatype(const DataType &datatype) override { + // The micro-graph should always be set to the right data type, since it + // shares input/output tensors. + // Input/output tensors data type are updated here. + mGraph->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return mGraph->inputs().size(); } + inline IOIndex_t nbDataInputs() const noexcept override final { return mGraph->dataInputs().size(); } + inline IOIndex_t nbOutputs() const noexcept override final { return mGraph->outputs().size(); } + + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override; + NbElts_t getNbProducedData(IOIndex_t outputIdx) const override; + + void updateConsummerProducer() override; + void forward() override; + void backward() override { + assert(false && "not implemented"); } - ~MetaOperator() = default; }; + +inline std::shared_ptr<Node> MetaOperator(const char *type, + const std::shared_ptr<GraphView>& graph, + const std::string& name = "") +{ + return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph), name); } +} // namespace Aidge #endif /* MetaOperator_H_ */ diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..df66cec7e1accfee1518378ce2e9697cdc7f91fb --- /dev/null +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -0,0 +1,124 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ +#define AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/Pad.hpp" + +namespace Aidge { +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}, + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + // Construct micro-graph + auto pad = std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(padding_dims, PadBorderType::Constant, 0.0), (!name.empty()) ? name + "_pad" : ""); + auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); + // Need to specify the ordered list of input operators + const std::vector<NodePtr> orderedInputNodes = {pad, conv}; + + auto metaOp = std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConv", Sequential({pad, conv}), orderedInputNodes), name); + addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); + addProducer(metaOp, 2, {out_channels}, "b"); + return metaOp; +} + +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedConv( + DimSize_t in_channels, + DimSize_t out_channels, + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}, + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + return PaddedConv<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedConvDepthWise(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}, + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + // Construct micro-graph + auto pad = std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(padding_dims, PadBorderType::Constant, 0.0), (!name.empty()) ? name + "_pad" : ""); + auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); + // Need to specify the ordered list of input operators + const std::vector<NodePtr> orderedInputNodes = {pad, conv}; + + auto metaOp = std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedConvDepthWise", Sequential({pad, conv}), orderedInputNodes), name); + addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); + addProducer(metaOp, 2, {out_channels}, "b"); + return metaOp; +} + +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedConvDepthWise( + DimSize_t in_channels, + DimSize_t out_channels, + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}, + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + return PaddedConvDepthWise<DIM>(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedAvgPooling(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}) +{ + auto graph = Sequential({ + Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), + AvgPooling_Op<DIM>(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims) + }); + + return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedAvgPooling", graph), name); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedMaxPooling(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<std::array<DimSize_t, 2>, DIM> &padding_dims = {0}) +{ + auto graph = Sequential({ + Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), + MaxPooling_Op<DIM>(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims) + }); + + return std::make_shared<Node>(std::make_shared<MetaOperator_Op>("PaddedMaxPooling", graph), name); +} +} // namespace Aidge + +#endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */ diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 5b0c199e75f0cedd4a0d36f6d2c87d89833e0dd5..a99e4e8ed37aeaa647da1dcaaa994b070901129b 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -81,7 +81,7 @@ public: * @param inputIdx Index of the input analysed. * @return NbElts_t */ - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; /** * @brief Amount of data from a specific input actually used in one computation pass. @@ -89,7 +89,7 @@ public: * @param inputIdx Index of the input analysed. * @return NbElts_t */ - NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; + virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Amount of data ready to be used on a specific output. @@ -97,9 +97,9 @@ public: * @param outputIdx Index of the output analysed. * @return NbElts_t */ - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; - void updateConsummerProducer(); + virtual void updateConsummerProducer(); virtual void forward(); @@ -116,6 +116,12 @@ public: virtual IOIndex_t nbInputs() const noexcept = 0; virtual IOIndex_t nbDataInputs() const noexcept = 0; virtual IOIndex_t nbOutputs() const noexcept = 0; + static const std::vector<std::string> getInputsName(){ + return {}; + } + static const std::vector<std::string> getOutputsName(){ + return {}; + } }; } // namespace Aidge diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ddc611a0fc6e3604ad6cb1949142d26625ae778a --- /dev/null +++ b/include/aidge/operator/Pad.hpp @@ -0,0 +1,239 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_PAD_H_ +#define AIDGE_CORE_OPERATOR_PAD_H_ + +#include <array> +#include <numeric> +#include <vector> +#include <cmath> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class PadAttr { BeginEndBorders, BorderType, BorderValue }; +enum class PadBorderType { Constant, Replicate, Reflect, Wrap }; + +template <DimIdx_t DIM> +class Pad_Op : public Operator, + public Registrable<Pad_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Pad_Op<DIM> &)>, + public StaticAttributes<PadAttr, + std::array<std::array<DimSize_t, 2>, DIM>, + PadBorderType, + double> { +private: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char *Type = "Pad"; + + Pad_Op() = delete; + + using Attributes_ = StaticAttributes<PadAttr, + std::array<std::array<DimSize_t, 2>, DIM>, + PadBorderType, + double>; + template <PadAttr e> + using attr = typename Attributes_::template attr<e>; + + constexpr Pad_Op(const std::array<std::array<DimSize_t, 2>, DIM> &beginEndTuples, + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) + : Operator(Type), + Attributes_(attr<PadAttr::BeginEndBorders>(beginEndTuples), + attr<PadAttr::BorderType>(borderType), + attr<PadAttr::BorderValue>(borderValue)) { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Pad_Op(const Pad_Op& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Pad_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Pad_Op<DIM>>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 1 && "operators supports only 3 inputs"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) { + std::array<DimSize_t, DIM + 2> outputDims = {}; + + for (std::size_t dim = 0; dim < DIM; ++dim) { + outputDims[dim+2] = this->template getAttr<PadAttr::BeginEndBorders>()[dim][0] + + mInput->dims()[dim+2] + + this->template getAttr<PadAttr::BeginEndBorders>()[dim][1]; + } + outputDims[1] = mInput->dims()[1]; + outputDims[0] = mInput->dims()[0]; + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return *(mInput.get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "Pad Operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "Pad Operators has only 1 outputs"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string &name) override { + mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + + void setDatatype(const DataType &datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 1; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> Pad(const std::array<std::array<DimSize_t, 2>, DIM> &beginEndTuples, + const std::string& name = "", + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) +{ + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported"); + auto pad = std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(beginEndTuples, borderType, borderValue), name); + return pad; +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, DIM> &dimBeginEnd, + const std::string& name = "", + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) +{ + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported"); + std::array<std::array<DimSize_t, 2>, DIM> beginEndTuples; + for (size_t i = 0; i < DIM; ++i) { + beginEndTuples[i] = {dimBeginEnd[i], dimBeginEnd[i]}; + } + auto pad = std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(beginEndTuples, borderType, borderValue), name); + return pad; +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> ZeroPad(const std::array<std::array<DimSize_t, 2>, DIM> &beginEndTuples, + const std::string& name = "") +{ + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported"); + auto pad = std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(beginEndTuples, PadBorderType::Constant, 0.0), name); + return pad; +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> ZeroPad(const std::array<DimSize_t, DIM> &dimBeginEnd, + const std::string& name = "") +{ + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported"); + std::array<std::array<DimSize_t, 2>, DIM> beginEndTuples; + for (size_t i = 0; i < DIM; ++i) { + beginEndTuples[i] = {dimBeginEnd[i], dimBeginEnd[i]}; + } + auto pad = std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(beginEndTuples, PadBorderType::Constant, 0.0), name); + return pad; +} + +template <DimSize_t DIM> +inline std::shared_ptr<Node> Pad( + std::array<DimSize_t, 2> const (&beginEndTuples)[DIM], + const std::string& name = "", + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) +{ + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported"); + return Pad(to_array(beginEndTuples), name, borderType, borderValue); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::PadAttr>::data[] = {"BeginEndBorders", "BorderType", "BorderValue"}; + +template <> +const char *const EnumStrings<Aidge::PadBorderType>::data[] = {"Constant", "Replicate", "Reflect", "Wrap"}; +} + +#endif /* AIDGE_CORE_OPERATOR_PAD_H_ */ diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 593192c9f402e2646ac94cff68aa0c805f5aecd1..529a37c063567e2e09367176437f212c69b2bf40 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -79,7 +79,7 @@ public: * @brief Set the Output Tensor of the Producer operator. * This method will create a copy of the Tensor. * - * @param newOutput Tensor containing the values to copy + * @param newOutput Tensor containing the values to copy */ void setOutputTensor(const Tensor& newOutput) { *mOutput = newOutput; @@ -121,17 +121,23 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutput->dims(); } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Producer_Op>::create(name)(*this); mOutput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); } inline IOIndex_t nbInputs() const noexcept override final { return 0; }; inline IOIndex_t nbDataInputs() const noexcept override final { return 0; }; inline IOIndex_t nbOutputs() const noexcept override final { return 1; }; + static const std::vector<std::string> getInputsName(){ + return {""}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } public: void forward() override final { diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 433e353f05f8b4ffc3cfc0e047464e7f9257da02..07f79f64933b3eec9e98a5ad13a4afbda9aed588 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -108,14 +108,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<ReLU_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -125,6 +125,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> ReLU(const std::string& name = "") { diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 522fc8ad01399391791336cfc5e09fa318f864ab..20d66500081a88cc05e7333217d72a494d53d648 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -130,13 +130,13 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Scaling_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -146,6 +146,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") { diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 898bae4c31bb2c41947523a86bfb9cd5c7b732b4..e6c9869a50d1d142c20525ce2ccfc4f1de5088ed 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -108,14 +108,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Softmax_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -125,6 +125,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> Softmax(const std::string& name = "") { diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 9916ee2004bd1aa9f33acf96d95cae4703f692df..1896894ee8690cedaef696394da0829604e36211 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -89,11 +89,6 @@ private: * */ std::vector<std::shared_ptr<Node>> mStaticSchedule; - /** - * @brief Number of computation node (i.e: nb nodes != Producer) - * - */ - std::size_t mComputationNumber = 0; // TODO: Check if not inferable from mStaticSchedule }; } // namespace Aidge diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 60f586edf947cef0e139049814263a29b4d01e24..af03ee2861e81d81171ccc2ea14289f2ce3aa9e3 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -194,7 +194,7 @@ public: * generic type caster for std::any is not feasable. * The strategy here is to keep a copy of each attribute in py::object that is updated everytime. */ - py::object getAttrPy(const std::string& name) const { + py::object getAttrPy(const std::string& name) const override final { return mAttrsPy.at(name); }; #endif diff --git a/include/aidge/utils/Utils.hpp b/include/aidge/utils/ErrorHandling.hpp similarity index 54% rename from include/aidge/utils/Utils.hpp rename to include/aidge/utils/ErrorHandling.hpp index 71817dcfc9713ad36a74175affd21b03cb6ed181..8fbeff30abecfec0077786b21825b6a6f36677c6 100644 --- a/include/aidge/utils/Utils.hpp +++ b/include/aidge/utils/ErrorHandling.hpp @@ -10,17 +10,21 @@ ********************************************************************************/ -#ifndef AIDGE_UTILS_H_ -#define AIDGE_UTILS_H_ +#ifndef AIDGE_ERRORHANDLING_H_ +#define AIDGE_ERRORHANDLING_H_ #include <cstdio> #include <memory> -#ifdef NO_EXCEPTIONS +#define AIDGE_STRINGIZE_DETAIL(x) #x +#define AIDGE_STRINGIZE(x) AIDGE_STRINGIZE_DETAIL(x) + +#ifdef NO_EXCEPTION #define AIDGE_THROW_OR_ABORT(ex, ...) \ do { std::printf(__VA_ARGS__); std::abort(); } while (false) #else #include <stdexcept> +#include <memory> #define AIDGE_THROW_OR_ABORT(ex, ...) \ do { \ int n = 128; \ @@ -35,4 +39,21 @@ do { \ } while (false) #endif -#endif //AIDGE_UTILS_H_ \ No newline at end of file +/** + * Macro for specified API assertions. + * Used to check logic directly related to user's inputs. + * If it asserts, it means an user error. +*/ +#define AIDGE_ASSERT(stm, ...) \ +if (!(stm)) { printf("Assertion failed: " AIDGE_STRINGIZE(stm) " in " __FILE__ ":%d", __LINE__); \ + AIDGE_THROW_OR_ABORT(std::runtime_error, __VA_ARGS__); } + +/** + * Macro for internal assertions. + * Used to check internal logic not directly related to API user's inputs. + * If it asserts, it means a bug. +*/ +#define AIDGE_INTERNAL_ASSERT(stm) \ +assert((stm) && "Internal assertion failed: " #stm " in " __FILE__ ":" AIDGE_STRINGIZE(__LINE__)) + +#endif //AIDGE_ERRORHANDLING_H_ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index de543e95a16475c4443164af7be5c379d6554f8d..3b29c472b3a540c9ef3b8ed46520e3e718e8cbfb 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -58,6 +58,11 @@ struct Registrar { //assert(newInsert && "registrar already exists"); } + static bool exists(const typename C::registrar_key& key) { + const auto it = C::registry().find(key); + return (it != C::registry().end()); + } + static auto create(const typename C::registrar_key& key){ const auto it = C::registry().find(key); assert(it != C::registry().end() && "invalid registrar key"); diff --git a/include/aidge/utils/StaticAttributes.hpp b/include/aidge/utils/StaticAttributes.hpp index fb800cffbcff5d4113961f8e62977417336f2cb8..b67f69ae7afc2c22f3b424812ec994b10974b668 100644 --- a/include/aidge/utils/StaticAttributes.hpp +++ b/include/aidge/utils/StaticAttributes.hpp @@ -18,7 +18,7 @@ #include <typeinfo> #include "aidge/utils/Attributes.hpp" -#include "aidge/utils/Utils.hpp" +#include "aidge/utils/ErrorHandling.hpp" namespace Aidge { /** @@ -87,7 +87,7 @@ public: // Runtime access with name template <typename R> - constexpr R& getAttr(const char* name) { + R& getAttr(const char* name) { for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { if (strcmp(EnumStrings<ATTRS_ENUM>::data[i], name) == 0) { return getAttr<R>(i); @@ -98,7 +98,7 @@ public: } template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> - constexpr typename std::enable_if<(SIZE > 0), R&>::type getAttr(std::size_t i) { + typename std::enable_if<(SIZE > 0), R&>::type getAttr(std::size_t i) { if (i == SIZE-1) { if (std::is_same<R, typename std::tuple_element<SIZE-1,std::tuple<T...>>::type>::value) { return reinterpret_cast<R&>(std::get<SIZE-1>(mAttrs)); @@ -113,7 +113,7 @@ public: } template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> - [[noreturn]] constexpr typename std::enable_if<(SIZE == 0), R&>::type getAttr(std::size_t /*i*/) { + [[noreturn]] typename std::enable_if<(SIZE == 0), R&>::type getAttr(std::size_t /*i*/) { AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute not found"); } @@ -128,7 +128,7 @@ public: } template <std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> - [[noreturn]] constexpr typename std::enable_if<(SIZE == 0), const std::type_info&>::type getAttrType(std::size_t /*i*/) const { + [[noreturn]] typename std::enable_if<(SIZE == 0), const std::type_info&>::type getAttrType(std::size_t /*i*/) const { AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute not found"); } @@ -140,7 +140,7 @@ public: /// Generic Attributes API ////////////////////////////////////// // Runtime existance check with name - constexpr bool hasAttr(const std::string& name) const override final { + bool hasAttr(const std::string& name) const override final { for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { if (name == EnumStrings<ATTRS_ENUM>::data[i]) { return true; @@ -151,7 +151,7 @@ public: } // Runtime type access with name - constexpr std::string getAttrType(const std::string& name) const override final { + std::string getAttrType(const std::string& name) const override final { for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { if (name == EnumStrings<ATTRS_ENUM>::data[i]) { return getAttrType(i).name(); @@ -170,7 +170,7 @@ public: } #ifdef PYBIND - py::object getAttrPy(const std::string& name) const { + py::object getAttrPy(const std::string& name) const override { for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { if (name == EnumStrings<ATTRS_ENUM>::data[i]) { // https://github.com/pybind/pybind11/blob/f3e0602802c7840992c97f4960515777cad6a5c7/include/pybind11/pytypes.h#L1119-L1138 diff --git a/include/aidge/utilsParsing/AstNode.hpp b/include/aidge/utilsParsing/AstNode.hpp index 1158ae148a22993476adb00ecbf8ebd24101830c..bf4f73236fb65b88da309e71ba55997b5342df41 100644 --- a/include/aidge/utilsParsing/AstNode.hpp +++ b/include/aidge/utilsParsing/AstNode.hpp @@ -1,7 +1,7 @@ -#ifndef _AIDGE_AST_NODE_H_ -#define _AIDGE_AST_NODE_H_ +#ifndef AIDGE_CORE_AST_NODE_H_ +#define AIDGE_CORE_AST_NODE_H_ #include <string> #include <type_traits> @@ -12,11 +12,11 @@ namespace Aidge{ template <typename EnumType> - class AstNode: public std::enable_shared_from_this<AstNode> + class AstNode: public std::enable_shared_from_this<AstNode<EnumType>> { static_assert(std::is_enum<EnumType>::value, "AstNode EnumType must be an enum type"); public: - AstNode(std::shared_ptr<ParsingToken<EnumType>> token,std::vector<std::shared_ptr<AstNode>> child ={}):mToken(token),mChild(child){} + AstNode(std::shared_ptr<ParsingToken<EnumType>> token,std::vector<std::shared_ptr<AstNode<EnumType>>> child ={}):mToken(token),mChild(child){} /** * @brief get the type of the token * @return the type @@ -41,7 +41,7 @@ namespace Aidge{ } /** * @brief test if the node is a leaf in the tree - * @return true if a leaf + * @return true if a leaf */ bool isLeaf() const { return mChild.size() == 0; @@ -66,4 +66,4 @@ namespace Aidge{ }; } -#endif //_AIDGE_AST_NODE_H_ +#endif //AIDGE_CORE_AST_NODE_H_ diff --git a/include/aidge/utilsParsing/ParsingToken.hpp b/include/aidge/utilsParsing/ParsingToken.hpp index 78045cf3085a18bfd0565354fd34aef02ef395bd..e303a5eabe6f7710873468f8edc8f3e844f4175f 100644 --- a/include/aidge/utilsParsing/ParsingToken.hpp +++ b/include/aidge/utilsParsing/ParsingToken.hpp @@ -1,13 +1,15 @@ -#ifndef _AIDGE_PARSING_TOKEN_H_ -#define _AIDGE_PARSING_TOKEN_H_ +#ifndef AIDGE_CORE_PARSING_TOKEN_H_ +#define AIDGE_CORE_PARSING_TOKEN_H_ #include <string> #include <type_traits> +#include <sstream> // Include the necessary header namespace Aidge{ + template <typename EnumType> - class ParsingToken: public std::enable_shared_from_this<ParsingToken> + class ParsingToken: public std::enable_shared_from_this<ParsingToken<EnumType>> { static_assert(std::is_enum<EnumType>::value, "ParsingToken EnumType must be an enum type"); public: @@ -16,11 +18,11 @@ namespace Aidge{ * @param type one of the token type * @param lexeme String representing aditional information of the token */ - ParsingToken(const EnumType type , const std::string lexeme )mLexeme(lexeme),mType(type){} + ParsingToken(const EnumType type , const std::string lexeme ):mLexeme(lexeme),mType(type){} /** * @brief get the lexeme - * @return std::string + * @return std::string */ const std::string getLexeme(void){ return mLexeme; @@ -28,8 +30,8 @@ namespace Aidge{ /** * @brief get the token type - * - * @return ParsingToken + * + * @return ParsingToken */ const EnumType getType(void){ return mType; @@ -39,7 +41,10 @@ namespace Aidge{ * @brief copy the token * @return deep copy of the token */ - std::shared_ptr<Aidge::ParsingToken> copy(); + std::shared_ptr<ParsingToken> copy(){ + auto newToken = std::make_shared<ParsingToken<EnumType>>(mType,mLexeme); + return newToken; + } //TODO std::ostringstream rep(void){ @@ -47,6 +52,7 @@ namespace Aidge{ out << " Token (" << mLexeme <<")" << "\n"; return out; } + private: /** @@ -63,4 +69,4 @@ namespace Aidge{ }; } -#endif //_AIDGE_PARSING_TOKEN_H_ \ No newline at end of file +#endif //AIDGE_CORE_PARSING_TOKEN_H_ diff --git a/python_binding/operator/pybind_Add.cpp b/python_binding/operator/pybind_Add.cpp index ab8b4cf7b91d5eea2db5245a8c5122ab004b4766..0b2323c5cfb660415ec3ae009beaa7aa78afca0b 100644 --- a/python_binding/operator/pybind_Add.cpp +++ b/python_binding/operator/pybind_Add.cpp @@ -20,7 +20,9 @@ namespace py = pybind11; namespace Aidge { template <std::size_t NUM> void declare_Add(py::module &m) { - py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "Add_Op", py::multiple_inheritance()); + py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "AddOp", py::multiple_inheritance()) + .def("get_inputs_name", &Add_Op<NUM>::getInputsName) + .def("get_outputs_name", &Add_Op<NUM>::getOutputsName); m.def("Add", &Add<NUM>, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_AvgPooling.cpp b/python_binding/operator/pybind_AvgPooling.cpp index 372afebdd3e1626cd0af88e335b78ec7fd73a5f4..a2eda1ab44f75b2f6a63d0d4d4f19cbee00b07c7 100644 --- a/python_binding/operator/pybind_AvgPooling.cpp +++ b/python_binding/operator/pybind_AvgPooling.cpp @@ -30,27 +30,23 @@ template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) { m, ("AvgPoolingOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) .def(py::init<const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &>(), + const std::array<DimSize_t, DIM> &>(), py::arg("kernel_dims"), - py::arg("stride_dims"), - py::arg("padding_dims")); - - m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + py::arg("stride_dims")) + .def("get_inputs_name", &AvgPooling_Op<DIM>::getInputsName) + .def("get_outputs_name", &AvgPooling_Op<DIM>::getOutputsName); + + m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, const std::string& name, - const std::vector<DimSize_t> &stride_dims, - const std::vector<DimSize_t> &padding_dims) { + const std::vector<DimSize_t> &stride_dims) { // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. + // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. if (kernel_dims.size() != DIM) { throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } if (stride_dims.size() != DIM) { throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } DimSize_t tmp_kernel_dims_array[DIM]; for (size_t i = 0; i < DIM; ++i) { tmp_kernel_dims_array[i] = kernel_dims[i]; @@ -59,19 +55,13 @@ template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) { for (size_t i = 0; i < DIM; ++i) { tmp_stride_dims_array[i] = stride_dims[i]; } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - return AvgPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array)); + return AvgPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array)); }, py::arg("kernel_dims"), py::arg("name") = "", - py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0)); - + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1)); + } @@ -79,10 +69,10 @@ void init_AvgPooling(py::module &m) { declare_AvgPoolingOp<1>(m); declare_AvgPoolingOp<2>(m); declare_AvgPoolingOp<3>(m); - + // FIXME: // m.def("AvgPooling1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&AvgPooling)); } } // namespace Aidge -#endif \ No newline at end of file +#endif diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index f43381fecc689a292e166c4da40ea0cb4842c9e6..cabaa2edd7053718160fa5013492d1914ee4cf16 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -21,7 +21,9 @@ namespace Aidge { template <DimSize_t DIM> void declare_BatchNormOp(py::module& m) { - py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, Attributes>(m, ("BatchNorm_Op" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()); + py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, Attributes>(m, ("BatchNormOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) + .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) + .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 0c09917d71e520227eed48705527adaf204857ee..aabbfd3dd9aecd44123962443458de56c0b7071c 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -32,33 +32,30 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { DimSize_t, const std::array<DimSize_t, DIM> &, const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &, const std::array<DimSize_t, DIM> &>(), py::arg("in_channels"), py::arg("out_channels"), py::arg("kernel_dims"), py::arg("stride_dims"), - py::arg("padding_dims"), - py::arg("dilation_dims")); - + py::arg("dilation_dims")) + .def("get_inputs_name", &Conv_Op<DIM>::getInputsName) + .def("get_outputs_name", &Conv_Op<DIM>::getOutputsName) + ; + m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, DimSize_t out_channels, const std::vector<DimSize_t>& kernel_dims, - const std::string& name, + const std::string& name, const std::vector<DimSize_t> &stride_dims, - const std::vector<DimSize_t> &padding_dims, const std::vector<DimSize_t> &dilation_dims) { // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. + // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. if (kernel_dims.size() != DIM) { throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } if (stride_dims.size() != DIM) { throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } if (dilation_dims.size() != DIM) { throw std::runtime_error("dilation_dims size [" + std::to_string(dilation_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } @@ -70,27 +67,20 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { for (size_t i = 0; i < DIM; ++i) { tmp_stride_dims_array[i] = stride_dims[i]; } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } DimSize_t tmp_dilation_dims_array[DIM]; for (size_t i = 0; i < DIM; ++i) { tmp_dilation_dims_array[i] = dilation_dims[i]; } const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; const DimSize_t (&dilation_dims_array)[DIM] = tmp_dilation_dims_array; - return Conv<DIM>(in_channels, out_channels, to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array), to_array(dilation_dims_array)); + return Conv<DIM>(in_channels, out_channels, to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(dilation_dims_array)); }, py::arg("in_channels"), py::arg("out_channels"), py::arg("kernel_dims"), py::arg("name") = "", py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); - } @@ -98,7 +88,7 @@ void init_Conv(py::module &m) { declare_ConvOp<1>(m); declare_ConvOp<2>(m); declare_ConvOp<3>(m); - + // FIXME: // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&Conv)); diff --git a/python_binding/operator/pybind_ConvDepthWise.cpp b/python_binding/operator/pybind_ConvDepthWise.cpp index 3f48c50f7ffdb44450c0e2a155d85dcbf9f73fd9..809a7d6e797651ed8c490aa9a886c31c7e4e6651 100644 --- a/python_binding/operator/pybind_ConvDepthWise.cpp +++ b/python_binding/operator/pybind_ConvDepthWise.cpp @@ -31,29 +31,25 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) { py::multiple_inheritance()) .def(py::init<const std::array<DimSize_t, DIM> &, const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &, const std::array<DimSize_t, DIM> &>(), py::arg("kernel_dims"), py::arg("stride_dims"), - py::arg("padding_dims"), - py::arg("dilation_dims")); - - m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + py::arg("dilation_dims")) + .def("get_inputs_name", &ConvDepthWise_Op<DIM>::getInputsName) + .def("get_outputs_name", &ConvDepthWise_Op<DIM>::getOutputsName); + + m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, const std::string& name, const std::vector<DimSize_t> &stride_dims, - const std::vector<DimSize_t> &padding_dims, const std::vector<DimSize_t> &dilation_dims) { // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. + // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. if (kernel_dims.size() != DIM) { throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } if (stride_dims.size() != DIM) { throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } if (dilation_dims.size() != DIM) { throw std::runtime_error("dilation_dims size [" + std::to_string(dilation_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } @@ -65,25 +61,19 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) { for (size_t i = 0; i < DIM; ++i) { tmp_stride_dims_array[i] = stride_dims[i]; } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } DimSize_t tmp_dilation_dims_array[DIM]; for (size_t i = 0; i < DIM; ++i) { tmp_dilation_dims_array[i] = dilation_dims[i]; } const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; const DimSize_t (&dilation_dims_array)[DIM] = tmp_dilation_dims_array; - return ConvDepthWise<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array), to_array(dilation_dims_array)); + return ConvDepthWise<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(dilation_dims_array)); }, py::arg("kernel_dims"), py::arg("name") = "", py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); - + } @@ -91,7 +81,7 @@ void init_ConvDepthWise(py::module &m) { declare_ConvDepthWiseOp<1>(m); declare_ConvDepthWiseOp<2>(m); declare_ConvDepthWiseOp<3>(m); - + // FIXME: // m.def("ConvDepthWise1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&ConvDepthWise)); diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp index 4b9d61d082ebed4d426b41efa071d3943f83d231..c6a1c70000e3e6d604a6652716667efa1c18e956 100644 --- a/python_binding/operator/pybind_FC.cpp +++ b/python_binding/operator/pybind_FC.cpp @@ -20,7 +20,9 @@ namespace py = pybind11; namespace Aidge { void declare_FC(py::module &m) { - py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator, Attributes>(m, "FC_Op", py::multiple_inheritance()); + py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator, Attributes>(m, "FCOp", py::multiple_inheritance()) + .def("get_inputs_name", &FC_Op::getInputsName) + .def("get_outputs_name", &FC_Op::getOutputsName); m.def("FC", &FC, py::arg("out_channels"), py::arg("nobias") = false, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_LeakyReLU.cpp b/python_binding/operator/pybind_LeakyReLU.cpp index cae8a88bab7b59189dfbc6528cd653f1c97cb73a..af7689f0e64dd4ca8f798dcb34ea968972ace464 100644 --- a/python_binding/operator/pybind_LeakyReLU.cpp +++ b/python_binding/operator/pybind_LeakyReLU.cpp @@ -18,7 +18,9 @@ namespace py = pybind11; namespace Aidge { void init_LeakyReLU(py::module& m) { - py::class_<LeakyReLU_Op, std::shared_ptr<LeakyReLU_Op>, Operator, Attributes>(m, "LeakyReLU_Op", py::multiple_inheritance()); + py::class_<LeakyReLU_Op, std::shared_ptr<LeakyReLU_Op>, Operator, Attributes>(m, "LeakyReLUOp", py::multiple_inheritance()) + .def("get_inputs_name", &LeakyReLU_Op::getInputsName) + .def("get_outputs_name", &LeakyReLU_Op::getOutputsName); m.def("LeakyReLU", &LeakyReLU, py::arg("negative_slope") = 0.0f, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_Matmul.cpp b/python_binding/operator/pybind_Matmul.cpp index 2f738550041bcdb1ae809d68fa24fdf5a72e9164..fdb51b24a87ce358c1e7808873ebc569ca2227c8 100644 --- a/python_binding/operator/pybind_Matmul.cpp +++ b/python_binding/operator/pybind_Matmul.cpp @@ -20,7 +20,9 @@ namespace py = pybind11; namespace Aidge { void declare_MatMul(py::module &m) { - py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, Operator, Attributes>(m, "MatMul_Op", py::multiple_inheritance()); + py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, Operator, Attributes>(m, "MatMulOp", py::multiple_inheritance()) + .def("get_inputs_name", &MatMul_Op::getInputsName) + .def("get_outputs_name", &MatMul_Op::getOutputsName); m.def("MatMul", &MatMul, py::arg("out_channels"), py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_MaxPooling.cpp b/python_binding/operator/pybind_MaxPooling.cpp index 2efd18c816c2d588e574872b3d3776a3409dc4ba..84313f298f90298726773630602e90a5ab3d3efd 100644 --- a/python_binding/operator/pybind_MaxPooling.cpp +++ b/python_binding/operator/pybind_MaxPooling.cpp @@ -30,27 +30,23 @@ template <DimIdx_t DIM> void declare_MaxPoolingOp(py::module &m) { m, ("MaxPoolingOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) .def(py::init<const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &>(), + const std::array<DimSize_t, DIM> &>(), py::arg("kernel_dims"), - py::arg("stride_dims"), - py::arg("padding_dims")); - - m.def(("MaxPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + py::arg("stride_dims")) + .def("get_inputs_name", &MaxPooling_Op<DIM>::getInputsName) + .def("get_outputs_name", &MaxPooling_Op<DIM>::getOutputsName); + + m.def(("MaxPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, const std::string& name, - const std::vector<DimSize_t> &stride_dims, - const std::vector<DimSize_t> &padding_dims) { + const std::vector<DimSize_t> &stride_dims) { // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. + // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. if (kernel_dims.size() != DIM) { throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } if (stride_dims.size() != DIM) { throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } DimSize_t tmp_kernel_dims_array[DIM]; for (size_t i = 0; i < DIM; ++i) { tmp_kernel_dims_array[i] = kernel_dims[i]; @@ -59,19 +55,13 @@ template <DimIdx_t DIM> void declare_MaxPoolingOp(py::module &m) { for (size_t i = 0; i < DIM; ++i) { tmp_stride_dims_array[i] = stride_dims[i]; } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - return MaxPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array)); + return MaxPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array)); }, py::arg("kernel_dims"), py::arg("name") = "", - py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0)); - + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1)); + } @@ -79,10 +69,10 @@ void init_MaxPooling(py::module &m) { declare_MaxPoolingOp<1>(m); declare_MaxPoolingOp<2>(m); declare_MaxPoolingOp<3>(m); - + // FIXME: // m.def("MaxPooling1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&MaxPooling)); } } // namespace Aidge -#endif \ No newline at end of file +#endif diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index ac9a34e0a14ace2cf264188302f52a27bf0f7222..d945b212ff6fb643302ca7512e91c7a778a39419 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -20,6 +20,7 @@ void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("output", &Operator::output, py::arg("outputIdx")) .def("input", &Operator::input, py::arg("inputIdx")) + .def("nb_data_inputs", &Operator::nbDataInputs) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_backend", &Operator::setBackend, py::arg("name")) diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 1c62cd0adf6b8712073ec0674754ce7c8c2014a5..107b7ba00e4077d9f7c215257bf7fd46629481c1 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -35,7 +35,9 @@ void init_Producer(py::module &m) { "ProducerOp", py::multiple_inheritance()) .def("dims", &Producer_Op::dims) - .def("set_output_tensor", &Producer_Op::setOutputTensor); + .def("set_output_tensor", &Producer_Op::setOutputTensor) + .def("get_inputs_name", &Producer_Op::getInputsName) + .def("get_outputs_name", &Producer_Op::getOutputsName); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); declare_Producer<1>(m); diff --git a/python_binding/operator/pybind_ReLU.cpp b/python_binding/operator/pybind_ReLU.cpp index 820589d76507b39ca65ac2397614aabd1221fe3e..dbcb483e8089373bc8599c2d09fed00049e2a2ac 100644 --- a/python_binding/operator/pybind_ReLU.cpp +++ b/python_binding/operator/pybind_ReLU.cpp @@ -18,7 +18,9 @@ namespace py = pybind11; namespace Aidge { void init_ReLU(py::module& m) { - py::class_<ReLU_Op, std::shared_ptr<ReLU_Op>, Operator>(m, "ReLU_Op", py::multiple_inheritance()); + py::class_<ReLU_Op, std::shared_ptr<ReLU_Op>, Operator>(m, "ReLUOp", py::multiple_inheritance()) + .def("get_inputs_name", &ReLU_Op::getInputsName) + .def("get_outputs_name", &ReLU_Op::getOutputsName); m.def("ReLU", &ReLU, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_Softmax.cpp b/python_binding/operator/pybind_Softmax.cpp index 72ac1107181c1d7e2f578e31a965636dbb5c111b..8e50ab7c83bf43285b357cb803c0ce3eb42f4cc7 100644 --- a/python_binding/operator/pybind_Softmax.cpp +++ b/python_binding/operator/pybind_Softmax.cpp @@ -19,7 +19,9 @@ namespace py = pybind11; namespace Aidge { void init_Softmax(py::module& m) { - py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "Softmax_Op", py::multiple_inheritance()); + py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "SoftmaxOp", py::multiple_inheritance()) + .def("get_inputs_name", &Softmax_Op::getInputsName) + .def("get_outputs_name", &Softmax_Op::getOutputsName); m.def("Softmax", &Softmax, py::arg("name") = ""); } diff --git a/setup.py b/setup.py index 16305afdfdfa5de2e328460d9e96c77eb96a9d98..b88329e54feab78e39bd79be0a129030098e216a 100644 --- a/setup.py +++ b/setup.py @@ -70,7 +70,8 @@ class CMakeBuild(build_ext): self.spawn(['cmake', str(cwd), param_py, '-DTEST=OFF', f'-DCMAKE_INSTALL_PREFIX:PATH={install_path}']) if not self.dry_run: - self.spawn(['make', 'all', 'install', '-j', max_jobs]) + self.spawn(['cmake', '--build', '.', '--config', 'Debug', '-j', max_jobs]) + self.spawn(['cmake', '--install', '.', '--config', 'Debug']) os.chdir(str(cwd)) aidge_package = build_lib / (get_project_name()) @@ -81,7 +82,7 @@ class CMakeBuild(build_ext): # Copy all shared object files from build_temp/lib to aidge_package for root, _, files in os.walk(build_temp.absolute()): for file in files: - if file.endswith('.so') and (root != str(aidge_package.absolute())): + if (file.endswith('.so') or file.endswith('.pyd')) and (root != str(aidge_package.absolute())): currentFile=os.path.join(root, file) shutil.copy(currentFile, str(aidge_package.absolute())) @@ -100,7 +101,6 @@ if __name__ == '__main__': long_description_content_type="text/markdown", long_description="\n".join(DOCLINES[2:]), classifiers=[c for c in CLASSIFIERS.split('\n') if c], - platforms=["Linux"], packages=find_packages(where="."), include_package_data=True, ext_modules=[CMakeExtension(get_project_name())], diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 03b2a9adb439eb00d0ba59a13fead4f25d617b36..8f8f51c89bbcc380963f355f781e8fda940dcffc 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -125,21 +125,17 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::dataInputs() const { - IOIndex_t nbDataIn = 0U; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - nbDataIn += inputNode->nbDataInputs(); - } - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataIn); - nbDataIn = 0U; + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (const std::shared_ptr<Node>& inputNode : mInputNodes) { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); - std::move(inputNodeinputs.begin(), inputNodeinputs.end(), - res.begin() + nbDataIn); - nbDataIn += inputNode->nbDataInputs(); - // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> - // inputs()).end()); + + for (const auto& input : inputNodeinputs) { + if (mNodes.find(input.first) == mNodes.end()) { + res.push_back(input); + } + } } return res; } @@ -147,21 +143,17 @@ Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs() const { - std::size_t nbIn = 0U; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - nbIn += inputNode->nbInputs(); - } - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbIn); - nbIn = 0U; + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (const std::shared_ptr<Node>& inputNode : mInputNodes) { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); - std::move(inputNodeinputs.begin(), inputNodeinputs.end(), - res.begin() + nbIn); - nbIn += inputNode->nbInputs(); - // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> - // inputs()).end()); + + for (const auto& input : inputNodeinputs) { + if (mNodes.find(input.first) == mNodes.end()) { + res.push_back(input); + } + } } return res; } diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 54fdac808642f3ae603e237737e265ba394fccbd..e6a53c871f5312c68f40dc5c9a2777729470298b 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -341,6 +341,35 @@ Aidge::NodePtr Aidge::Node::clone() const { return std::make_shared<Node>(mOperator->clone(), mName); } + +std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ + + std::set<Aidge::NodePtr> out; + nodeSee.insert(shared_from_this()); + + if(delta == 0) { + out.insert(shared_from_this()); + + }else if (delta > 0){ + for (const NodePtr& node : getChildren()) { + if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance + for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ + out.insert(ch); + } + } + } + }else{ + for (const NodePtr& node : getParents()) { + if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance + for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ + out.insert(pr); + } + } + } + } + + return out; +} ///////////////////////////////////////////////////////////////////////////////////////////// // private diff --git a/src/graphRegex/GraphFsmInterpreter.cpp b/src/graphRegex/GraphFsmInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2984ab4fb3864244c9e32dbfcda9ef2ae080acf0 --- /dev/null +++ b/src/graphRegex/GraphFsmInterpreter.cpp @@ -0,0 +1,182 @@ +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + +using namespace Aidge; + + +GraphFsmInterpreter::GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition):mParser(graphMatchExpr){ + mActGroupe = 0; + mNodesCondition = nodesCondition; +} +std::shared_ptr<FsmGraph> GraphFsmInterpreter::interpret(void){ + mActGroupe = 0; + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = mParser.parse(); + return visit(tree); +} +std::shared_ptr<FsmGraph> GraphFsmInterpreter::visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree){ + + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>> nextAstNodes = AstTree->getChilds(); + + if(AstTree->getType() == gRegexTokenTypes::SEP){ + return sepF(visit(nextAstNodes[0]),visit(nextAstNodes[1])); + }else if(AstTree->getType() == gRegexTokenTypes::NEXT){ + return nextF(visit(nextAstNodes[0]),visit(nextAstNodes[1])); + }else if(AstTree->getType() == gRegexTokenTypes::QOM){ + return qomF(visit(nextAstNodes[0])); + }else if(AstTree->getType() == gRegexTokenTypes::QZM){ + return qzmF(visit(nextAstNodes[0])); + }else if(AstTree->getType() == gRegexTokenTypes::KEY || AstTree->getType() == gRegexTokenTypes::CKEY){ + return keyF(AstTree); + }else if(AstTree->getType() == gRegexTokenTypes::LPAREN){ + mActGroupe += 1; + std::shared_ptr<FsmGraph> out = visit(nextAstNodes[0]); + mActGroupe -= 1; + return out; + }else{ + throw std::logic_error("visit Bad token type" ); + } +} + + + + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::keyF(std::shared_ptr<AstNode<gRegexTokenTypes>> AstNode){ + + + std::shared_ptr<FsmNode> start = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> valid = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmEdge> edge; + + + if(AstNode->getType() == gRegexTokenTypes::CKEY){ + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::COMMON,mNodesCondition,AstNode->getValue()); + }else if (AstNode->getType() == gRegexTokenTypes::KEY) + { + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::UNIQUE,mNodesCondition,AstNode->getValue()); + }else{ + + throw std::logic_error("keyF Bad in AST" ); + } + + graph->addEdge(edge); + graph->setGroupe(mActGroupe); + return graph; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::sepF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm){ + + size_t idxLeft = leftFsm->getNbSubFsm(); + rigthFsm->incOrigineAllNodeBy(idxLeft); + leftFsm->unionG(rigthFsm); + //the rigthFsm is no longer usfull + return leftFsm; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::nextF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm){ + /* + combine the 2 Graph + all valid node of A are merge with Start B, Start B is un Start + update the relative reference + + A B + SA -> VA + SB -> VB + A B + SA -> q -> VB + */ + leftFsm->mergeOneStartOneValid(rigthFsm); + //the rigthFsm is no longer usfull + return leftFsm; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fsm){ + /* + + + valid node is connect to the child of Start with the same edge condition + A + S -> V + + A + S -> V + (E|R) + V -> S + */ + + std::vector<std::shared_ptr<FsmNode>> allStart = fsm->getStartNodes(); + std::set<std::shared_ptr<FsmNode>> allValid = fsm->getValidNodes(); + std::shared_ptr<FsmEdge> edge; + + if(allStart.size() != 1){ + throw std::logic_error("qomF Bad in AST" ); + } + + for(auto start : allStart ){ + for(auto edgeStart :start->getEdges() ){ + if (auto sharedEdge = edgeStart.lock()) { + + const std::map<size_t, int> commonRef = sharedEdge->getRelative(); + bool haveCommon = !commonRef.empty(); + + for(auto valid : allValid){ + if(haveCommon){ + /* + the // quantif case + get the go back and make a lexeme id(number) + we need to go back to the ref delta min #TODO + */ + bool hasMinRef = false; + std::pair<size_t, int> minRef; + for (const auto& entry : commonRef) { + if (!hasMinRef || std::abs(minRef.second) > std::abs(entry.second)) { + hasMinRef = true; + minRef = entry; + } + } + std::stringstream lexem; + lexem << "(" << minRef.first << ", " << minRef.second << ")"; + edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::REF,mNodesCondition, lexem.str()); + }else{ + /* + the sequensial quantif case + no reference to common + */ + edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::EMPTY,mNodesCondition,""); + + } + fsm->addEdge(edge); + } + }else{ + throw std::runtime_error("edgeStart weak pointer is expired" ); + } + } + + } + return fsm; + +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::qzmF(std::shared_ptr<FsmGraph> fsm){ + /* + qomf and a bypass empty start to valide + */ + fsm = qomF(fsm); + + std::vector<std::shared_ptr<FsmNode>> allStart = fsm->getStartNodes(); + std::set<std::shared_ptr<FsmNode>> allValid = fsm->getValidNodes(); + std::shared_ptr<FsmEdge> edge; + + if(allStart.size() != 1){ + throw std::logic_error("qzmF Bad in AST" ); + } + + for(auto start : allStart ){ + + for(auto valid : allValid){ + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::EMPTY,mNodesCondition,""); + fsm->addEdge(edge); + } + } + + return fsm; + + +} \ No newline at end of file diff --git a/src/graphRegex/GraphLexer.cpp b/src/graphRegex/GraphLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61214f96a090fef5d28cb0ce1a009644d9570880 --- /dev/null +++ b/src/graphRegex/GraphLexer.cpp @@ -0,0 +1,155 @@ + +#include "aidge/graphRegex/GraphLexer.hpp" + +using namespace Aidge; + + +GraphLexer::GraphLexer( const std::string gRegexExpressions ): +mRegularExpressions(gRegexExpressions){ + mPosition = 0; +} + +std::shared_ptr<ParsingToken<gRegexTokenTypes>> GraphLexer::getNextToken(void){ + std::string currentChars = ""; + while (mPosition < mRegularExpressions.length()) + { + //erase all space + if (mRegularExpressions[mPosition] != ' ') + { + currentChars += mRegularExpressions[mPosition]; + } + else + { + mPosition++; + continue; + } + + ///// + // const lent token + ///// + + if (std::regex_match(currentChars,std::regex("\\->")))// the next TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::NEXT,""); + } + else if (std::regex_match(currentChars,std::regex("\\*")))// the * TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::QZM,""); + } + else if (std::regex_match(currentChars,std::regex("\\+")))// the + TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::QOM,""); + } + else if (std::regex_match(currentChars,std::regex("\\(")))// the LPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::LPAREN,""); + } + else if (std::regex_match(currentChars,std::regex("\\)")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::RPAREN,""); + } + + // + else if (std::regex_match(currentChars,std::regex(";")))// the SEP TOKEN + { + //test if the last sep + //std::string subStr = mRegularExpressions.substr(mPosition); + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::SEP,""); + } + + ///// + //unconst lent token + ///// + + else if (std::regex_match(currentChars,std::regex("[A-Za-z_0-9]")))// the KEY or CKEY + { + + //read all the key + bool isCKey = false; + std::regex keyRegex("[A-Za-z_0-9]+"); + std::regex cKeyRegex("[A-Za-z_0-9]+\\#[0-9]*"); + + while ( mPosition < mRegularExpressions.length()) { + + if(!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,cKeyRegex)) + { + currentChars.pop_back(); //the last char is the problemes + break; + } + else if (std::regex_match(currentChars,cKeyRegex)){ + isCKey = true; + } + mPosition++; + if (mPosition < mRegularExpressions.length()) currentChars += mRegularExpressions[mPosition]; + + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mRegularExpressions.length()-1) + { + if (!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,cKeyRegex)) + { + throw badTokenError(currentChars,mPosition); + } + } + + + if (isCKey){ + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::CKEY,currentChars); + } else{ + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::KEY,currentChars); + } + } + + mPosition++; + } + + + //no more to find no one match the currentChars + if (currentChars.empty()) { + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::STOP,""); // Null shared pointer ; + }else{ + throw badTokenError(currentChars,mPosition); + } + +} + +void GraphLexer::rstPosition(void){ + if (isEnd()){ + mPosition = 0; + }else{ + throw badTokenError("end rst",mPosition); + } +} + +bool GraphLexer::isEnd(void){ + return mPosition >= mRegularExpressions.length(); +} + +std::runtime_error GraphLexer::badTokenError(const std::string& currentChars,std::size_t position){ + std::ostringstream errorMessage; + errorMessage << "\nBad syntax " << currentChars << " :\n" << mRegularExpressions << "\n"; + for (std::size_t i = 0; i < position; i++) { + errorMessage << ' '; + } + errorMessage << "^\n"; + + return std::runtime_error(errorMessage.str()); +} + + const std::string GraphLexer::rep(){ + std::string out = mRegularExpressions; + out += "\n"; + for (std::size_t i = 0; i < mPosition; i++) { + out += ' '; + } + out += "^\n"; + return out ; + } \ No newline at end of file diff --git a/src/graphRegex/GraphParser.cpp b/src/graphRegex/GraphParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5aa653c482dae82c2e9fa02bfc36b2ffc821785f --- /dev/null +++ b/src/graphRegex/GraphParser.cpp @@ -0,0 +1,181 @@ +#include "aidge/graphRegex/GraphParser.hpp" + +using namespace Aidge; + +GraphParser::GraphParser(const std::string gRegexExpressions): +mLexer(gRegexExpressions) +{ + mCurrentToken = mLexer.getNextToken(); +} + + +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::parse(void){ + + std::shared_ptr<AstNode<gRegexTokenTypes>> astTree = constructAstAllExpr(); + rstParser(); + return astTree; +} + + +void GraphParser::rstParser(void){ + mLexer.rstPosition(); + mCurrentToken = mLexer.getNextToken(); +} + + +void GraphParser::ackToken(gRegexTokenTypes tokenType){ + + if(mCurrentToken->getType() == tokenType ){ + try { + mCurrentToken = mLexer.getNextToken(); + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "Graph Lexer error in Parser :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + }else{ + std::ostringstream errorMessage; + errorMessage << "Bad syntax GraphParser " << static_cast<int>(mCurrentToken->getType()) <<"!="<< static_cast<int>(tokenType) << "\n"; + errorMessage << mLexer.rep(); + throw std::runtime_error(errorMessage.str()); + } +} + +/* +exp : KEY(QOM | QZM)? | CKEY | domain +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstExp(void) +{ + + try{ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + std::shared_ptr<AstNode<gRegexTokenTypes>> node = std::make_shared<AstNode<gRegexTokenTypes>>(token); + + if (mCurrentToken->getType() == gRegexTokenTypes::KEY ){ + ackToken(gRegexTokenTypes::KEY ); + if (mCurrentToken->getType() == gRegexTokenTypes::QOM ){ + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::QOM ); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + return newNode; + }else if (mCurrentToken->getType() == gRegexTokenTypes::QZM ){ + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::QZM ); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + return newNode; + } + return node; + }else if (mCurrentToken->getType() == gRegexTokenTypes::CKEY){ + ackToken(gRegexTokenTypes::CKEY ); + return node; + }else{ + return constructAstDomain(); + } + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstExp :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} + +/* +seq :exp (NEXT seq)* +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstSeq(void) +{ + + try{ + + std::shared_ptr<AstNode<gRegexTokenTypes>> left = constructAstExp(); + if(mCurrentToken->getType() == gRegexTokenTypes::NEXT ) + { + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::NEXT); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{left,constructAstSeq()}); + left = newNode; + } + return left; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstSeq :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + +} + + +/* +LPAREN seq RPAREN (QOM | QZM) +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstDomain(void) +{ + + try{ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token ; + std::shared_ptr<AstNode<gRegexTokenTypes>> node ; + + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::LPAREN); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{constructAstSeq()}); + ackToken(gRegexTokenTypes::RPAREN); + //(QOM | QZM) + + token = mCurrentToken->copy(); + if (mCurrentToken->getType() == gRegexTokenTypes::QOM){ + ackToken(gRegexTokenTypes::QOM); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + }else if (mCurrentToken->getType() == gRegexTokenTypes::QZM){ + ackToken(gRegexTokenTypes::QZM); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + }else{ + std::ostringstream errorMessage; + errorMessage << "Bad syntax constructAstDomain must have quantifier \n"; + throw std::runtime_error(errorMessage.str()); + } + + return node; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstDomain :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} + +/* + allExpr: seq (SEP allExpr)* | STOP +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstAllExpr(void) +{ + + try{ + std::shared_ptr<AstNode<gRegexTokenTypes>> left = constructAstSeq(); + if(mCurrentToken->getType() == gRegexTokenTypes::SEP ) + { + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::SEP); + + if(mCurrentToken->getType() == gRegexTokenTypes::STOP ) + { + return left; + } + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{left,constructAstAllExpr()}); + left = newNode; + } + return left; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstDomain :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} diff --git a/src/graphRegex/GraphStrInterpreter.cpp b/src/graphRegex/GraphStrInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ad24b5b9b0fee5fba34dd7397132bec2410fd23 --- /dev/null +++ b/src/graphRegex/GraphStrInterpreter.cpp @@ -0,0 +1,38 @@ +#include "aidge/graphRegex/GraphStrInterpreter.hpp" + +using namespace Aidge; + +GraphStrInterpreter::GraphStrInterpreter(const std::string graphMatchExpr):mParser(graphMatchExpr){ + mToTest = graphMatchExpr; + mToTest.erase(std::remove_if(mToTest.begin(), mToTest.end(), ::isspace), mToTest.end()); +} + + +std::string GraphStrInterpreter::visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree){ + + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>> nextAstNodes = AstTree->getChilds(); + + if(AstTree->getType() == gRegexTokenTypes::SEP){ + return visit(nextAstNodes[0])+";"+visit(nextAstNodes[1]); + }else if(AstTree->getType() == gRegexTokenTypes::NEXT){ + return visit(nextAstNodes[0])+"->"+visit(nextAstNodes[1]); + }else if(AstTree->getType() == gRegexTokenTypes::QOM){ + return visit(nextAstNodes[0])+"+"; + }else if(AstTree->getType() == gRegexTokenTypes::QZM){ + return visit(nextAstNodes[0])+"*"; + }else if(AstTree->getType() == gRegexTokenTypes::KEY || AstTree->getType() == gRegexTokenTypes::CKEY){ + return AstTree->getValue(); + }else if(AstTree->getType() == gRegexTokenTypes::LPAREN){ + return "("+visit(nextAstNodes[0])+")"; + }else{ + throw std::logic_error("visit Bad token type" ); + } + + +} + + +std::string GraphStrInterpreter::interpret(void){ + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = mParser.parse(); + return visit(tree); +} \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmEdge.cpp b/src/graphRegex/matchFsm/FsmEdge.cpp new file mode 100644 index 0000000000000000000000000000000000000000..593da06abe18576d435ae55718d379aa5b682d60 --- /dev/null +++ b/src/graphRegex/matchFsm/FsmEdge.cpp @@ -0,0 +1,277 @@ +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + +std::map<std::string,int> FsmEdgeCommon::mCommonIdxMap; + +bool FsmEdge::isCommon(void){ + return false; +} + +size_t FsmEdge::getCommonIdx(void){ + return std::numeric_limits<std::size_t>::max(); +} +const std::map<size_t,int>& FsmEdge::getRelative(void){ + return mRelativePos; +} +void FsmEdge::updateRelative( const std::map<size_t,int>& relativePos ){ + for (const auto& kvp : relativePos) { + mRelativePos.insert(kvp); + } +} +std::shared_ptr<FsmNode> FsmEdge::getSourceNode(void){ + return mNodeSource; +} +void FsmEdge::reSetSouceNode(const std::shared_ptr<FsmNode>& newSource){ + mNodeSource->rmEdge(shared_from_this()); + mNodeSource = newSource; + mNodeSource->addEdge(shared_from_this()); + propagateRelativePos(); + +} +std::shared_ptr<FsmNode> FsmEdge::getDestNode(void){ + return mNodeDest; +} +void FsmEdge::reSetDestNode(const std::shared_ptr<FsmNode>& newDest){ + mNodeDest->rmParent(mNodeSource); + mNodeDest = newDest; + mNodeDest->addParent(mNodeSource); + propagateRelativePos(); +} +void FsmEdge::propagateRelativePos(void){ + + std::set<int> myRelativeID; + for (const auto& kvp : mRelativePos) { + myRelativeID.insert(kvp.first); + } + + for (const auto& nextWeakEdge : mNodeDest->getEdges()){ + + if (auto nextEdge = nextWeakEdge.lock()) { + + if(this == nextEdge.get()){ + continue; + } + + + std::set<int> nextRelativeID; + for (const auto& kvp : nextEdge->getRelative()) { + nextRelativeID.insert(kvp.first); + } + + // Find elements in myRelativeID but not in nextRelativeID + std::set<int> idxsToPush; + std::set_difference(myRelativeID.begin(), myRelativeID.end(), + nextRelativeID.begin(), nextRelativeID.end(), + std::inserter(idxsToPush, idxsToPush.begin())); + + // Find elements in nextRelativeID but not in myRelativeID + std::set<int> idxsToGet; + std::set_difference(nextRelativeID.begin(), nextRelativeID.end(), + myRelativeID.begin(), myRelativeID.end(), + std::inserter(idxsToGet, idxsToGet.begin())); + + // test for integrity we look if 2 edge refert to the samme + // ref and are link the ref dif is one + // not working for common node + // we can go deeper by find the all pass to a ref and see if the delta is good + + // Find elements present in both myRelativeID and nextRelativeID + std::set<int> idxsTotest; + for (int idx : nextRelativeID){ + if (myRelativeID.find(idx) != myRelativeID.end()){ + if (std::abs(getRelative().at(idx) - nextEdge->getRelative().at(idx)) != 1) { + throw std::runtime_error("Bad relative"); + } + } + } + + + + // this egde have more relative info than the next + std::map<size_t,int> tmpRelative; + // we push this info to the next + for( auto idxToPush :idxsToPush ){ + tmpRelative.insert( std::make_pair(idxToPush, getRelative().at(idxToPush) +1)); + } + if(tmpRelative.size() != 0){ + nextEdge->updateRelative(tmpRelative); + nextEdge->propagateRelativePos(); + } + tmpRelative.clear(); + + + // the next node have more info than me i need to get it + for( auto idxToGet :idxsToGet ){ + tmpRelative.insert( std::make_pair(idxToGet, nextEdge->getRelative().at(idxToGet) -1)); + } + if(tmpRelative.size() != 0){ + updateRelative(tmpRelative); + + for(auto weakParent : getSourceNode()->getParentNodes()){ + if (auto parent = weakParent.lock()) { + for(auto weakPEdge : parent->getEdges()){ + if (auto pEdge = weakPEdge.lock()) { + pEdge->propagateRelativePos(); + }else{ + throw std::runtime_error("propagateRelativePos parent edge weak pointer is expired" ); + } + } + }else{ + throw std::runtime_error("propagateRelativePos parent weak pointer is expired" ); + } + } + } + tmpRelative.clear(); + }else{ + throw std::runtime_error("propagateRelativePos edge weak pointer is expired" ); + } + } +} + +void FsmEdge::updateWeak(void){ + mNodeSource->addEdge(shared_from_this()); + mNodeDest->addParent(mNodeSource); +} + +FsmEdge::FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest) +:mToTest(toTest) +{ + mNodeSource = source; + mNodeDest = dest; + // wen i make the edge I init the nodes + // mNodeSource->addEdge(shared_from_this()); + // mNodeDest->addParent(mNodeSource); +} + + +/////surchage + +FsmEdgeUnique::FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest) +:FsmEdge(source,dest,toTest) +{ +} +const EdgeTestResult FsmEdgeUnique::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + auto opNode = stmContext->getActNode(); + + if(opNode == nullptr){ + return {false,std::set<NodePtr>()};//none + } + + if(mToTest->test(opNode) && opNode->getChildren().size() <= 1){ + stmContext->setValid(opNode,mToTest); + return {true,opNode->getChildren()} ; + }else{ + stmContext->addRejectedNode(opNode); + return {false,std::set<NodePtr>()}; + } +} +///////////////////// +FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey) +:FsmEdge(source,dest,toTest) +{ + //make a uid for common node + if(mCommonIdxMap.find(commonKey) == mCommonIdxMap.end()){ + mCommonIdxMap.insert(std::make_pair(commonKey, mCommonIdxMap.size())); + } + mCommonIdx = mCommonIdxMap[commonKey]; + propagateRelativePos(); +} + + +const EdgeTestResult FsmEdgeCommon::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + + auto opNode = stmContext->getActNode(); + + if(opNode == nullptr){ + return {false,std::set<NodePtr>()};//none + } + if(mToTest->test(opNode)){ + stmContext->setCommon(opNode,mCommonIdx); + stmContext->setValid(opNode,mToTest); + return {true,opNode->getChildren()} ; + }else{ + stmContext->addRejectedNode(opNode); + return {false,std::set<NodePtr>()}; + } +} +bool FsmEdgeCommon::isCommon(void){ + return true; + } +//////////////////// TODO FsmEdgeEmpty must be size_t +FsmEdgeRef::FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx) +:FsmEdge(source,dest,nullptr),mRefCommonIdx(refCommonIdx),mdeltaCommonIdx(deltaCommonIdx) +{ + +} +const EdgeTestResult FsmEdgeRef::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + + NodePtr refNode = stmContext->getCommonNodeFromIdx(mRefCommonIdx); + if (refNode){ + std::set<std::shared_ptr<Node>> see; + return {true,refNode->getNodeDelta(mdeltaCommonIdx,see)}; + } + return {false,std::set<NodePtr>()}; +} +//////////////////// +FsmEdgeEmpty::FsmEdgeEmpty(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest) +:FsmEdge(source,dest,nullptr) +{} +const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + auto opNode = stmContext->getActNode(); + if(opNode == nullptr){ + return {false,std::set<NodePtr>()}; + } + return {true,std::set<NodePtr>({opNode})};//none +} + +/// factory +std::shared_ptr<FsmEdge> FsmEdgeFactory::make( +std::shared_ptr<FsmNode> source, +std::shared_ptr<FsmNode> dest, FsmEdgeTypes type, +std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest, +const std::string lexeme) +{ + if (type == FsmEdgeTypes::EMPTY) { + if (lexeme.empty()) { + return std::make_shared<FsmEdgeEmpty>(source, dest); + } else { + throw std::invalid_argument("error lexem EMPTY"); + } + } else if (type == FsmEdgeTypes::REF) { + std::smatch m; + std::regex refRegex("\\s*\\(\\s*(\\d+)\\s*,\\s*(-?\\d+)\\s*\\)\\s*"); + if (std::regex_match(lexeme, m, refRegex)) { + int refCommonIdx = std::stoi(m[1]); + int deltaCommonIdx = std::stoi(m[2]); + return std::make_shared<FsmEdgeRef>(source, dest, refCommonIdx, deltaCommonIdx); + } else { + throw std::invalid_argument("error lexem REF " + lexeme); + } + } else if (type == FsmEdgeTypes::COMMON) { + std::smatch m; + std::regex commonRegex("\\s*(\\w+)#(\\d*)"); + if (std::regex_match(lexeme, m, commonRegex)) { + std::string edgeType = m[1]; + std::string commonId = m[2]; + size_t commonIdx = commonId.empty() ? 0 : std::stoi(commonId) + 1; + std::string commonKey = edgeType + std::to_string(commonIdx); + return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey); + } else { + throw std::invalid_argument("error lexem COMMON " + lexeme); + } + } else if (type == FsmEdgeTypes::UNIQUE) { + std::regex uniqueRegex("\\s*(\\w+)"); + std::smatch m; + if (std::regex_match(lexeme, m, uniqueRegex)) { + std::string edgeType = m[1]; + return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType)); + } else { + throw std::invalid_argument("error lexem UNIQUE \"" + std::string(lexeme) +" eee\""); + } + } else { + throw std::invalid_argument("Bad edge Type"); + } + } \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmGraph.cpp b/src/graphRegex/matchFsm/FsmGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..09bc25d636c1cc882439f50107bf728714fdfb20 --- /dev/null +++ b/src/graphRegex/matchFsm/FsmGraph.cpp @@ -0,0 +1,199 @@ +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +using namespace Aidge; + + + +FsmGraph::FsmGraph(/* args */){ + +} + +//TODO + std::shared_ptr<MatchResult> FsmGraph::test(std::vector<NodePtr>& startNodes){ + std::vector<std::shared_ptr<Aidge::FsmNode>> startNodesFsm = getStartNodes(); + if(startNodes.size() != startNodesFsm.size()){ + throw std::runtime_error("bad number of Start nodes"); + } + + std::vector<std::shared_ptr<FsmRunTimeContext>> walks; + for(std::size_t i = 0; i < startNodes.size(); i++){ + walks.push_back(std::make_shared<FsmRunTimeContext>(startNodesFsm[i],startNodes[i])); + } + std::vector<std::shared_ptr<FsmRunTimeContext>> nextWalks; + + std::vector<std::shared_ptr<FsmRunTimeContext>> allValidContext; + std::vector<std::shared_ptr<FsmRunTimeContext>> allContextSee; + + + + + while (!walks.empty()) + { + for(auto fsmContext : walks){ + allContextSee.push_back(fsmContext); + //if we are in a valid st we save it + //it's one solution of the posible solution of the matching + if(fsmContext->isOnValidState()){ + //not save 2 time the same end point + if(!std::any_of(allValidContext.begin(), allValidContext.end(), + [&](std::shared_ptr<Aidge::FsmRunTimeContext> oldValid) { + return fsmContext->areEqual(oldValid); + })){ + allValidContext.push_back(fsmContext); + } + + } + + //dont test 2 time a fsmContext + std::vector<std::shared_ptr<FsmRunTimeContext>> tmpNextWalks = fsmContext->getActState()->test(fsmContext); + for(auto PotentialFsmContext : tmpNextWalks){ + + if(!std::any_of(allContextSee.begin(), allContextSee.end(), + [&](std::shared_ptr<Aidge::FsmRunTimeContext> oldSee) { + return PotentialFsmContext->areEqual(oldSee); + })){ + nextWalks.push_back(PotentialFsmContext); + } + } + + } + walks.swap(nextWalks); + nextWalks.clear(); + } + + + return std::make_shared<MatchResult>(allValidContext,getNbSubFsm()); + +} + + +/////////////// +// FSM construction +/////////////// +const std::set<std::shared_ptr<FsmEdge>>& FsmGraph::getEdge(void){ + return mEdges; +} + +void FsmGraph::addEdge(std::shared_ptr<FsmEdge>& edge){ + edge->updateWeak(); + mEdges.insert(edge); + mAllOrigine.insert(edge->getDestNode()->getOrigine()); + mAllOrigine.insert(edge->getSourceNode()->getOrigine()); +} + +const std::vector<std::shared_ptr<FsmNode>> FsmGraph::getStartNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + std::vector<std::shared_ptr<FsmNode>> startNodes; + for(auto node :nodes){ + if(node->isStart()){ + startNodes.push_back(node); + } + } + return startNodes; +} + +const std::set<std::shared_ptr<FsmNode>> FsmGraph::getValidNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + std::set<std::shared_ptr<FsmNode>> ValidNodes; + for(auto node :nodes){ + if(node->isValid()){ + ValidNodes.insert(node); + } + } + //may short + return ValidNodes; +} + +const std::set<std::shared_ptr<FsmNode>> FsmGraph::getNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes; + for(auto edge : mEdges){ + nodes.insert(edge->getDestNode()); + nodes.insert(edge->getSourceNode()); + } + return nodes; +} + +void FsmGraph::setGroupe(std::size_t groupeIdx){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + for(auto node :nodes){ + node->setGroupe(groupeIdx); + } +} + +void FsmGraph::unionG(const std::shared_ptr<FsmGraph> fsmGraph){ + + for(auto edge : fsmGraph->getEdge()){ + addEdge(edge); + } +} + +void FsmGraph::mergeOneStartOneValid(const std::shared_ptr<FsmGraph> fsmGraph){ + std::set<std::shared_ptr<FsmNode>> validNodes = getValidNodes(); + std::vector<std::shared_ptr<FsmNode>> startNodes = fsmGraph->getStartNodes(); + + if (startNodes.size() != 1 || validNodes.size() != 1){ + + std::ostringstream errorMessage; + errorMessage <<"mergeOneStartOneValid start size: " << startNodes.size() << " valide size : " << validNodes.size() + <<" can only merge FSM 1 start 1 valide"; + throw std::runtime_error(errorMessage.str()); + } + + unionG(fsmGraph); + //for loop useless but for future merge it's coudl be used + for(auto valid : validNodes){ + valid->unValid(); + for(auto start : startNodes){ + start->unStart(); + _mergeNode(start,valid); + } + } +} + +std::size_t FsmGraph::getNbSubFsm(void){ + return mAllOrigine.size(); +} + +void FsmGraph::incOrigineAllNodeBy(std::size_t incr){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + for(auto node :nodes){ + node->incOrigine(incr); + } + for(auto origin : mAllOrigine){ + origin += incr; + } +} + +void FsmGraph::_mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + + if(nodes.find(source) == nodes.end() || nodes.find(dest) == nodes.end()){ + throw std::runtime_error("FsmGraph can not merge node not in the graph"); + } + nodes.clear(); + + //probagate source attribut + if(source->isValid()){ + dest->valid(); + } + if(source->isStart()){ + dest->start(); + } + + //merge source to dest by replace source by dest in all EDGE + for(auto edge : mEdges){ + if(edge->getDestNode() == source ){ + edge->reSetDestNode(dest); + }else if(edge->getSourceNode() == source ){ + edge->reSetSouceNode(dest); + } + + } + //check is source is not in graph + nodes = getNodes(); + if(nodes.find(source) != nodes.end() ){ + throw std::runtime_error("FsmGraph merge node not effective"); + } + nodes.clear(); + +} diff --git a/src/graphRegex/matchFsm/FsmNode.cpp b/src/graphRegex/matchFsm/FsmNode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84b4a0c3fdbe0730a12a2a62db9158e2538d646f --- /dev/null +++ b/src/graphRegex/matchFsm/FsmNode.cpp @@ -0,0 +1,132 @@ +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + + + +FsmNode::FsmNode(bool isAValid,bool isAStart ){ + mIsAStart =isAStart; + mIsAValid =isAValid; + +} +const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared_ptr<FsmRunTimeContext> fsmContext){ + + + std::vector<std::shared_ptr<FsmRunTimeContext>> out; + + for(auto edge : mEdges){ + if (auto sharedEdge = edge.lock()) { + + std::shared_ptr<FsmNode> nextState = sharedEdge->getDestNode(); + + //make copy of the fsmContext + std::shared_ptr<FsmRunTimeContext> newFsmContext = std::make_shared<FsmRunTimeContext>(fsmContext); + + EdgeTestResult edgeRes = sharedEdge->test(newFsmContext); + + if(edgeRes.success){ + if(edgeRes.node.size() != 0){ + for(auto nextNode :edgeRes.node ){ + if(!newFsmContext->isAlreadyValid(nextNode)|| newFsmContext->isCommonDefined(nextNode) ){ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nextNode)); + + }else{ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nullptr)); + } + + } + }else{ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nullptr)); + } + } + newFsmContext.reset(); + + }else{ + throw std::runtime_error("test FsmNode weak pointer is expired" ); + } + + } + return out; +} + + + +std::size_t FsmNode::getOrigine(void){ + return mOrigineStm; +} +void FsmNode::incOrigine(std::size_t inc){ + mOrigineStm += inc; +} +void FsmNode::rmEdge(std::shared_ptr<FsmEdge> edge){ + mEdges.erase(edge); +} + +void FsmNode::addEdge(std::shared_ptr<FsmEdge> edge){ + std::weak_ptr<FsmEdge> edgeW(edge); + if (!edgeW.expired()) { + mEdges.insert(edgeW); + }else{ + throw std::runtime_error("addEdge FsmNode weak pointer is expired" ); + } +} + +// const std::set<std::shared_ptr<FsmNode>> FsmNode::getChildNodes(void){ +// std::set<std::shared_ptr<FsmNode>> children; +// for(auto edge : mEdges){ +// if (auto sharedEdge = edge.lock()) { +// children.insert(sharedEdge->getDestNode()); +// }else{ +// throw std::runtime_error("getChildNodes FsmNode weak pointer is expired" ); +// } +// } +// return children; +// } + + +const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>>& FsmNode::getParentNodes(void){ + return mParents; +} +const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& FsmNode::getEdges(void){ + return mEdges; +} + +void FsmNode::setGroupe(std::size_t groupeIdx){ + mGroupeStm = groupeIdx; + +} + +bool FsmNode::isValid(void){ + return mIsAValid; +} +bool FsmNode::isStart(void){ + return mIsAStart; +} +void FsmNode::unValid(void){ + mIsAValid =false; +} +void FsmNode::valid(void){ + mIsAValid =true; +} +void FsmNode::unStart(void){ + mIsAStart =false; +} +void FsmNode::start(void){ + mIsAStart =true; +} + + + +void FsmNode::addParent(std::shared_ptr<FsmNode> node){ + + std::weak_ptr<FsmNode> nodeW(node); + if (!nodeW.expired()) { + mParents.insert(nodeW); + }else{ + throw std::runtime_error("addParent FsmNode weak pointer is expired" ); + } +} +void FsmNode::rmParent(std::shared_ptr<FsmNode> node){ + mParents.erase(node); +} \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..787cf2322a5b8e7001cdc59325345000dbb61553 --- /dev/null +++ b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp @@ -0,0 +1,226 @@ +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" + +using namespace Aidge; + +std::vector<std::set<NodePtr>> FsmRunTimeContext::mRejectedNodes; + +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced ){ + mActOpNode = actOpNode; + mActState = actState; + + //not define case + if(idxRejeced == std::numeric_limits<std::size_t>::max()){ + mLocalIdxRejeced = mRejectedNodes.size(); + mRejectedNodes.push_back(std::set<NodePtr>()); + }else{ + if(idxRejeced > mRejectedNodes.size()-1 ){ + throw std::runtime_error("FsmRunTimeContext idxRejeced"); + } + mLocalIdxRejeced =idxRejeced; + } +} + + + +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime){ + mActOpNode = fsmRunTime->mActOpNode; + mActState = fsmRunTime->mActState; + mCommonNodes = fsmRunTime->mCommonNodes; + mValidNodes = fsmRunTime->mValidNodes; + mLocalIdxRejeced = fsmRunTime->mLocalIdxRejeced; +} +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ){ + mActOpNode = actOpNode; + mActState = actState; + mCommonNodes = fsmRunTime->mCommonNodes; + mValidNodes = fsmRunTime->mValidNodes; + mLocalIdxRejeced = fsmRunTime->mLocalIdxRejeced; +} + +void FsmRunTimeContext::addRejectedNode(NodePtr node){ + mRejectedNodes[mLocalIdxRejeced].insert(node); +} + +std::set<NodePtr> FsmRunTimeContext::getRejectedNodes(void){ + return mRejectedNodes[mLocalIdxRejeced]; +} + +bool FsmRunTimeContext::isOnValidState(void){ + return mActState->isValid(); +} + +bool FsmRunTimeContext::isCommonDefined(NodePtr node){ + //return mCommonNodes.find(node) != mCommonNodes.end(); + + std::set<NodePtr> nodes = getCommonNodes(); + for(const auto& nodeC : nodes){ + if(nodeC.get() == node.get()){ + return true; + } + } + return false; +} + +bool FsmRunTimeContext::isAlreadyValid(NodePtr node){ + + std::set<NodePtr> nodes = getValidNodes(); + for(const auto& nodeV : nodes){ + if(nodeV.get() == node.get()){ + return true; + } + } + return false; + + //return getValidNodes().find(node) != getValidNodes().end(); +} + +bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext){ + /* + see if 2 context can be merge + it need to have different mValidNodes exept for common + and the same idx for the common + */ + + //common node + + for (const auto& ref : getCommon()) { + for (const auto& test : fsmContext->getCommon()) { + //same index + if(ref.second == test.second){ + if(ref.first != test.first){ + return false; + } + } + } + } + + //valid nodes + std::set<NodePtr> commonElements; + std::set<NodePtr> A = getValidNodesNoCommon(); + std::set<NodePtr> B = fsmContext->getValidNodesNoCommon(); + std::set_intersection( + A.begin(),A.end(), + B.begin(), B.end(), + std::inserter(commonElements, commonElements.end()) + ); + + + if (!commonElements.empty()) { + return false; + } + + return true; +} + +bool FsmRunTimeContext::areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext){ + if(getActNode() != fsmContext->getActNode()){ + return false; + } + if (getActState() != fsmContext->getActState()){ + return false; + } + if (getValidNodes() != fsmContext->getValidNodes()){ + return false; + } + if (getCommon() != fsmContext->getCommon()){ + return false; + } + + + return true; +} + +void FsmRunTimeContext::setCommon(NodePtr node,std::size_t commonIdx){ + if(isCommonDefined(node)){ + if (mCommonNodes.at(node) != commonIdx){ + throw std::runtime_error("conflict idx in the Common node"); + } + }else{ + mCommonNodes[node] = commonIdx; + } +} + +void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag){ + //we already find a node of this type + if(mValidNodes.find(tag) != mValidNodes.end()){ + if(isAlreadyValid(node) && !isCommonDefined(node) ){ + throw std::runtime_error("setValid you valid tow time"); + } + mValidNodes[tag].insert(node); + }else{ + mValidNodes[tag] = {node}; + } + +} + +std::size_t FsmRunTimeContext::getSubStmId(void){ + return mActState->getOrigine(); +} + +NodePtr FsmRunTimeContext::getCommonNodeFromIdx(std::size_t commonIdx){ + for (const auto& pair : mCommonNodes) { + if (pair.second == commonIdx) { + return pair.first; // Return the key when the value is found + } + } + throw std::runtime_error("getCommonNodeFromIdx Value not found in the map"); +} + +std::size_t FsmRunTimeContext::getCommonNodeIdx(NodePtr node){ + if(isCommonDefined(node)){ + return mCommonNodes.at(node); + } + throw std::runtime_error("getCommonNodeIdx node not found"); +} + +std::set<NodePtr> FsmRunTimeContext::getCommonNodes(void){ + std::set<NodePtr> nodes; + // Iterate over the map and insert values into the set + for (const auto& pair : mCommonNodes) { + nodes.insert(pair.first); + } + return nodes; +} + +std::map<NodePtr,std::size_t> FsmRunTimeContext::getCommon(void){ + return mCommonNodes; +} + +std::set<NodePtr> FsmRunTimeContext::getValidNodes(void){ + + auto sharedSet = std::make_shared<std::set<NodePtr>>(); + // Create a set to store the values from the map + std::set<NodePtr> nodes; + // Iterate over the map and insert values into the set + for (const auto& pair : mValidNodes) { + nodes.insert(pair.second.begin(),pair.second.end()); + } + return nodes; +} + +std::set<NodePtr> FsmRunTimeContext::getValidNodesNoCommon(void){ + std::set<NodePtr> differenceSet; + std::set<NodePtr> valide = getValidNodes(); + std::set<NodePtr> common = getCommonNodes(); + std::set_difference(valide.begin(), valide.end(), common.begin(), common.end(),std::inserter(differenceSet, differenceSet.end())); + return differenceSet; +} + +std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> FsmRunTimeContext::getValid(void){ + return mValidNodes; +} + +NodePtr FsmRunTimeContext::getActNode(void){ + return mActOpNode; +} + +std::shared_ptr<FsmNode> FsmRunTimeContext::getActState(){ + return mActState; +} + + +void FsmRunTimeContext::rst(void){ + mRejectedNodes.clear(); +} + diff --git a/src/graphRegex/matchFsm/MatchResult.cpp b/src/graphRegex/matchFsm/MatchResult.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c35f1a7348e365baa8a27854ee6b0a833e342ee7 --- /dev/null +++ b/src/graphRegex/matchFsm/MatchResult.cpp @@ -0,0 +1,93 @@ +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" + +using namespace Aidge; + + +MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm):mIdToRunTime(nbSubStm){ + mAllValid = allValid; + mNbSubStm = nbSubStm; + + //mIdToRunTimm + for (const auto& contextPtr : allValid) { + mIdToRunTime[contextPtr->getSubStmId()].push_back(contextPtr); + } + + std::vector<std::shared_ptr<FsmRunTimeContext>> precedence; + //make all solution posible + _generateCombinationd(0,precedence); + //sort by solution number of elements + std::sort(mSolve.begin(), mSolve.end(), [](const std::set<NodePtr>& set1, const std::set<NodePtr>& set2) { + return set1.size() < set2.size(); + }); + + +} + +void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence){ + + //it's end , we are below the number of stm + if (idxSubStm == mNbSubStm) + { + //precedence containe a liste of FSM compatible, we just need to + //check if all the node have been valide by at least one contetext + + //1) make the set of all node for the comput graph that are valide in all the FsmRunTimeContext + std::set<NodePtr> validNode; + std::set<NodePtr> rejectNode; + for (const auto& contextPtr : precedence) { + std::set<NodePtr> tmpV = contextPtr->getValidNodes(); + validNode.insert(tmpV.begin(), tmpV.end()); + std::set<NodePtr> tmpR = contextPtr->getRejectedNodes(); + rejectNode.insert(tmpR.begin(),tmpR.end()); + } + // 2) all RejectedNodes need to be valide by an others stm + // if it's not the case the match is not valid + if(std::includes(validNode.begin(), validNode.end(), rejectNode.begin(), rejectNode.end())){ + //we can save the solution + mSolve.push_back(validNode); + } + precedence.pop_back(); + return; + } + + + for (const auto& contextPtrOneFsm : mIdToRunTime[idxSubStm]) + { + if(idxSubStm == 0){ + precedence.push_back(contextPtrOneFsm); + _generateCombinationd(idxSubStm+1,precedence); + + }else{ + //test if the new context is compatible whith all the context in the precedence + // + bool compatibleSolutionFsm = true; + for (const auto& contextPtrOfOtherFsm : precedence) { + if(!(contextPtrOneFsm->areCompatible(contextPtrOfOtherFsm))){ + compatibleSolutionFsm = false; + break; + } + } + + if(compatibleSolutionFsm){ + precedence.push_back(contextPtrOneFsm); + _generateCombinationd(idxSubStm+1,precedence); + } + + } + } + + if(idxSubStm != 0){ + precedence.pop_back(); + } + return; + +} + +std::set<NodePtr> MatchResult::getBiggerSolution(void){ + if(mSolve.empty()){ + return std::set<NodePtr>(); + }else{ + return mSolve[0]; + } + +} \ No newline at end of file diff --git a/src/nodeTester/ConditionalInterpreter.cpp b/src/nodeTester/ConditionalInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8cb20ac2f2348821b245dfc9f61be1072d76b9c9 --- /dev/null +++ b/src/nodeTester/ConditionalInterpreter.cpp @@ -0,0 +1,344 @@ + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + +using namespace Aidge; + + +/////////////////////////////// +//ConditionalRegisterFunction +/////////////////////////////// + + ConditionalData* ConditionalRegisterFunction::run(const std::string key,std::vector<ConditionalData*> & datas){ + + auto lambdaIt = mWlambda.find(key); + if (lambdaIt != mWlambda.end()) { + return lambdaIt->second(datas); + }else { + throw std::runtime_error("can not run Lambda due to invalid key: " + key); + } + } + +////////////////////// +//ConditionalInterpreter +/////////////////////// + ConditionalInterpreter::ConditionalInterpreter(const std::string ConditionalExpressions) + :mLambdaRegiter() + { + + ConditionalParser conditionalParser = ConditionalParser(ConditionalExpressions); + mTree = conditionalParser.parse(); + ///lambda by default + mLambdaRegiter.insert("getType",+[](NodePtr NodeOp){return NodeOp->type();}); + + } + + + bool ConditionalInterpreter::test( const NodePtr nodeOp) + { + + clearRes(); + try{ + std::vector<ConditionalData*> r = visit({mTree},nodeOp); + + if (mResolution.size() != 1){ + throw std::runtime_error("Multy output interpretation output"); + }else{ + if (!mResolution[0]->isTypeEqualTo<bool>()){ + throw std::runtime_error("TEST OUT MUST BE A BOOL "); + }else{ + return mResolution[0]->getValue<bool>(); + } + } + + }catch(const std::exception& e){ + std::ostringstream errorMessage; + errorMessage << "Error in test " << "\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + } + + void ConditionalInterpreter::insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f){ + mLambdaRegiter.insert<std::function<bool(Aidge::NodePtr)> >(key, f); + } + + ///// + std::vector<ConditionalData*> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){ + std::vector<ConditionalData*> dataVector; + + for ( std::shared_ptr<AstNode<ConditionalTokenTypes>> node : nodes) { + try{ + switch (node->getType()){ + /////////////////////////////////// + //OPERATOR + /////////////////////////////////// + case ConditionalTokenTypes::NOT: + { + visit(node->getChilds(),nodeOp); + fNot(); + } + break; + case ConditionalTokenTypes::AND: + { + visit(node->getChilds(),nodeOp); + fAnd(); + } + break; + case ConditionalTokenTypes::OR: + { + visit(node->getChilds(),nodeOp); + fOr(); + } + break; + case ConditionalTokenTypes::EQ: + { + visit(node->getChilds(),nodeOp); + fEq(); + //dataVector.insert(dataVector.end(), tmp.begin(), tmp.end()); + } + break; + case ConditionalTokenTypes::NEQ: + { + visit(node->getChilds(),nodeOp); + fNeq(); + } + break; + + /////////////////////////////////// + //VALUE + /////////////////////////////////// + + case ConditionalTokenTypes::KEY: + + break; + case ConditionalTokenTypes::INTEGER: + { + fStrToInteger(node); + } + break; + case ConditionalTokenTypes::FLOAT: + { + fStrToFloat(node); + + } + break; + case ConditionalTokenTypes::STRING: + { + fStrToStr(node); + } + break; + + case ConditionalTokenTypes::NODE: //TODO + { + + ConditionalData* data = new ConditionalData; + data->setValue<NodePtr>(nodeOp); + mResolution.push_back(data); + + } + break; + + case ConditionalTokenTypes::LAMBDA: + { + visit(node->getChilds(),nodeOp); + fLambda(node); + + } + break; + + case ConditionalTokenTypes::BOOL: //TODO + { + ConditionalData* data = new ConditionalData; + + if(node->getValue() == "true"){ + data->setValue<bool>(true); + }else{ + data->setValue<bool>(false); + } + + mResolution.push_back(data); + + } + break; + + case ConditionalTokenTypes::ARGSEP: + case ConditionalTokenTypes::LPAREN: + case ConditionalTokenTypes::RPAREN: + case ConditionalTokenTypes::STOP: + default: + throw std::runtime_error("NODE TYPE NOT SUPORTED IN ConditionalInterpreter"); + } + }catch(const std::exception& e){ + std::ostringstream errorMessage; + errorMessage << "Error in visiting AST for node"<< nodeOp->name() << "\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + } + + return dataVector; + } + + + ////////////////////// + //value convertor + ///////////////////// + + + void ConditionalInterpreter::fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + ConditionalData* data = new ConditionalData; + data->setValue<int>(std::stoi(node->getValue())); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + + ConditionalData* data = new ConditionalData; + data->setValue<float>(std::stof(node->getValue())); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + ConditionalData* data = new ConditionalData; + data->setValue<std::string>(node->getValue()); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + //if the lambda have input + ConditionalData* data; + try { + data = mLambdaRegiter.run(node->getValue(),mResolution); + } catch (const std::exception& e) { + std::ostringstream errorMessage; + errorMessage << "Error in conditional interpretation when run the "<< node->getValue() <<" Lambda\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fEq(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("EQ need 2 arg and get :" + mResolution.size()); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + if (a->getType() != b->getType()){ + throw std::runtime_error("EQ Unsuported between type :" + a->getType() +" "+ b->getType()); + } + + + + ConditionalData* data = new ConditionalData; + + if (a->isTypeEqualTo<int>()) { + data->setValue<bool>( a->getValue<int>() == b->getValue<int>()); + }else if (a->isTypeEqualTo<float>()){ + data->setValue<bool>( a->getValue<float>() == b->getValue<float>()); + }else if (a->isTypeEqualTo<std::string>()){ + data->setValue<bool>( a->getValue<std::string>() == b->getValue<std::string>()); + }else if (a->isTypeEqualTo<bool>()){ + data->setValue<bool>( a->getValue<bool>() == b->getValue<bool>()); + }else{ + throw std::runtime_error("EQ Unknown type encountered :" + a->getType() ); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fNeq(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("NEQ need 2 arg and get :" + mResolution.size()); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + if (a->getType() != b->getType()){ + throw std::runtime_error("NEQ Unsuported between type :" + a->getType() +" "+ b->getType()); + } + + ConditionalData* data = new ConditionalData; + + if (a->isTypeEqualTo<int>()) { + data->setValue<bool>( a->getValue<int>() != b->getValue<int>()); + }else if (a->isTypeEqualTo<float>()){ + data->setValue<bool>( a->getValue<float>() != b->getValue<float>()); + }else if (a->isTypeEqualTo<std::string>()){ + data->setValue<bool>( a->getValue<std::string>() != b->getValue<std::string>()); + }else + { + throw std::runtime_error("NEQ Unknown type encountered :" + a->getType() ); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fAnd(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("AND need 2 arg and get :" + mResolution.size()); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + + if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ + throw std::runtime_error("AND Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( a->getValue<bool>() && b->getValue<bool>()); + + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fOr(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("OR need 2 arg and get :" + mResolution.size()); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + + if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ + throw std::runtime_error("OR Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( a->getValue<bool>() || b->getValue<bool>()); + + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fNot() + { + if (mResolution.size() != 1){ + throw std::runtime_error("not need 1 arg and get :" + mResolution.size()); + } + auto a = mResolution[0]; + + if (a->getType() != typeid(bool).name()){ + throw std::runtime_error("NOT Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( !a->getValue<bool>() ); + + clearRes(); + mResolution.push_back(data); + + } diff --git a/src/nodeTester/ConditionalLexer.cpp b/src/nodeTester/ConditionalLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9379bd8409f8f7ec4bae3e0122f88de79718e9dd --- /dev/null +++ b/src/nodeTester/ConditionalLexer.cpp @@ -0,0 +1,242 @@ +#include "aidge/nodeTester/ConditionalLexer.hpp" + +using namespace Aidge; + +////////////////// +//ConditionalLexer +////////////////// + + +ConditionalLexer::ConditionalLexer( const std::string ConditionalExpressions): +mConditionalExpressions(ConditionalExpressions) +{ + mPosition = 0; +} + +std::shared_ptr<ParsingToken<ConditionalTokenTypes>> ConditionalLexer::getNextToken(void){ + std::string currentChars = ""; + + while (mPosition < mConditionalExpressions.length()) + { + //erase all space + if (mConditionalExpressions[mPosition] != ' ') + { + currentChars += mConditionalExpressions[mPosition]; + } + else + { + mPosition++; + continue; + } + //performe tokenisation, find a regex and make a new token + + if (std::regex_match(currentChars,std::regex("\\&\\&")))// the AND TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::AND,""); + } + else if (std::regex_match(currentChars,std::regex("\\|\\|")))// the OR TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::OR,""); + } + else if (std::regex_match(currentChars,std::regex("\\!")))// the Not and not equ + { + mPosition++; + if ( mPosition < mConditionalExpressions.length()){ + currentChars += mConditionalExpressions[mPosition]; + if(std::regex_match(currentChars,std::regex("!="))){ + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NEQ,""); + }else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NOT,""); + } + } + //a not at the end not ok but it's the parseur work + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NOT,""); + } + else if (std::regex_match(currentChars,std::regex("==")))// the EQ TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::EQ,""); + } + else if (std::regex_match(currentChars,std::regex("\\(")))// the LPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::LPAREN,""); + } + else if (std::regex_match(currentChars,std::regex("\\)")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::RPAREN,""); + } + else if (std::regex_match(currentChars,std::regex(",")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::ARGSEP,""); + } + else if (std::regex_match(currentChars,std::regex("\\$")))// the ACTNode TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NODE,""); + } + + + ///// + //non const lent token + ///// + + //LAMBDA, KEY , bool //the fuction TAG + else if (std::regex_match(currentChars,std::regex("[A-Za-z_]")))// the KEY TOKEN (a char next ) + { + //read all the key + bool isLambda = false; + std::regex keyRegex("[A-Za-z_0-9]+"); + std::regex LambdaRegex("[A-Za-z_0-9]+\\("); + + while ( mPosition < mConditionalExpressions.length()) { + if(!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,LambdaRegex)) + { + currentChars.pop_back(); //the last char is the problemes + break; + } + else if (std::regex_match(currentChars,LambdaRegex)){ + isLambda = true; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mConditionalExpressions.length()-1) + { + if (!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,LambdaRegex)) + { + throw badTokenError(currentChars,mPosition); + } + //mPosition++; // we stop all by going pos > lengt + } + + + if (std::regex_match(currentChars,std::regex("(true|false)"))){ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::BOOL,currentChars); + + } else if (isLambda){ + currentChars.pop_back();//pop the ( of the lambda + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::LAMBDA,currentChars); + } else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::KEY,currentChars); + } + + } + //numeric value + else if (std::regex_match(currentChars,std::regex("[0-9]")))// the KEY TOKEN (a char next ) + { + //read all the key + bool isFloat = false; + std::regex integerRegex("[0-9]+$"); + std::regex floatRegex("[0-9]+\\.[0-9]*$"); + + while ( mPosition < mConditionalExpressions.length()) { + + if(!std::regex_match(currentChars,integerRegex) && !std::regex_match(currentChars,floatRegex)) + { + currentChars.pop_back(); // the last char match is not a good one + break; + } + else if (std::regex_match(currentChars,floatRegex)){ + isFloat = true; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mConditionalExpressions.length()-1) + { + if (!std::regex_match(currentChars,integerRegex) && !std::regex_match(currentChars,floatRegex)) + { + throw badTokenError(currentChars,mPosition); + } + } + + if(isFloat){ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::FLOAT,currentChars); + }else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::INTEGER,currentChars); + } + + } + //string TODO + else if (std::regex_match(currentChars,std::regex("\'"))) // TODO ' or \' + { + std::regex strRegex("\'[A-Za-z_0-9\\s]*\'$"); + while ( mPosition < mConditionalExpressions.length()) { + if(std::regex_match(currentChars,strRegex)){ + break; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + + //test the end condition + if (mPosition == mConditionalExpressions.length()-1 ){ + if (!std::regex_match(currentChars,strRegex)){ + throw badTokenError(currentChars,mPosition); + } + //mPosition++; // we stop all by going pos > lengt + } + + mPosition++; // go after the last " + //erase the " char + currentChars.pop_back(); + currentChars.erase(0,1); + + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::STRING,currentChars); + + } + + //Array TODO + + mPosition++; + } + + //no more to find no one match the currentChars + if (currentChars.empty()) { + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::STOP,""); // Null shared pointer ; + }else{ + //std::ostringstream errorMessage; + //errorMessage << "\nBad syntax " << currentChars << " :\n" << mConditionalExpressions; + throw badTokenError(currentChars,mPosition); + } + +} + +void ConditionalLexer::rstPosition(void){ + if (isEnd()){ + mPosition = 0; + }else{ + throw badTokenError("end rst",mPosition); + } + +} + +bool ConditionalLexer::isEnd(void){ + return mPosition >= mConditionalExpressions.length(); +} + +std::runtime_error ConditionalLexer::badTokenError(const std::string& currentChars,std::size_t position){ + std::ostringstream errorMessage; + errorMessage << "\nBad syntax " << currentChars << " :\n" << mConditionalExpressions << "\n"; + for (std::size_t i = 0; i < position; i++) { + errorMessage << ' '; + } + errorMessage << "^\n"; + + return std::runtime_error(errorMessage.str()); +} \ No newline at end of file diff --git a/src/nodeTester/ConditionalParser.cpp b/src/nodeTester/ConditionalParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ca2843aabefe9f98bc8ad46a36fe03883d0baef --- /dev/null +++ b/src/nodeTester/ConditionalParser.cpp @@ -0,0 +1,188 @@ + +#include "aidge/nodeTester/ConditionalParser.hpp" + +using namespace Aidge; + + +////////////////////////////// +//ConditionalParser +////////////////////////////// + +ConditionalParser::ConditionalParser(const std::string ConditionalExpressions):mLexer(ConditionalExpressions){ + mCurrentToken = mLexer.getNextToken(); +} + +void ConditionalParser::rstParser(void){ + mLexer.rstPosition(); + mCurrentToken = mLexer.getNextToken(); +} + +void ConditionalParser::ackToken(ConditionalTokenTypes tokenType){ + if(mCurrentToken->getType() == tokenType ){ + + try { + mCurrentToken = mLexer.getNextToken(); + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "Conditional Lexer error in Parser :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + }else{ + + std::ostringstream errorMessage; + errorMessage << "Bad syntax ConditionalParser " << static_cast<int>(mCurrentToken->getType()) <<"!="<< static_cast<int>(tokenType) << "\n"; + errorMessage << mLexer.rep(); + throw std::runtime_error(errorMessage.str()); + } +} + + + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstVal(void){ + /* + val : (KEY|INTEGER|FOAT|STRING|LAMBDA) + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + + if (token->getType() == ConditionalTokenTypes::KEY){ + ackToken(ConditionalTokenTypes::KEY); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::INTEGER){ + ackToken(ConditionalTokenTypes::INTEGER); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::FLOAT){ + ackToken(ConditionalTokenTypes::FLOAT); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::BOOL){ + ackToken(ConditionalTokenTypes::BOOL); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::STRING){ + ackToken(ConditionalTokenTypes::STRING); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + + }else if(token->getType() == ConditionalTokenTypes::NODE){ + ackToken(ConditionalTokenTypes::NODE); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + + }else if(token->getType() == ConditionalTokenTypes::LAMBDA){ + return constructAstLambda(); + } + + throw std::runtime_error("ConditionalParser unknow val type "+ token->rep().str() + "\n" + mLexer.rep()); + +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstLambda(void){ + /* + AstLambda : LAMBDA val (ARGSEP val)* RPAREN + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> tokenLdb = mCurrentToken->copy(); + ackToken(ConditionalTokenTypes::LAMBDA); + ASTNodeCh paramLambda; + //AT LEAST ONE VALUE AS INPUT OF A LAMBDA + paramLambda.push_back(constructAstVal()); + while (mCurrentToken->getType() != ConditionalTokenTypes::RPAREN) + { + ackToken(ConditionalTokenTypes::ARGSEP); + paramLambda.push_back(constructAstVal()); + } + ackToken(ConditionalTokenTypes::RPAREN); + return std::make_shared<AstNode<ConditionalTokenTypes>>(tokenLdb,paramLambda); +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstCmpr(void){ + /* + cmpr : val (EQ|NEQ) val | LPAREN expr RPAREN + NOT ir ? + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + //we can check the type relation ir key (EQ|NEQ) val | val (EQ|NEQ) key , but val (EQ|NEQ) val is valid ? + if (token->getType() == ConditionalTokenTypes::LPAREN) + { + ackToken(ConditionalTokenTypes::LPAREN); + std::shared_ptr<AstNode<ConditionalTokenTypes>> node = constructAstExpr(); + ackToken(ConditionalTokenTypes::RPAREN); + return node; + }else{ + + std::shared_ptr<AstNode<ConditionalTokenTypes>> node = constructAstVal(); + token = mCurrentToken->copy(); + if (token->getType() == ConditionalTokenTypes::EQ){ + ackToken(ConditionalTokenTypes::EQ); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{node,constructAstVal()}); + }else if(token->getType() == ConditionalTokenTypes::NEQ){ + ackToken(ConditionalTokenTypes::NEQ); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{node,constructAstVal()}); + }else{ + + throw std::runtime_error("constructAstCmpr "+ token->rep().str() + "\n" + mLexer.rep()); + } + + } +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstExpr(std::size_t precLimit /*= 0*/){ + /* + expr : cmpr ((AND | OR) cmpr)* + the NOT is not binary OP can be use in pratt + precedence H to L: TODO + AND + OR + */ + + //the not + std::shared_ptr<AstNode<ConditionalTokenTypes>> left; + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + + if (mCurrentToken->getType() == ConditionalTokenTypes::NOT ){ + ackToken(ConditionalTokenTypes::NOT ); + left= std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{constructAstCmpr()}); + }else{ + left= constructAstCmpr(); + } + + //pratt + while (mCurrentToken->getType() != ConditionalTokenTypes::STOP ) //security + { + token = mCurrentToken->copy(); + //if the token is not in the map is not a operator so we consider a prec of 0 + if (ConditionalPrec.find(token->getType()) ==ConditionalPrec.end() ){ + return left; + } + + //if my actual operator have a prec <= of the last operator + std::size_t prec = ConditionalPrec.at(token->getType()); + if (prec <= precLimit){ + return left; + } + + //Act all AND and OR + ackToken(token->getType()); + + std::shared_ptr<AstNode<ConditionalTokenTypes>> right = constructAstExpr(prec); + + //i'm not sur what append to newNode + //std::shared_ptr<AstNode<ConditionalTokenTypes>> newNode = std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{left,constructAstCmpr()}); + std::shared_ptr<AstNode<ConditionalTokenTypes>> newNode = std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{left,right}); + left = newNode; + } + return left; +} + + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::parse(void){ + /* + expr : cmpr ((AND | OR) cmpr)* + cmpr : val (EQ|NEQ) val | LPAREN expr RPAREN | BOOL | LAMBDA + val : (KEY|INTEGER|FOAT|STRING|LAMBDA) + lambda : LAMBDA val (ARGSEP val)* RPAREN + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> astTree = constructAstExpr(); + + rstParser(); + return astTree; +} \ No newline at end of file diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1f58c68686d9359fa3b8ea4b5eb54244e988895 --- /dev/null +++ b/src/operator/MetaOperator.cpp @@ -0,0 +1,141 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/utils/ErrorHandling.hpp" + +Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, + std::vector<NodePtr> inputNodes, + std::vector<NodePtr> outputNodes) + : Operator(type), + mGraph(graph) +{ + mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); + for (std::size_t i = 0; i < mInputs.size(); ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->outputs().size()); + for (std::size_t i = 0; i < mOutputs.size(); ++i) { + mOutputs[i] = std::make_shared<Tensor>(); + } + + // Fill inputsNodes and outputsNodes when there is no ambiguity + if (inputNodes.empty()) { + AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping"); + inputNodes.push_back(*mGraph->inputNodes().begin()); + } + + if (outputNodes.empty()) { + AIDGE_ASSERT(mGraph->outputNodes().size() == 1, "need to specify internal nodes output mapping"); + outputNodes.push_back(*mGraph->outputNodes().begin()); + } + + AIDGE_ASSERT(mGraph->inputNodes().size() == inputNodes.size(), "wrong number of specified input nodes"); + AIDGE_ASSERT(mGraph->outputNodes().size() == outputNodes.size(), "wrong number of specified output nodes"); + + // Identify inputs that are outside the micro-graph + for (const auto& inputNode : inputNodes) { + AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph"); + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + inputNode->inputs(); + + int inputIdx = 0; // input idx relative to the current node + for (const auto& in : inputNodeinputs) { + if (in.first == nullptr || !mGraph->inView(in.first)) { + // The input is not connected inside the micro-graph + // (no connection to this input or connection outside the micro-graph) + // => it is therefore an input for the meta-operator + mInputOps.push_back(std::make_pair(inputNode->getOperator(), inputIdx)); + } + + ++inputIdx; + } + } + + // The outputs of the output nodes are also the outputs of the meta-operator + for (const auto& outputNode : outputNodes) { + AIDGE_ASSERT(mGraph->inView(outputNode), "output node must be in the graph"); + const std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> outputNodeoutputs = + outputNode->outputs(); + + for (size_t outputIdx = 0; outputIdx < outputNodeoutputs.size(); ++outputIdx) { + mOutputOps.push_back(std::make_pair(outputNode->getOperator(), outputIdx)); + } + } + + AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size()); + AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size()); +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { + if (mImpl) { + return mImpl->getNbRequiredData(inputIdx); + } + else { + const auto& inputOp = mInputOps[inputIdx]; + return inputOp.first->getNbRequiredData(inputOp.second); + } +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { + if (mImpl) { + return mImpl->getNbConsumedData(inputIdx); + } + else { + const auto& inputOp = mInputOps[inputIdx]; + return inputOp.first->getNbConsumedData(inputOp.second); + } +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const { + if (mImpl) { + return mImpl->getNbProducedData(outputIdx); + } + else { + const auto& outputOp = mOutputOps[outputIdx]; + return outputOp.first->getNbProducedData(outputOp.second); + } +} + +void Aidge::MetaOperator_Op::updateConsummerProducer() { + if (mImpl) { + mImpl->updateConsummerProducer(); + } + else { + if (!mScheduler) { + // Lazy initialization + mScheduler = std::make_shared<SequentialScheduler>(mGraph); + } + + // TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule. + // It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()" + mScheduler->generateScheduling(); + } +} + +void Aidge::MetaOperator_Op::forward() { + if (mImpl) { + // A custom implementation exists for this meta operator + mImpl->forward(); + } + else { + // No custom implementation, use the individual operators implementations + if (!mScheduler) { + // Lazy initialization + // TODO: should we assert that a scheduler already exists at this point? + // => should be created in updateConsummerProducer() + mScheduler = std::make_shared<SequentialScheduler>(mGraph); + mScheduler->generateScheduling(); + } + + mScheduler->forward(false); + } +} diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index 0ceec368db03cd6cd8387cb6081868c1d71a23ef..f06e88d3d76166696ca15c7ed8eec962ada74592 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -79,10 +79,10 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n"); } - const DimSize_t channelsSize = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("InChannels"); + const DimSize_t channelsSize = std::dynamic_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("InChannels"); // TODO : suppose we have Conv2D ... - const std::array<DimSize_t, 2> kernelDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<std::array<DimSize_t, 2>>("KernelDims"); + const std::array<DimSize_t, 2> kernelDims = std::dynamic_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<std::array<DimSize_t, 2>>("KernelDims"); std::shared_ptr<Tensor> weight = conv->input(1).first->getOperator()->getOutput(conv->input(1).second); std::shared_ptr<Tensor> bias = conv->input(2).first->getOperator()->getOutput(conv->input(2).second); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 4dc8eb5c84ddb25546a32a672bdc84685a6f79f0..1f34091e54c0f83dae6b60589c20fb8fdf1d5064 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -40,13 +40,10 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // TODO: optimize memory usage // setup initial producers list - mComputationNumber = 0; std::set<std::shared_ptr<Node>> producers; for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { if (nodePtr->type() == "Producer") { producers.insert(nodePtr); - } else { - ++mComputationNumber; } } // add Data Input @@ -112,6 +109,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // Push consumers in the list of nodes to run and update the consumer producer system for (const auto& runnable : runnableConsumers) { + if (verbose) printf("Runnable: %s\n", (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); runnable->getOperator()->updateConsummerProducer(); mStaticSchedule.push_back(runnable); } @@ -177,14 +175,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // TODO: handle multiple inputs/outputs void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { + // Forward dims (if allowed) if (forwardDims) {mGraphView->forwardDims(); } - // add each Producer Node. - std::set<std::shared_ptr<Node>> computationOver; + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + // Clear previous scheduling results mScheduling.clear(); - this->generateScheduling(); int cpt = 0; for (const auto& runnable : mStaticSchedule) { if (verbose) @@ -202,12 +205,11 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { } if (!verbose) drawProgressBar(1.0, 50, " "); printf("\n"); - } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w"); - std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%s ms\n\n"); + std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n"); if (!mScheduling.empty()) { const auto globalStart = mScheduling[0].start; diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 4b929286ba494a452c7f9cb71ce944c7d576c03a..9f014364636c70031b522b09c893e1144af3f133 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -161,7 +161,7 @@ TEST_CASE("[core/graph] GraphView(addChild)") { TEST_CASE("[core/graph] GraphView(inputs)") { auto g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); - g1->add(conv); + g1->add(conv, false); REQUIRE(g1->inputs() == conv->inputs()); } diff --git a/unit_tests/graph/Test_get.cpp b/unit_tests/graph/Test_get.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afd1f42ee9f5d6cd668dd5cab82172cdc298e149 --- /dev/null +++ b/unit_tests/graph/Test_get.cpp @@ -0,0 +1,55 @@ + + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" + + +using namespace Aidge; +TEST_CASE("get Delta") { + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + g1->addChild(conv1, "c"); + + + std::set<Aidge::NodePtr> see; +conv->getNodeDelta(1,see); + + SECTION("Self return") { + see.clear(); + REQUIRE(conv->getNodeDelta(0,see) == std::set<std::shared_ptr<Node>>{conv}); + } + + + SECTION("child") { + see.clear(); + REQUIRE(conv->getNodeDelta(1,see) == std::set<std::shared_ptr<Node>>{conv1}); + } + + +} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_GRegex.cpp b/unit_tests/graphMatching/Test_GRegex.cpp index 7184fad76a921239753d4752ae1a4a61bf3aec16..2c5907d82e7c5b1d32f1fb38493c7333b68f8731 100644 --- a/unit_tests/graphMatching/Test_GRegex.cpp +++ b/unit_tests/graphMatching/Test_GRegex.cpp @@ -53,6 +53,10 @@ TEST_CASE("Create good init GRegex", "[GRegex]") { // Perform tests REQUIRE(GReg.getStmInit().size() == 1); REQUIRE(GReg.getStmFab().getNumberOfStm() == 1); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } @@ -101,6 +105,10 @@ TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex // Perform tests REQUIRE(result == true_result); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") { @@ -166,6 +174,10 @@ TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GR // Perform tests REQUIRE(result == true_result); REQUIRE(wrong_start_result == empty_result); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } /* diff --git a/unit_tests/graphMatching/Test_SeqStm.cpp b/unit_tests/graphMatching/Test_SeqStm.cpp index baabbbc3c10ec751c64a65ab01c2c4d502f58cb5..db8662e3329abe153d4a0fb2b3c46b950208d6bc 100644 --- a/unit_tests/graphMatching/Test_SeqStm.cpp +++ b/unit_tests/graphMatching/Test_SeqStm.cpp @@ -79,6 +79,10 @@ TEST_CASE("Create good init SeqStm", "[SeqStm]") { REQUIRE(stm.getAllCommonNode().size() == 0); REQUIRE(stm.getAllNodeTested().size() == 0); REQUIRE(stm.getAllNodeValidated().size() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test testNode function", "[SeqStm]") { @@ -156,4 +160,8 @@ TEST_CASE("Test testNode function", "[SeqStm]") { REQUIRE(stm.isStmBlocked() == true); REQUIRE(stm.getAllNodeTested() == testAllNodeTested); REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_StmFactory.cpp b/unit_tests/graphMatching/Test_StmFactory.cpp index b595372fd97a56f2ecf2575429c63db92484bbc0..3c66d0fa817cea674de5ab849091290c976e5735 100644 --- a/unit_tests/graphMatching/Test_StmFactory.cpp +++ b/unit_tests/graphMatching/Test_StmFactory.cpp @@ -36,6 +36,10 @@ TEST_CASE("Create good init StmFactory", "[StmFactory]") { } StmFactory stmF(nodesRegex); REQUIRE(stmF.getNumberOfStm() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { @@ -66,6 +70,10 @@ TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { //test the number of stm REQUIRE(stmF.getNumberOfStm() == 2); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { @@ -123,6 +131,9 @@ TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { REQUIRE(stm->getAllNodeTested() == testAllNodeTested); REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } @@ -185,5 +196,9 @@ TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") { REQUIRE(stmD->isStmBlocked() == false); REQUIRE(stmD->getAllNodeTested().size() == 0); REQUIRE(stmD->getAllNodeValidated().size() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } diff --git a/unit_tests/graphRegex/Test_Fsm.cpp b/unit_tests/graphRegex/Test_Fsm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5950f21b323f07b380ae95f70637ca48a173481 --- /dev/null +++ b/unit_tests/graphRegex/Test_Fsm.cpp @@ -0,0 +1,195 @@ +#include <memory> + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + +TEST_CASE("matchFSM", "FsmEdge") { + + + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest); + + SECTION("FsmEdgeUnique constructor") { + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + SECTION("FsmEdgeCommon constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeCommon EdgeToTest(nodeA,nodeB,toTest,"A"); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == true); + } + + SECTION("FsmEdgeRef constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeRef EdgeToTest(nodeA,nodeB,0,-1); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + SECTION("FsmEdgeEmpty constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeEmpty EdgeToTest(nodeA,nodeB); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + + SECTION("FsmEdgeFactory"){ + + std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("true==true")}, + {"B",std::make_shared<ConditionalInterpreter>("true==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + +// make(std::shared_ptr<FsmNode> source, std::shared_ptr<FsmNode> dest, +// FsmEdgeTypes type,std::map<std::string, const std::shared_ptr<ConditionalInterpreter>> allTest, +// const std::string& lexeme = ""); + + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(true,false); +// EMPTY = 0, +// REF, +// COMMON, +// UNIQUE + + std::shared_ptr<FsmEdge> edgeE = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::EMPTY,allTest,""); + std::shared_ptr<FsmEdge> edgeU = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::UNIQUE,allTest,"A"); + std::shared_ptr<FsmEdge> edgeC = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::COMMON,allTest,"A#"); + std::shared_ptr<FsmEdge> edgeR = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::REF,allTest,"(0,1)"); + + //test detection of bad syntax lexem + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::EMPTY,allTest,"A")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::UNIQUE,allTest,"A#")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::COMMON,allTest,"A")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::REF,allTest,"A")); + + REQUIRE(edgeE->getSourceNode() == nodeA); + REQUIRE(edgeE->getDestNode() == nodeB); + } + + SECTION("graph constructor") { + //make the nodes + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> nodeC = std::make_shared<FsmNode>(false,true); + + //make the edges + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); + + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + + graph->addEdge(edgeAB); + graph->addEdge(edgeBC); + + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeC}); + } + + + SECTION("graph merge") { + + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + //make the nodes + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> nodeC = std::make_shared<FsmNode>(true,false); + + //make the edges + + std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); + + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + graph->addEdge(edgeAB); + graph->addEdge(edgeBC); + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{nodeC}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA,nodeB,nodeC}); + + //make the nodes + std::shared_ptr<FsmNode> node2A = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> node2B = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> node2C = std::make_shared<FsmNode>(true,false); + + + std::shared_ptr<FsmEdge> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest); + std::shared_ptr<FsmEdge> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest); + + std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>(); + + + graph2->addEdge(edge2AB); + graph2->addEdge(edge2BC); + + + REQUIRE(graph2->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{node2C}); + REQUIRE(graph2->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{node2A}); + REQUIRE(graph2->getNodes() == std::set<std::shared_ptr<FsmNode>>{node2A,node2B,node2C}); + + + graph->mergeOneStartOneValid(graph2); + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{node2C}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA,nodeB,nodeC,node2B,node2C}); + } + + + + +} + +// TEST_CASE("matchFSM", "FsmGraph") { + +// SECTION("FsmEdgeUnique constructor") { +// //make the nodes +// std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); +// std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + +// //make the edges +// std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); +// std::shared_ptr<FsmEdgeUnique> edge = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + +// std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + +// graph->addEdge(edge); + + + +// } + +// } \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_FsmMatch.cpp b/unit_tests/graphRegex/Test_FsmMatch.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1fe75be1a47033f75af7ccc4dc5202774444cd10 --- /dev/null +++ b/unit_tests/graphRegex/Test_FsmMatch.cpp @@ -0,0 +1,89 @@ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" + + +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("FsmMatch") { + + SECTION("Construction") { + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"B",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + + allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest["B"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->A",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + + + //REQUIRE(fsm->getNodes().size() == 3); + //REQUIRE(fsm->getStartNodes().size() == 1); + + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + + g1->add(conv); + g1->addChild(conv1, "c"); + + + REQUIRE(allTest["A"]->test(conv) == true); + REQUIRE(allTest["B"]->test(conv) == true); + + std::vector<std::shared_ptr<Node>> startNodes = {conv}; + + auto result = fsm->test(startNodes); + + REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1}); + } + + + SECTION("2 branche graph"){ + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Fc", 1, 1, 1, "c2"); + + g1->add(conv); + g1->addChild(conv1,conv); + g1->addChild(conv2,conv); + + REQUIRE(g1->getNodes() == std::set<std::shared_ptr<Node>>({conv,conv1,conv2})); + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>({conv})); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>({conv1,conv2})); + + + ///////////// + + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"B",std::make_shared<ConditionalInterpreter>("isFc($)==true")} + }; + allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest["B"]->insertLambda("isFc",+[](NodePtr NodeOp){return NodeOp->type() == "Fc";}); + + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A#->A; A#->B",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + std::vector<std::shared_ptr<Node>> startNodes = {conv,conv}; + auto result = fsm->test(startNodes); + REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1,conv2}); + + } + +} diff --git a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ce090506c9a61abd928b3ae590ee838afb05999 --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp @@ -0,0 +1,42 @@ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("GraphFsmInterpreter", "GraphFsmInterpreter") { + + SECTION("Construction") { + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("true==true")}, + {"B",std::make_shared<ConditionalInterpreter>("true==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + + //GraphFsmInterpreter("A->B",allTest); + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->B",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + + REQUIRE(fsm->getNodes().size() == 3); + REQUIRE(fsm->getStartNodes().size() == 1); + REQUIRE(fsm->getEdge().size() == 2); + + for(auto node : fsm->getNodes()){ + if(node->isValid()){ + REQUIRE(node->getEdges().size() == 0); + }else{ + REQUIRE(node->getEdges().size() == 1); + } + + } + + + } + + + + + +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_GraphLexer.cpp b/unit_tests/graphRegex/Test_GraphLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b8cc8e018546ebfe3f84202d9404db27b17449b --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphLexer.cpp @@ -0,0 +1,118 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphLexer.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +#include "aidge/utilsParsing/ParsingToken.hpp" + + +#include <iostream> +#include <map> +#include <functional> + +using namespace Aidge; + +// NEXT +// QOM +// QZM +// KEY +// CKEY +// SEP +// LPAREN +// RPAREN + +TEST_CASE("GraphRegex", "Lexer") { + SECTION("RandomGenerateTest") { + + std::map<gRegexTokenTypes, std::function<std::pair<std::string, std::string>()>> LexerTestMap{ + {gRegexTokenTypes::NEXT, +[](){return std::pair<std::string, std::string>("-> ","");}}, + {gRegexTokenTypes::QOM, +[](){return std::pair<std::string, std::string>("+ ","");}}, + {gRegexTokenTypes::QZM, +[](){return std::pair<std::string, std::string>("* ","");}}, + {gRegexTokenTypes::SEP, +[](){return std::pair<std::string, std::string>("; ","");}}, + + + + {gRegexTokenTypes::KEY, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {gRegexTokenTypes::CKEY, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_"; + const std::string num = "1234567890"; + + std::size_t randomIndex = std::rand() % characters.size(); + std::size_t randomNum = std::rand() % num.size(); + std::string key; + std::string idx; + + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + idx += num[randomNum]; + randomIndex = std::rand() % characters.size(); + randomNum = std::rand() % num.size(); + } + + return std::pair<std::string, std::string>(key+"#"+idx+" ",key+"#"+idx);} + }, + + {gRegexTokenTypes::LPAREN, +[](){return std::pair<std::string, std::string>("( ","");}}, + {gRegexTokenTypes::RPAREN, +[](){return std::pair<std::string, std::string>(") ","");}} + //{gRegexTokenTypes::STOP, +[](){return std::pair<std::string, std::string>("","");}} + }; + + + ////////////////// + //TEST GENERATOR + ////////////////// + const std::size_t numRandomElements = 10000; + std::vector<std::tuple<gRegexTokenTypes, std::string>> testVector; + + std::string testString; + + for (std::size_t i = 0; i < numRandomElements; ++i) { + + int randomIndex = std::rand() % LexerTestMap.size(); + // Get an iterator to the random element in the map + auto it = std::next(LexerTestMap.begin(), randomIndex); + // Access the random key and lambda value separately using structured binding + gRegexTokenTypes randomKey = it->first; + + std::function<std::pair<std::string, std::string>()> randomValue = it->second; + std::pair<std::string, std::string> result = randomValue(); + + testString += result.first; + testVector.emplace_back(randomKey, result.second); + + + } + + GraphLexer graphLexer = GraphLexer(testString); + + for (std::tuple<gRegexTokenTypes, std::string> testToken : testVector) { + gRegexTokenTypes tokenToFind = std::get<0>(testToken); + std::string lexemToFind = std::get<1>(testToken); + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = graphLexer.getNextToken(); + + + std::ostringstream errorMessage; + errorMessage << "\n we whant :"<< lexemToFind << "\n we get : "<< token->getLexeme() <<"\n"<< "on \n" << testString << " :\n " ; + + CAPTURE(errorMessage.str()); + REQUIRE(token->getLexeme() == lexemToFind); + REQUIRE(token->getType() == tokenToFind); + } + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = graphLexer.getNextToken(); + REQUIRE(token->getType() == gRegexTokenTypes::STOP); + } + + +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_GraphParser.cpp b/unit_tests/graphRegex/Test_GraphParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..857caa06f4e5fa383e79ea22bfe1ca28ac0973c8 --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphParser.cpp @@ -0,0 +1,82 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/utilsParsing/AstNode.hpp" +#include <iostream> + + +using namespace Aidge; + + //generative function , + std::string domain(); + std::string exp() { + int randomValue = std::rand() % 3; + switch (randomValue) { + case 0: + return "A"; + case 1 : + return "A#"; + default: + return domain(); + + } + } + + std::string seq() { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return exp(); + default: + return exp()+"->"+seq(); + } + } + + std::string domain() { + int randomValue = std::rand() % 2; + + switch (randomValue) { + // case 0: + // return seq(); + // case 1: + // return seq() + "->" +domain(); + + case 0: + return "("+ seq() +")*"; + default: + return "("+ seq() +")+"; + + // case 4: + // return "("+ domain() +")*" + "->" +domain(); + // default: + // return "("+ domain() +")+" + "->" +domain(); + } + } + + std::string allExpr() { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return seq(); + default : + return seq()+ ";" +allExpr(); + } + } + +/* +exp : KEY(QOM | QZM)? | CKEY | domain +seq :exp (NEXT seq)* +domain : LPAREN seq RPAREN (QOM | QZM) +allExpr: seq (SEP allExpr)* +*/ +TEST_CASE("GraphParser", "Test_GraphParser") { + + SECTION("Empty") { + for (int i = 0; i < 100; ++i) { + const std::string test = allExpr(); + std::cout << test <<"\n"; + GraphParser graphParser = GraphParser(test); + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = graphParser.parse(); + } + } +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_graphRegexAST.cpp b/unit_tests/graphRegex/Test_graphRegexAST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cdb0bc1934983a26ab742bfe8879455077219cc --- /dev/null +++ b/unit_tests/graphRegex/Test_graphRegexAST.cpp @@ -0,0 +1,71 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphStrInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("GraphStrInterpreter") { + + + + std::vector<std::string> tests = { + + + //sequ + "A;", + "A->B", + "A->B->C", + //seq and common + "A#", + "A#->B", + "A#->B#", + "A#->B#->C", + "A#->B#->C#", + "A->B#->C", + //sequ quantif + + "A+", + "A+->B+", + "A->B+->C", + //sequ quantif * + "A*", + "A*->B*", + "A->B*->C", + + //sequ quantif + "A*", + "A*->B+", + "A+->B*->C", + //others + + "(A#->B->C#)+", + "(A#->B)+;A#->B->C", + "B+->B->B", + "B#->R*", + "(B#->R)*", + "A->C->B#->B;B#->R", + "B#->R", + "A->C#;A->C#;A->C#;A->C#;A->C#;A->C#", + "B#->R;B#->R", + "A# -> C -> B#; B#->A#", + + // Add more test cases here + }; + + SECTION("AST Regex bijection") { + + for (const std::string& test : tests) { + std::shared_ptr<GraphStrInterpreter> strGenerator = std::make_shared<GraphStrInterpreter>(test); + std::string astString = strGenerator->interpret(); + //supress space in the test becase erase in the AST + std::string testNoS = test; + testNoS.erase(std::remove_if(testNoS.begin(), testNoS.end(), ::isspace), testNoS.end()); + //if the last char is ; (SEP) it will not in the AST and it's not a bug erase it + if (!testNoS.empty() && testNoS.back() == ';') { + // Remove the last character + testNoS.pop_back(); + } + //test + REQUIRE(astString == testNoS); + } + + } +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b502fb546e2f1396b629ebc78bc1bd4d67842e2 --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp @@ -0,0 +1,66 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalInterpreter.hpp" +#include "aidge/operator/GenericOperator.hpp" + + +using namespace Aidge; + + + +TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { + + SECTION("custom Lambda") { + + const std::string test = " !toto($) == true " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + conditionalParser.insertLambda("toto",+[](NodePtr NodeOp){return false;}); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + + SECTION("syntax error") { + + const std::string test = "'A' == 'A' ,&& "; + REQUIRE_THROWS_AS( ConditionalInterpreter(test), std::runtime_error); + + } + + + SECTION("test false int ") { + + const std::string test = " 10 == 11 " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == false); + } + + SECTION("test true int ") { + const std::string test = " 42 == 42 " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + + SECTION("test false str ") { + const std::string test = " 'toto' == 'Corgi' " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == false); + } + + SECTION("test true str ") { + + const std::string test = " 'Corgi' == 'Corgi' " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalLexer.cpp b/unit_tests/nodeTester/Test_ConditionalLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a937c27227dde4fa03ed7733df9e9552c3c1ac7b --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalLexer.cpp @@ -0,0 +1,144 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalLexer.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" + +#include <iostream> +#include <map> +#include <functional> + +using namespace Aidge; + +TEST_CASE("nodeTester", "Lexer") { + SECTION("RandomGenerateTest") { + + std::map<ConditionalTokenTypes, std::function<std::pair<std::string, std::string>()>> LexerTestMap{ + {ConditionalTokenTypes::AND, +[](){return std::pair<std::string, std::string>("&& ","");}}, + {ConditionalTokenTypes::OR, +[](){return std::pair<std::string, std::string>("|| ","");}}, + {ConditionalTokenTypes::EQ, +[](){return std::pair<std::string, std::string>("== ","");}}, + {ConditionalTokenTypes::NEQ, +[](){return std::pair<std::string, std::string>("!= ","");}}, + + {ConditionalTokenTypes::KEY, +[](){return std::pair<std::string, std::string>("A ","A");}}, + + {ConditionalTokenTypes::BOOL, +[](){ + std::size_t keyLen = (std::rand() % 2); + const std::vector<std::string> characters = {"true","false"}; + + return std::pair<std::string, std::string>(characters[keyLen]+" ",characters[keyLen]);} + }, + + {ConditionalTokenTypes::INTEGER, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "1234567890"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {ConditionalTokenTypes::FLOAT, +[](){ + std::size_t keyLen = (std::rand() % 20)+2; + const std::string characters = "1234567890"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen/2; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + key += "."; + for (std::size_t i = 0; i < keyLen/2; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {ConditionalTokenTypes::STRING, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890 "; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>("'"+key+"' ",key);} + }, + + {ConditionalTokenTypes::LAMBDA, +[](){ + + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + const std::string Startchar = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + std::size_t randomIndex = std::rand() % characters.size(); + std::size_t randomStartIndex = std::rand() % Startchar.size(); + + std::string key; + key += Startchar[randomStartIndex]; + + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>(key+"( ",key);} + }, + + {ConditionalTokenTypes::ARGSEP, +[](){return std::pair<std::string, std::string>(", ","");}}, + {ConditionalTokenTypes::NODE, +[](){return std::pair<std::string, std::string>("$ ","");}}, + {ConditionalTokenTypes::LPAREN, +[](){return std::pair<std::string, std::string>("( ","");}}, + {ConditionalTokenTypes::RPAREN, +[](){return std::pair<std::string, std::string>(") ","");}} + //{ConditionalTokenTypes::STOP, +[](){return std::pair<std::string, std::string>("","");}} + }; + + + ////////////////// + //TEST GENERATOR + ////////////////// + const std::size_t numRandomElements = 100; + std::vector<std::tuple<ConditionalTokenTypes, std::string>> testVector; + + std::string testString; + + for (std::size_t i = 0; i < numRandomElements; ++i) { + + int randomIndex = std::rand() % LexerTestMap.size(); + // Get an iterator to the random element in the map + auto it = std::next(LexerTestMap.begin(), randomIndex); + // Access the random key and lambda value separately using structured binding + ConditionalTokenTypes randomKey = it->first; + + std::function<std::pair<std::string, std::string>()> randomValue = it->second; + std::pair<std::string, std::string> result = randomValue(); + + testString += result.first; + testVector.emplace_back(randomKey, result.second); + + + } + + ConditionalLexer conditionalLexer = ConditionalLexer(testString); + + for (std::tuple<ConditionalTokenTypes, std::string> testToken : testVector) { + ConditionalTokenTypes tokenToFind = std::get<0>(testToken); + std::string lexemToFind = std::get<1>(testToken); + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = conditionalLexer.getNextToken(); + + + std::ostringstream errorMessage; + errorMessage << "\n we whant :"<< lexemToFind << "\n we get : "<< token->getLexeme() <<"\n"<< "on \n" << testString << " :\n " ; + + CAPTURE(errorMessage.str()); + REQUIRE(token->getLexeme() == lexemToFind); + REQUIRE(token->getType() == tokenToFind); + } + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = conditionalLexer.getNextToken(); + REQUIRE(token->getType() == ConditionalTokenTypes::STOP); + } + + +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalParser.cpp b/unit_tests/nodeTester/Test_ConditionalParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56adb92b41745001e1790e087f07369918794c5d --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalParser.cpp @@ -0,0 +1,75 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalParser.hpp" +#include "aidge/utilsParsing/AstNode.hpp" + +using namespace Aidge; + + std::string gVal() { + int randomValue = std::rand() % 5; + switch (randomValue) { + case 0: + return std::to_string(std::rand() % 101); + + case 1: + return std::to_string(std::rand() % 101)+"."+std::to_string(std::rand() % 101); + + case 2: + return " 'toto' "; + case 3: + return " A "; + + case 4: + return " A(10) "; + + default: + return " true "; + + } + } + + std::string gExpr() ; + std::string gCmpr() { + int randomValue = std::rand() % 3; + switch (randomValue) { + case 0: + return gVal() + " == " +gVal(); + case 1: + return "("+ gExpr() +")"; + default: + return gVal() + " != " +gVal(); + + } + + + return gVal() + " == " +gVal(); + } + + std::string gExpr() { + std::string out = gCmpr(); + int iterations = std::rand() % 100; + for (int i = 0; i < iterations; ++i) { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return out +" && " + gCmpr(); + break; + default: + return out +" || " + gCmpr(); + break; + } + } + return out; + } + + +TEST_CASE("ConditionalParser", "ConditionalParser") { + + SECTION("Empty") { + for (int i = 0; i < 100; ++i) { + const std::string test = gExpr(); + ConditionalParser conditionalParser = ConditionalParser(test); + std::shared_ptr<AstNode<ConditionalTokenTypes>> tree = conditionalParser.parse(); + } + } +} \ No newline at end of file diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e9718fc694d29713797565d3ae8c8107cc7612de --- /dev/null +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -0,0 +1,53 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/graph/GraphView.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[core/operators] MetaOperator", "[Operator]") { + SECTION("PaddedConv") { + auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {{{1, 1}, {1, 1}}}); + + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraph(); + + REQUIRE(microGraph->getNodes().size() == 2); + REQUIRE(microGraph->inputNodes().size() == 2); // 2 because Conv has inputs outside the meta-op (Producers for weight and bias) + // Order not garanteed by the GraphView + //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad"); + REQUIRE(microGraph->outputNodes().size() == 1); + REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); + REQUIRE(op->nbInputs() == 3); + REQUIRE(op->nbDataInputs() == 1); + REQUIRE(op->nbOutputs() == 1); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(); + myInput->resize({2,3,5,5}); + op->getOperator()->associateInput(0,myInput); + op->getOperator()->computeOutputDims(); + + REQUIRE(op->getOperator()->outputDimsForwarded()); + REQUIRE(op->getOperator()->getOutput(0)->dims() == std::vector<size_t>({2,3,5,5})); + REQUIRE(op->getOperator()->getInput(0) == myInput); + // Order not garanteed by the GraphView + //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->getInput(0) == myInput); + REQUIRE(op->getOperator()->getOutput(0) == (*microGraph->outputNodes().begin())->getOperator()->getOutput(0)); + + //op->getOperator()->updateConsummerProducer(); // require implementation + //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); + //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2); + } +}