diff --git a/.gitlab/ci/build.gitlab-ci.yml b/.gitlab/ci/build.gitlab-ci.yml index cd56a55fa7e9cbcefba4715188fd270462e81976..a4579e2951ccbafc4335ae428c62eba94c0757e5 100644 --- a/.gitlab/ci/build.gitlab-ci.yml +++ b/.gitlab/ci/build.gitlab-ci.yml @@ -17,6 +17,66 @@ build:ubuntu_cpp: - build_cpp/ - install_cpp/ +build:ubuntu_cpp_g++10: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y g++-10 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/g++-10 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_g++12: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y g++-12 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/g++-12 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_clang12: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y clang-12 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/clang++-12 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + +build:ubuntu_cpp_clang15: + stage: build + needs: [] + tags: + - docker + + script: + - apt install -y clang-15 + - mkdir -p build_cpp + - mkdir -p install_cpp + - cd build_cpp + - export CXX=/usr/bin/clang++-15 + - cmake -DCMAKE_INSTALL_PREFIX:PATH=../install_cpp -DCMAKE_BUILD_TYPE=Debug -DWERROR=ON -DCOVERAGE=ON .. + - make -j4 all install + build:ubuntu_python: stage: build needs: [] @@ -28,8 +88,7 @@ build:ubuntu_python: - virtualenv venv - source venv/bin/activate # Numpy dependancy for unit test - - python3 -m pip install numpy - - export AIDGE_INSTALL=`pwd`/install + - python3 -m pip install -r requirements.txt - python3 -m pip install . artifacts: expire_in: 1 week @@ -65,3 +124,31 @@ build:windows_cpp: paths: - build_cpp/ - install_cpp/ + +build:windows_python: + stage: build + needs: [] + tags: + - windows + + image: buildtools + before_script: + # Install Chocolatey + - Set-ExecutionPolicy Bypass -Scope Process -Force; [System.Net.ServicePointManager]::SecurityProtocol = [System.Net.ServicePointManager]::SecurityProtocol -bor 3072; iex ((New-Object System.Net.WebClient).DownloadString('https://community.chocolatey.org/install.ps1')) + # Install dependencies + - choco install cmake.install --installargs '"ADD_CMAKE_TO_PATH=System"' -Y + - choco install git -Y + - choco install python -Y + # Update PATH + - $env:Path = [System.Environment]::GetEnvironmentVariable("Path","Machine") + ";" + [System.Environment]::GetEnvironmentVariable("Path","User") + script: + - python -m pip install virtualenv + - virtualenv venv + - venv\Scripts\Activate.ps1 + # Numpy dependancy for unit test + - python -m pip install -r requirements.txt + - python -m pip install . + artifacts: + expire_in: 1 week + paths: + - venv/ diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 47ded2a462477958320bfad3ad84e6b8f6ef6082..e708c168421216fa249f26eee1f2b2eb80b588fd 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -35,8 +35,10 @@ #include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/MatMul.hpp" #include "aidge/operator/MaxPooling.hpp" -//#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" #include "aidge/operator/Operator.hpp" +#include "aidge/operator/Pad.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/ReLU.hpp" #include "aidge/operator/Softmax.hpp" diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 7422a52eb171ee6dae0e14ad67c0562295fe5d8c..58c434bccc7c8dd39a93c46ecf74c38d7d834d1a 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -477,13 +477,14 @@ class Tensor : public Data, if (dims().empty()) { return "{}"; } std::string res; std::size_t dim = 0; - std::size_t *dimVals = new std::size_t[nbDims()]; - for (std::size_t i = 0; i < nbDims(); ++i) { - dimVals[i] = 0; - } std::size_t counter = 0; - res += "{\n"; - if (nbDims()>=2){ + if (nbDims()>=2) { + std::size_t *dimVals = new std::size_t[nbDims()]; + for (std::size_t i = 0; i < nbDims(); ++i) { + dimVals[i] = 0; + } + // std::vector<std::size_t> dimVals = std::vector<std::size_t>(nbDims(), 0); + res += "{\n"; while (counter < mSize) { std::string spaceString = std::string((dim+1)<<1,' '); if (dim < nbDims()-2) { @@ -532,31 +533,35 @@ class Tensor : public Data, } res += "\n"; } + if (dim == 0) { + break; + } dimVals[dim--] = 0; dimVals[dim]++; } } - for(int i = static_cast<int>(dim); i>=0; --i) { + delete[] dimVals; + + for(int i = static_cast<int>(dim); i > 0; --i) { res += std::string((dim+1)<<1,' ') + "}\n"; } - }else{ + } else { + res += "{"; for (DimSize_t j = 0; j < dims()[0]; ++j) { switch (mDataType) { case DataType::Int32: - res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + res += " " + std::to_string(static_cast<int *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); break; case DataType::Float64: - res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + res += " " + std::to_string(static_cast<double *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); break; default: - res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : "\n"); + res += " " + std::to_string(static_cast<float *>(mImpl->rawPtr())[j]) + ((j < dims()[0]-1) ? "," : ""); break; } } } - - res += "}"; return res; } diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 1f1eeafa859b116606613392a13a65ad398669ad..89ba148497709f0af475bbf953ff285c88036102 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -124,7 +124,7 @@ public: } /** - * @brief List dataInput connections of the GraphView object's inputNodes. + * @brief List outside dataInput connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> dataInputs() const; @@ -137,7 +137,7 @@ public: inline auto dataInputs(const std::string name) const { return mNodeRegistry.at(name)->dataInputs(); } /** - * @brief List input connections of the GraphView object's inputNodes. + * @brief List outside input connections of the GraphView object's inputNodes. * @return std::vector<std::pair<NodePtr, IOIndex_t>> */ std::vector<std::pair<NodePtr, IOIndex_t>> inputs() const; diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index dbe017fc7f8935e83ff1672392992c75a2e0658c..1d8449ac25cf8c31192da0c350c14cbfa50a48f4 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -399,6 +399,18 @@ public: return node->clone(); } + + /** + * @brief Get the set of pointers to connected node at a distance of a delta. + * @details the recution are cut + * Return a nullptr is nofing found. + * @param delta Input delta. + * @return std::shared_ptr<Node> + */ + + std::set<NodePtr> getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee); + + private: /////////////////////////////////////////////////////// // OPERATORS diff --git a/include/aidge/graphRegex/GraphFsmInterpreter.hpp b/include/aidge/graphRegex/GraphFsmInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e92b6fe8fc9d5e44cb8051e687e33d7192e0eb7 --- /dev/null +++ b/include/aidge/graphRegex/GraphFsmInterpreter.hpp @@ -0,0 +1,73 @@ +#ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ +#define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ + +#include <string> +#include <memory> + +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +namespace Aidge { + + class GraphFsmInterpreter + { + private: + /* data */ + GraphParser mParser; + std::size_t mActGroupe; + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> mNodesCondition; + + public: + GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition); + virtual ~GraphFsmInterpreter() =default; + + + std::shared_ptr<FsmGraph> interpret(void); + + private: + + + std::shared_ptr<FsmGraph> visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree); + + /** + * @defgroup graphFsmInterpreterF Functions for interpreting AST nodes + * @brief For each node type in the AST, define how build the FsmGraph + */ + + + /** + * @ingroup graphFsmInterpreterF + * @brief leaf of fsm make the fsm for test one transition + */ + std::shared_ptr<FsmGraph> keyF(std::shared_ptr<AstNode<gRegexTokenTypes>> AstNode); + /** + * @ingroup graphFsmInterpreterF + * @brief combine two fsm of two expression. + */ + std::shared_ptr<FsmGraph> sepF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm); + /** + * @ingroup graphFsmInterpreterF + * @brief combine two to make a new that match leftFsm next rigthFsm + */ + std::shared_ptr<FsmGraph> nextF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm); + /** + * @ingroup graphFsmInterpreterF + * @brief make the fsm match + + */ + std::shared_ptr<FsmGraph> qomF(std::shared_ptr<FsmGraph> fsm); + /** + * @ingroup graphFsmInterpreterF + * @brief make the fsm match * + */ + std::shared_ptr<FsmGraph> qzmF(std::shared_ptr<FsmGraph> fsm); + + }; + + + +} + + +#endif // AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ diff --git a/include/aidge/graphRegex/GraphLexer.hpp b/include/aidge/graphRegex/GraphLexer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e4137ab093c466b7349007da91e032dae48eda51 --- /dev/null +++ b/include/aidge/graphRegex/GraphLexer.hpp @@ -0,0 +1,68 @@ +#ifndef AIDGE_CORE_GRAPH_LEXER_H_ +#define AIDGE_CORE_GRAPH_LEXER_H_ + +#include <string> +#include <memory> +#include <regex> +#include <stdexcept> //error +#include <sstream> + +#include "aidge/utilsParsing/ParsingToken.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +namespace Aidge { + + class GraphLexer + { + + public: + GraphLexer( const std::string gRegexExpressions ); + + /** + * @brief Get the next token on the gRegexExpressions + * @return ConditionalToken + */ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> getNextToken(void); + /** + * @brief Restart at the start of the gRegexExpressions + * + */ + void rstPosition(void); + + /** + * @brief Test if the string is completely read + * @return bool + */ + bool isEnd(void); + + + /** + * @brief Get the representation of the class + * @return string + */ + const std::string rep(); + + private: + + /** + * @brief Constructs an error message to display the character not understood by the lexer + * @return error mesage + */ + std::runtime_error badTokenError(const std::string& currentChars,std::size_t position); + + /** + * @brief The expression of the test to be performed on the nodes + */ + const std::string mRegularExpressions; + /** + * @brief The lexer's current position in mConditionalExpressions + */ + std::size_t mPosition; + + }; +} + + + + +#endif //AIDGE_CORE_GRAPH_LEXER_H_ diff --git a/include/aidge/graphRegex/GraphParser.hpp b/include/aidge/graphRegex/GraphParser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..73406203a8be87e1df75cc694ab1ff281c27fbfa --- /dev/null +++ b/include/aidge/graphRegex/GraphParser.hpp @@ -0,0 +1,98 @@ +#ifndef AIDGE_CORE_GRAPH_PARSER_H_ +#define AIDGE_CORE_GRAPH_PARSER_H_ + + +#include <memory> // for shared_ptr +#include "aidge/graphRegex/GraphLexer.hpp" +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +namespace Aidge{ + +/** + * @brief this class uses the lexer to create an AST according to a set of gramer rules + */ +class GraphParser{ + + public: + /** + * @brief AST graph creation function + * @param gRegexExpressions String representing the logical fuction to be performed + */ + GraphParser(const std::string gRegexExpressions); + + virtual ~GraphParser() = default; + + /** + * @brief AST graph creation function + * @return The AST tree + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> parse(void); + + + private: + /** + * @brief restart at the start of the ConditionalExpressions for LEXER and restart mCurrentToken + */ + void rstParser(void); + + ////////////////// + + /** + * @defgroup ParsingFunctions Function for creating AST + * @brief Functions for recursive construction of the AST representing grammar rules + */ + + /** + * @ingroup ParsingFunctions + * @brief Token reading and verification function + * + */ + void ackToken(gRegexTokenTypes tokenType); + + //TODO TODO + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for key : KEY(QOM | QZM)? | CKEY + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstExp(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for sequence : seq :exp (NEXT seq)* + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstSeq(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for domain : (seq NEXT domain)? | LPAREN domain RPAREN (QOM | QZM) (NEXT domain)? + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstDomain(void); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for multiple exepresion : allExpr: domain (SEP allExpr)* + * @return AST node + */ + std::shared_ptr<AstNode<gRegexTokenTypes>> constructAstAllExpr(void); + + + /** + * @brief The actual token in the parce + */ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> mCurrentToken; + + /** + * @brief The lexem use + */ + GraphLexer mLexer; + +}; + + +} + +#endif //AIDGE_CORE_GRAPH_PARSER_H_ diff --git a/include/aidge/graphRegex/GraphRegexTypes.hpp b/include/aidge/graphRegex/GraphRegexTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..9e35f8f027eb363b71358922ffe0caa4a55fff1d --- /dev/null +++ b/include/aidge/graphRegex/GraphRegexTypes.hpp @@ -0,0 +1,29 @@ + +#ifndef AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ +#define AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ + + +namespace Aidge { + /** + * @brief enum for all types of token use in the of the regex + * 7-5 type + * 4-0 id + */ + enum class gRegexTokenTypes + { + STOP, + NEXT, /**< -> */ + + QOM, /**< + */ + QZM, /**< * */ + + KEY, /**< [A-Za-z_0-9]+ */ + CKEY, /**< [A-Za-z_0-9]+#[0-9]* */ + + SEP, /**< \( */ + LPAREN, /**< \( */ + RPAREN, /**< \) */ + }; + +} +#endif //AIDGE_CORE_GREGEX_TOKEN_TYPES_H_ diff --git a/include/aidge/graphRegex/GraphStrInterpreter.hpp b/include/aidge/graphRegex/GraphStrInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..98dca0e9f84de0be2614aed0e47c9d86ae552674 --- /dev/null +++ b/include/aidge/graphRegex/GraphStrInterpreter.hpp @@ -0,0 +1,40 @@ +#ifndef AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ +#define AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ + +#include <iostream> +#include <sstream> +#include <memory> +#include <algorithm> + +#include "aidge/utilsParsing/AstNode.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +namespace Aidge { + + class GraphStrInterpreter + { + private: + /* data */ + GraphParser mParser; + std::string mToTest; + public: + GraphStrInterpreter(const std::string graphMatchExpr); + virtual ~GraphStrInterpreter() =default; + + + std::string interpret(void); + + private: + + + std::string visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree); + }; + + + +} + + +#endif //AIDGE_CORE_GRAPH_FSM_INTERPRETER_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmEdge.hpp b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c3eae528808dbdb8023718c961b7c45cbf4afac9 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmEdge.hpp @@ -0,0 +1,228 @@ +#ifndef AIDGE_CORE_FSM_EDGE_H_ +#define AIDGE_CORE_FSM_EDGE_H_ + +#include <memory> +#include <set> +#include <string> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + + +namespace Aidge{ + + class FsmNode; + class FsmRunTimeContext; + + struct EdgeTestResult { + bool success; + std::set<NodePtr> node; + }; + + /** + * @brief virtual class use test the node on the node to validate + */ + class FsmEdge: public std::enable_shared_from_this<FsmEdge> + { + private: + + /** + * @brief the relative position to this test relative to all the const key + * first is common id, second is the relative position + */ + std::map<size_t,int> mRelativePos; + /** + * @brief the ptr on the source node + */ + std::shared_ptr<FsmNode> mNodeSource; + /** + * @brief the ptr on the dest node + */ + std::shared_ptr<FsmNode> mNodeDest; + /** + * @brief the weak ptr + */ + std::weak_ptr<FsmEdge> weakPtr; + + public: + FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest); + + virtual ~FsmEdge(){}; + + FsmEdge() : weakPtr(shared_from_this()) {} + + + /** + * @brief test is the validation of the node, it must be defined for all types of edge + * it takes as argument an FSM traversal context and returns a set of next nodes + * @return set of next node or nullptr if not next + */ + + virtual const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) =0; + + /** + * @brief test is the egde test a common node + * @return true if is a common + */ + virtual bool isCommon(void); + /** + * @brief get the Common idx of the common test in this edge (if is a common edge) + * @return idx of the common + */ + virtual size_t getCommonIdx(void); + /** + * @brief get the relative postion to the common node deffine in this edge + * @return map + */ + const std::map<size_t,int>& getRelative(void); + /** + * @brief add new relative position + */ + void updateRelative( const std::map<size_t,int>& relativePos ); + /** + * @brief get source FsmNode + * @return FsmNode + */ + std::shared_ptr<FsmNode> getSourceNode(void); + /** + * @brief set a new source to the edge + * @return FsmNode + */ + void reSetSouceNode(const std::shared_ptr<FsmNode>& newSource); + /** + * @brief get dest FsmNode + * @return FsmNode + */ + std::shared_ptr<FsmNode> getDestNode(void); + /** + * @brief set a new dest to the edge + * @return FsmNode + */ + void reSetDestNode(const std::shared_ptr<FsmNode>& newDest); + /** + * @brief propagate the edge mRelativePos to the others Edge and recalcul the relative position + */ + void propagateRelativePos(void); + + /** + * @brief test to make on the node to validate + * @see ConditionalInterpreter + */ + const std::shared_ptr<ConditionalInterpreter> mToTest; + + /** + * @brief update week ptr for the node, TODO best + */ + void updateWeak(void); + }; + + /** + * @brief class spesialisation for not commun node (node that must be match one Unique) transition + */ + class FsmEdgeUnique:public FsmEdge + { + + public: + FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest); + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + }; + + /** + * @brief class spesialisation for commun node transition + * @see FsmEdge + */ + class FsmEdgeCommon:public FsmEdge + { + + private: + /** + * @brief the map that defind the ralation between the commonKey find by the lexer and a unique id use to refer to the common node + */ + static std::map<std::string,int> mCommonIdxMap; + /** + * @brief the common id test in this transition + */ + int mCommonIdx; + public: + + /** + * @brief constructor commun node , + * @details during construction, + * the node key found by the lexer is converted to a unique id and the relative positions are updated. + */ + FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey); + // ~FsmEdgeCommon() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + bool isCommon(void) override; + + }; + + + + /** + * @brief class spesialisation for ref transition + * @see FsmEdge + */ + class FsmEdgeRef:public FsmEdge + { + private: + /** + * @brief the id of one common node that we use as an anchor + */ + const int mRefCommonIdx; + /** + * @brief the delta in terme of child or parent refer to the anchor + */ + const int mdeltaCommonIdx; + public: + FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx); + //~FsmEdgeRef() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + + }; + + /** + * @brief class spesialisation for ref empty transition + * @see FsmEdge + */ + class FsmEdgeEmpty:public FsmEdge + { + + public: + FsmEdgeEmpty(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + //~FsmEdgeEmpty() override {} + const EdgeTestResult test(const std::shared_ptr<FsmRunTimeContext> stmContext) override; + + }; + + + +//////////////////////// +// FACTORY +//////////////////////// + +enum class FsmEdgeTypes { + EMPTY = 0, + REF, + COMMON, + UNIQUE +}; + + +class FsmEdgeFactory { + public: + /** + * @brief factory for making edge and read the info in the lexeme of the token + * @param source source node of the edge + * @param dest Dest node of the edge + * @param type type of the edge + * @param lexeme the additional information to build the edge + * @return s prt of the edge + */ + static std::shared_ptr<FsmEdge> make(std::shared_ptr<FsmNode> source, std::shared_ptr<FsmNode> dest, + FsmEdgeTypes type,std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest, + const std::string lexeme = ""); + }; + +} + +#endif //AIDGE_CORE_FSM_EDGE_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmGraph.hpp b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp new file mode 100644 index 0000000000000000000000000000000000000000..0a74551367dd492cb0abb820e4c5ce5a601d071e --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmGraph.hpp @@ -0,0 +1,98 @@ + +#ifndef AIDGE_CORE_FSM_GRAPH_H_ +#define AIDGE_CORE_FSM_GRAPH_H_ + +#include <set> +#include <vector> +#include <memory> +#include <stdexcept> //error + +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" +namespace Aidge{ + + + +class FsmGraph +{ +private: + /** + * @brief all node origine + */ + std::set<std::size_t> mAllOrigine; + std::set<std::shared_ptr<FsmEdge>> mEdges; +public: + FsmGraph(/* args */); + virtual ~FsmGraph() = default; + +std::shared_ptr<MatchResult> test(std::vector<NodePtr>& StartNodes); + + + +const std::set<std::shared_ptr<FsmEdge>>& getEdge(void); +/** + * @brief add edge in the graph, as FsmEdge know the source and dest FsmNode these nodes are also add to the graph +*/ +void addEdge(std::shared_ptr<FsmEdge>& edge); + +/** + * @brief get the liste of the starting states + * @details we need to use a vector because the order of the nodes is important for start node initialization \ref test() +*/ +const std::vector<std::shared_ptr<FsmNode>> getStartNodes(void); + +/** + * @brief get the set of the valide states + * @return set of valide state +*/ +const std::set<std::shared_ptr<FsmNode>> getValidNodes(void); + +/** + * @brief get the set of all the node in the graph + * @return set of all nodes +*/ +const std::set<std::shared_ptr<FsmNode>> getNodes(void); + +/** + * @brief set a groupe idx for all the nodes in the graph +*/ +void setGroupe(std::size_t groupeIdx); + +/** + * @brief make the union beteen this graph and an input graph + * @param fsmGraph graph to union +*/ +void unionG(const std::shared_ptr<FsmGraph> fsmGraph); + + +/** + * @brief make the union beteen this graph and an input graph and merge the valide state to the start state + * @param fsmGraph graph to merge +*/ +void mergeOneStartOneValid(const std::shared_ptr< FsmGraph> fsmGraph); +/** + * @brief get the number of sub FSM + * @return number of sub Fsm +*/ +std::size_t getNbSubFsm(void); + +/** + * @brief increment the origine of all node in the graph + * @param incr the incrémentation value +*/ +void incOrigineAllNodeBy(std::size_t incr); + +private: + +/** + * @brief merge tow node of the graph + * @param node +*/ +void _mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest); + +}; + + +} +#endif //AIDGE_CORE_FSM_GRAPH_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmNode.hpp b/include/aidge/graphRegex/matchFsm/FsmNode.hpp new file mode 100644 index 0000000000000000000000000000000000000000..2776ff8eb297fd5ad9a4c425fb386adde0a25269 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmNode.hpp @@ -0,0 +1,99 @@ +#ifndef AIDGE_CORE_FSM_NODE_H_ +#define AIDGE_CORE_FSM_NODE_H_ + +#include <set> +#include <vector> +#include <memory> + +//#include "graphRegex/matchFsm/FsmEdge.hpp" +//#include "graphRegex/matchFsm/FsmRunTimeContext.hpp" + +namespace Aidge{ + // Forward declaration of the class defined in graphRegex/matchFsm/FsmEdge.hpp + class FsmEdge; + struct EdgeTestResult; + class FsmRunTimeContext; + + + //------------------------------------------------------------------------------ + + // MAY BE IN UTILE + template <typename T> + struct lex_compare { + bool operator() (const std::weak_ptr<T> &lhs, const std::weak_ptr<T> &rhs)const { + auto lptr = lhs.lock(), rptr = rhs.lock(); + if (!rptr) return false; // nothing after expired pointer + if (!lptr) return true; + return lptr < rptr; + } + }; + + /** + * @brief is a node in the FSM graph, it's a state in the FSM + * @details a state can be and/or : + * - a valide state, the match is valide if it stop on this edge + * - a start state , the match start on this state + * The state is also define by this origine (is the unique id of it's expretion ) + * and it's groupe (for inner expression TODO) + */ + class FsmNode : public std::enable_shared_from_this<FsmNode> + { + private: + /** + * @brief the edge of the node + * @details the edge have a shared ref to the node so we use weak ref + */ + std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>> mEdges; + /** + * @brief the parent of the node + */ + std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>> mParents; + + std::size_t mOrigineStm = 0; + std::size_t mGroupeStm = 0; + + bool mIsAValid; + bool mIsAStart; + + public: + FsmNode(bool isAValid,bool isAStart ); + virtual ~FsmNode() = default; + /** + * @brief use to MAG the actual context , and return all the posible new context + * @details one input context can generate a multitude of contexts because a graph node + * can have more than one child, and each traversal possibility is a new context. + * @param actContext the actual context + * @return A vector of all the new context + */ + const std::vector<std::shared_ptr<FsmRunTimeContext>> test( std::shared_ptr<FsmRunTimeContext>); + + + std::size_t getOrigine(void); + void incOrigine(std::size_t inc); + + + void rmEdge(std::shared_ptr<FsmEdge>); + void addEdge(std::shared_ptr<FsmEdge>); + + //const std::set<std::shared_ptr<FsmNode>> getChildNodes(void); + + const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>>& getParentNodes(void); + const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& getEdges(void); + + void setGroupe(std::size_t groupeIdx); + + bool isValid(void); + bool isStart(void); + void unValid(void); + void valid(void); + void unStart(void); + void start(void); + + + + void addParent(std::shared_ptr<FsmNode>); + void rmParent(std::shared_ptr<FsmNode>); + }; + +} +#endif //AIDGE_CORE_FSM_NODE_H_ diff --git a/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6f1b9fc2bfe68195f67cfc0bf17d57aed5345219 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp @@ -0,0 +1,173 @@ +#ifndef AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ +#define AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ + +#include <memory> +#include <vector> +#include <set> +#include <algorithm> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" +#include "aidge/graph/Node.hpp" + + + +namespace Aidge{ + + class FsmNode; + + class FsmNode; + + /** + * @brief a class used to save the execution context of state machines, that is the actual state in the FSM, the actual node in the graph + * all node that have been Validate,Rejecte or Considered common + */ + class FsmRunTimeContext + { + private: + /** + * @brief the list of node rejected for all the context + */ + static std::vector<std::set<NodePtr>> mRejectedNodes; + /** + * @brief the actual state of this Context (where it's in the FSM graph) + */ + std::shared_ptr<FsmNode> mActState; + /** + * @brief the actual node of this Context (where it's in the graph) + */ + NodePtr mActOpNode; + /** + * @brief the map of the node consider as common and the common ID + * @details we need to store what node it's consider as common because of the end + * resolution of the matching, all node consider as common need to be the same in all context + */ + std::map<NodePtr,std::size_t> mCommonNodes; + /** + * @brief the map of the node that as been valid in this context , and the test that valide the node + */ + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> mValidNodes; + /** + * @brief the index in the rejected node of this context + */ + std::size_t mLocalIdxRejeced; + public: + /** + * @brief constructor + * @param actState the actual state in the FSM + * @param actOpNode the actual node in the graph + * @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind + */ + FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() ); + FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime); + FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ); + + virtual ~FsmRunTimeContext()=default; + + /** + * @defgroup FsmRunTimeContextRejected Function for managing rejected nodes + */ + + /** + * @ingroup FsmRunTimeContextRejected + * @brief Add a node as rejected in this context + */ + void addRejectedNode(NodePtr node); + + /** + * @ingroup FsmRunTimeContextRejected + * @brief get the rejected nodes of this context + */ + std::set<NodePtr> getRejectedNodes(void); + + + /** + * @defgroup FsmRunTimeContextTest Function for test the context + */ + + /** + * @ingroup FsmRunTimeContextTest + * @brief test if the actual state is valide + * @return bool + */ + bool isOnValidState(void); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if the node is considered as common in this context + * @param node node to test + * @return bool + */ + bool isCommonDefined(NodePtr node); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if has already validated in this context + * @param node node to test + * @return bool + */ + bool isAlreadyValid(NodePtr node); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if this context is compatible with an others + * @details to say that two contexts are compatible is to check : + * that the contexts do not validate the same nodes (other than the common ones) + * and that the common ones have the same idx + * @param fsmContext the others context + * @return bool + */ + bool areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext); + /** + * @ingroup FsmRunTimeContextTest + * @brief test if this context is strictly equal with an others + * @param fsmContext the others context + * @return bool + */ + bool areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext); + + /** + * @defgroup FsmRunTimeContextSet Function set context + */ + + + void setCommon(NodePtr node,std::size_t commonIdx); + + + void setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag); + + /** + * @defgroup FsmRunTimeContextGet Function get context + */ + + + /** + * @ingroup FsmRunTimeContextGet + * @brief get the sub idx state + * @return bool + */ + std::size_t getSubStmId(void); + + NodePtr getCommonNodeFromIdx(std::size_t commonIdx); + std::size_t getCommonNodeIdx(NodePtr node); + std::set<NodePtr> getCommonNodes(void); + + std::map<NodePtr,std::size_t> getCommon(void); + std::set<NodePtr> getValidNodes(void); + + std::set<NodePtr> getValidNodesNoCommon(void); + std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> getValid(void); + + + NodePtr getActNode(void); + std::shared_ptr<FsmNode> getActState(void); + + + /** + * @defgroup FsmRunTimeContextMem + */ + + void rst(void); + + + }; + +} + +#endif //AIDGE_CORE_FSM_RUN_TIME_CONTEXT_H_ diff --git a/include/aidge/graphRegex/matchFsm/MatchResult.hpp b/include/aidge/graphRegex/matchFsm/MatchResult.hpp new file mode 100644 index 0000000000000000000000000000000000000000..ac2f2a627a9d88b3cabeac4b181af2f3b7566d72 --- /dev/null +++ b/include/aidge/graphRegex/matchFsm/MatchResult.hpp @@ -0,0 +1,60 @@ +#ifndef AIDGE_CORE_MATCH_RESULT_H_ +#define AIDGE_CORE_MATCH_RESULT_H_ + +#include <memory> +#include <vector> +#include <map> + + +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" +#include "aidge/graph/Node.hpp" + +namespace Aidge{ + +/** + * @brief class that old the result of a matching + * give acess to all node ant there tag in the expression +*/ +class MatchResult +{ +private: + /* data */ + std::vector<std::shared_ptr<FsmRunTimeContext>> mAllValid; + + /* + the Run time of eatch sub FSM , to have a valide match we need a set of one run time per FSM compatible + the id must be contigue + */ + std::vector<std::vector<std::shared_ptr<FsmRunTimeContext>>> mIdToRunTime; + + std::vector<std::set<NodePtr>> mSolve; + + std::size_t mNbSubStm; + +public: + MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm); + + virtual ~MatchResult() = default; + + /** + * @brief get the set of the node match for une expression + * @return the set of node of the graph that corresponding to an expression + */ + std::set<NodePtr> getBiggerSolution(void); + +private: + +/** + * @brief recurent function use to inite mSolve in the constructor + * + **/ + +void _generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence); + +}; + + +} + + +#endif //AIDGE_CORE_MATCH_RESULT_H_ diff --git a/include/aidge/hook/execTime.hpp b/include/aidge/hook/ExecTime.hpp similarity index 100% rename from include/aidge/hook/execTime.hpp rename to include/aidge/hook/ExecTime.hpp diff --git a/include/aidge/hook/hook.hpp b/include/aidge/hook/Hook.hpp similarity index 94% rename from include/aidge/hook/hook.hpp rename to include/aidge/hook/Hook.hpp index 28f7ef5cddbc649af50209ba77527b8b75d731b7..5e00db5d68f11aadd4f3b6eb8174ba61b33e4a49 100644 --- a/include/aidge/hook/hook.hpp +++ b/include/aidge/hook/Hook.hpp @@ -31,11 +31,11 @@ protected: public: Hook(std::shared_ptr<Operator> op) : mOperator(op) {} - virtual ~Hook(); + virtual ~Hook() = default; virtual void call() = 0; }; } -#endif /* Hook_H_ */ \ No newline at end of file +#endif /* Hook_H_ */ diff --git a/include/aidge/hook/outputRange.hpp b/include/aidge/hook/OutputRange.hpp similarity index 100% rename from include/aidge/hook/outputRange.hpp rename to include/aidge/hook/OutputRange.hpp diff --git a/include/aidge/nodeTester/ConditionalData.hpp b/include/aidge/nodeTester/ConditionalData.hpp new file mode 100644 index 0000000000000000000000000000000000000000..12df32a728571678a3885f9981e526e1d73db785 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalData.hpp @@ -0,0 +1,98 @@ + +#ifndef AIDGE_CORE_CONDITIONAL_DATA_H_ +#define AIDGE_CORE_CONDITIONAL_DATA_H_ + +#include <vector> +#include <string> +#include <stdexcept> //error +#include <memory> +#include <map> +namespace Aidge{ + + + +///////////////////////// +// The data type in AST Intepretation +//////////////////////// + +class BaseConditionalValue { +public: + virtual ~BaseConditionalValue() {} +}; + +template <typename T> +class ConditionalValue : public BaseConditionalValue { +public: + ConditionalValue(const T& data) : value(data) {} + T value; +}; + + +struct ConditionalData { + /** + * @brief generic type to propagate all the different values in the AST interpretation + */ + //void* value; + std::unique_ptr<BaseConditionalValue> value; + const std::type_info* type =nullptr; + + ///////////////////////////////// + // + //////////////////////////////// + /** + * @brief set a value + */ + template <typename T> + void setValue(const T& newValue) { + //make sure that the old value is free + deleteValue(); + value = std::make_unique<ConditionalValue<T>>(newValue); + type = &typeid(T); + } + + /** + * @brief get the actual value + * @details recaste the value to the templaited type and checks that the conversion type is compatible with type + * @tparam the type of the return value + * @return the value + */ + template <typename T> + T getValue() const { + if (type && *type == typeid(T)) { + //const Value<T>* typedValue = dynamic_cast<const Value<T>*>(static_cast<const BaseValue*>(value)); + const ConditionalValue<T>* typedValue = dynamic_cast<const ConditionalValue<T>*>(value.get()); + if (typedValue) { + return typedValue->value; + } + } + throw std::runtime_error(std::string("DATA ERROR ") + type->name() + " != " + typeid(T).name()); + } + /////////////////////////////////// + // + /////////////////////////////////// + std::string getType() const { + return type ? type->name() : "nullptr"; + } + + + template <typename T> + bool isTypeEqualTo() const { + return (type && *type == typeid(T)); + } + + void deleteValue() { + if (type) { + value.reset(); + type = nullptr; + } + } + + ~ConditionalData() { // TODO best can we have a list of type supported ? + deleteValue(); + } +}; + +} + + +#endif //AIDGE_CORE_CONDITIONAL_DATA_H_ diff --git a/include/aidge/nodeTester/ConditionalInterpreter.hpp b/include/aidge/nodeTester/ConditionalInterpreter.hpp new file mode 100644 index 0000000000000000000000000000000000000000..165fac1c2ae98bf76b73c039de9fc975e9845cc9 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalInterpreter.hpp @@ -0,0 +1,339 @@ + + +#ifndef AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ +#define AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ + +#include "aidge/nodeTester/ConditionalParser.hpp" +#include "aidge/nodeTester/ConditionalData.hpp" + +#include <memory> // for shared_ptr +#include <unordered_map> +#include <functional> +#include "aidge/graph/Node.hpp" +#include <sstream> + + +namespace Aidge{ + + + +////////////////////////////// +// +///////////////////////////// +/** + * @brief class used to register any lambda function without context, + * it encapsulates the source lambda in a lambda which takes as argument ConditionalData* which are any type. + * @see ConditionalData + */ +class ConditionalRegisterFunction { + ////////////////////////// + //Safe recaste + ////////////////////////// + + /** + * @brief recast the ConditionalData* to the argument type of the lambda + * @tparam T type of the lambda argument + * @see ConditionalData + */ + template <typename T> + T safeCastInput(ConditionalData* data) { + //cnvertion and type cheking + if (data->isTypeEqualTo<T>()){ + return data->getValue<T>(); + }else{ + throw std::invalid_argument( "incompatible input type " + data->getType() +" "+ typeid(T).name() ); + } + + } + + + /** + * @brief recaste the output of the lambda to a ConditionalData* + * @tparam T type of the lambda return + * @see ConditionalData + */ + template <typename T> + ConditionalData* safeCastOutput(T data) { + + ConditionalData* out = new ConditionalData; + out->setValue<T>(data); + + return out; + } + + + + + ////////////////////// + // get all the type of the function + ////////////////////// + + /** + * @brief Retrieves information about a function's return type and argument types. + * @tparam T The function type. + */ + template <typename T> + struct function_traits; + + + /** + * @brief Specialization of function_traits for function pointers. + * @tparam R The return type of the function. + * @tparam Args The argument types of the function. + */ + template <typename R, typename... Args> + struct function_traits<R (*)(Args...)> { + using return_type = R; + static constexpr std::size_t arity = sizeof...(Args); + + template <std::size_t N> + struct argument { + static_assert(N < arity, "Index out of range."); + using type = typename std::tuple_element<N, std::tuple<Args...>>::type; + }; + }; + + /** + * @brief Specialization of function_traits for std::function types. + * @tparam R The return type of the function. + * @tparam Args The argument types of the function. + */ + template <typename R, typename... Args> + struct function_traits<std::function<R(Args...)>> { + using return_type = R; + static constexpr std::size_t arity = sizeof...(Args); + + template <std::size_t N> + struct argument { + static_assert(N < arity, "Index out of range."); + using type = typename std::tuple_element<N, std::tuple<Args...>>::type; + }; + }; + + ///////////////////// + //change the function to ConditionalData*(std::vector<ConditionalData*>) + ///////////////////// + + /** + * @brief Converts a function to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam F The type of the function to convert. + * @tparam ParamsIdx The indices of the function parameters. + * @param f The function to convert. + * @return The pointer to the converted function. + */ + template <class F, std::size_t... ParamsIdx> + auto funcPointer(F f, std::index_sequence<ParamsIdx...>) { + //wrapp the lambda in a new one that as ConditionalData as inputs and output + return [this,f](std::vector<ConditionalData*> &args) { + if (args.size() != sizeof...(ParamsIdx)){ + std::ostringstream errorMessage; + errorMessage << "bad Number of argument: get " << args.size() << " need " << sizeof...(ParamsIdx) << "\n"; + throw std::runtime_error(errorMessage.str()); + } + //assert(args.size() == sizeof...(ParamsIdx));//the size of the vector valide + + using FuncTraits = function_traits<decltype(f)>; + using outType = typename FuncTraits::return_type; + + outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[ParamsIdx])...); + //typename + return safeCastOutput<outType>(result); + }; + } + + /** + * @brief Converts a function pointer to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam R The return type of the function. + * @tparam Params The parameter types of the function. + * @param f The function pointer to convert. + * @return The pointer to the converted function. + */ + template <class R,class... Params> + auto funcPointer(R (*f)(Params...)) { + return funcPointer(f, std::index_sequence_for<Params...>{}); + } + + /** + * @brief Converts a std::function to a ConditionalData*(std::vector<ConditionalData*>). + * @tparam R The return type of the function. + * @tparam Params The parameter types of the function. + * @param f The function pointer to convert. + * @return The pointer to the converted function. + */ + template <class R,class... Params> + auto funcPointer(std::function<R(Params...)> f) { + return funcPointer(f, std::index_sequence_for<Params...>{}); + } + + + /////////////////// + // interface + /////////////////// + + public: + + /** + * @brief Default constructor + */ + ConditionalRegisterFunction(){} + + + /** + * @brief Inserts a function into the map with the provided key. + * @tparam T The function type. + * @param key The key to associate with the function. + * @param f The function to insert. + */ + template <class T> + void insert(const std::string key,T f){ + mWlambda.insert({ key, funcPointer(f)}); + } + + + /** + * @brief Runs the function associated with the given key, using the provided vector of input data. + * @param key The key of the function to run. + * @param datas The vector of input data. + * @return A pointer to the output ConditionalData object. + */ + ConditionalData* run(const std::string key,std::vector<ConditionalData*> & datas); + + private: + /// @brief map of name and the converted function. + std::map<const std::string, std::function<ConditionalData*(std::vector<ConditionalData*> &)>> mWlambda; +}; + +/////////////////// +//AST tree node +// //////////////// +/** + * @brief this class interprets AST to generate a test on a graph node. For each AST node, + * it generates an interpretation and registers lambda functions that can be used in the test expression. + * there are two lambda control mechanisms: + * - A cpp mechanism which allows any lambda to be inserted into the constructor that use templaite + * - A user mechanism limited to lambda bool(NodePtr) + * @see ConditionalParser use to get the AST + */ +class ConditionalInterpreter +{ + private: + + /** + * @brief the AST generate by the Parser + * @see ConditionalParser + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> mTree; + /** + * @brief the registery for the lambda fuction + * @see ConditionalRegisterFunction + */ + ConditionalRegisterFunction mLambdaRegiter; + + + std::vector<ConditionalData*> mResolution ; + + void clearRes(){ + + for (std::size_t i = 0; i < mResolution.size(); ++i) { + delete mResolution[i]; + } + mResolution.clear(); + } + + public: + /** + * @brief Constructor + * @param ConditionalExpressions The expression of the test to be performed on the nodes + */ + + ConditionalInterpreter(const std::string ConditionalExpressions); + + ~ConditionalInterpreter(){clearRes();} + + /** + * @brief Test a node depending of the ConditionalExpressions + * @details the AST is visit using \ref visit() whith the $ init whit the nodeOp + * @return bool the match node has the initialized expresion + * @see visit() This function uses the visit() function to perform the evaluation. + */ + bool test( const NodePtr nodeOp); + + /** + * @brief Interface for inserting custom lambda bool(NodePtr) functions in AST interpretation, + * it will be available in the ConditionalExpressions expretion as : key($) + * @param key The key that will be used to call the function in the expression + * @param f The pointer to function + */ + void insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f); + + + ///// + + private: + /** + * @brief Recursive AST traversal function, using the for interpreting AST nodes function, + * using \ref ASTnodeInterpreterF fuctions + * @param NodeOp The node currently being tested + * @param nodes The AST given by the parsing process + */ + std::vector<ConditionalData*> visit(const ASTNodeCh& nodes, const NodePtr NodeOp ); + + /** + * @defgroup ASTnodeInterpreterF Functions for interpreting AST nodes + * @brief For each node type in the AST, function defines the processing to be performed + * they return a std::vector<ConditionalData*> which corresponds to the value(s) obtained + */ + + /** + * @ingroup ASTnodeInterpreterF + * @brief Function that does something. + */ + void fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a int and to ConditionalData* + */ + void fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a float and to ConditionalData* + */ + void fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + /** + * @ingroup ASTnodeInterpreterF + * @brief Converted the lexeme to a str and to ConditionalData* + */ + void fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); + + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the == operation between two previously converted ConditionalData* + */ + void fEq(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the != operation between two previously converted ConditionalData* + */ + void fNeq(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the && operation between two previously converted ConditionalData* in bool + */ + void fAnd(void); + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the || operation between two previously converted ConditionalData* in bool + */ + void fOr(void); + + /** + * @ingroup ASTnodeInterpreterF + * @brief makes the ! operation + */ + void fNot(void); +}; + + +} + +#endif //AIDGE_CORE_CONDITIONAL_INTERPRETER_H_ diff --git a/include/aidge/nodeTester/ConditionalLexer.hpp b/include/aidge/nodeTester/ConditionalLexer.hpp new file mode 100644 index 0000000000000000000000000000000000000000..fcfb9ebe783ac719076ce675e6fc3d78caf5be07 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalLexer.hpp @@ -0,0 +1,88 @@ +/** + * @file + * @brief + * @version file 1.0.0 + * @author vl241552 + * @copyright + * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. All + * rights reserved. + */ + + + +#ifndef AIDGE_CORE_CONDITIONAL_LEXER_H_ +#define AIDGE_CORE_CONDITIONAL_LEXER_H_ + +#include <string> +#include <regex> +#include <memory> // for shared_ptr + + +#include <stdexcept> //error +#include <sstream> + +#include "aidge/nodeTester/ConditionalTypes.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" + + +namespace Aidge{ + + + +class ConditionalLexer +{ + +public: +ConditionalLexer( const std::string ConditionalExpressions ); + +/** + * @brief Get the next token on the ConditionalExpressions + * @return ParsingToken<ConditionalTokenTypes> + */ +std::shared_ptr<ParsingToken<ConditionalTokenTypes>> getNextToken(void); +/** + * @brief Restart at the start of the ConditionalExpressions + * + */ +void rstPosition(void); + +/** + * @brief Test if the string is completely read + * @return bool + */ +bool isEnd(void); + + +/** + * @brief Get the representation of the class + * @return string + */ +const std::string rep(){ + return mConditionalExpressions; +} + +private: + +/** + * @brief Constructs an error message to display the character not understood by the lexer + * @return error mesage + */ +std::runtime_error badTokenError(const std::string& currentChars,std::size_t position); + +/** + * @brief The expression of the test to be performed on the nodes + */ +const std::string mConditionalExpressions; +/** + * @brief The lexer's current position in mConditionalExpressions + */ +std::size_t mPosition; + +}; + +///////////////////////////////////// + + +} + +#endif //AIDGE_CORE_CONDITIONAL_LEXER_H_ diff --git a/include/aidge/nodeTester/ConditionalParser.hpp b/include/aidge/nodeTester/ConditionalParser.hpp new file mode 100644 index 0000000000000000000000000000000000000000..a99f5374182f57c0adca3b4d44691ff4e37de44d --- /dev/null +++ b/include/aidge/nodeTester/ConditionalParser.hpp @@ -0,0 +1,109 @@ + + + +#ifndef AIDGE_CORE_CONDITIONAL_PARSER_H_ +#define AIDGE_CORE_CONDITIONAL_PARSER_H_ + + +#include <memory> // for shared_ptr +#include <map> +#include <vector> + +#include "aidge/nodeTester/ConditionalLexer.hpp" +#include "aidge/nodeTester/ConditionalTypes.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" +#include "aidge/utilsParsing/AstNode.hpp" + +namespace Aidge{ + +const std::map<ConditionalTokenTypes, std::size_t> ConditionalPrec{ + {ConditionalTokenTypes::AND,2}, + {ConditionalTokenTypes::OR,1} +}; + + + + +using ASTNodeCh = std::vector<std::shared_ptr<AstNode<ConditionalTokenTypes>>>; + +/** + * @brief this class uses the lexer to create an AST according to a set of gramer rules + */ +class ConditionalParser{ + + public: + /** + * @brief AST graph creation function + * @param ConditionalExpressions String representing the logical fuction to be performed + */ + ConditionalParser(const std::string ConditionalExpressions); + + virtual ~ConditionalParser() = default; + /** + * @brief AST graph creation function + * @return The AST tree + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> parse(void); + + + private: + /** + * @brief restart at the start of the ConditionalExpressions for LEXER and restart mCurrentToken + */ + void rstParser(void); + + ////////////////// + + /** + * @defgroup ParsingFunctions Function for creating AST + * @brief Functions for recursive construction of the AST representing grammar rules + */ + + /** + * @ingroup ParsingFunctions + * @brief Token reading and verification function + * + */ + void ackToken(ConditionalTokenTypes tokenType); + + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for values : (KEY|INTEGER|FOAT|STRING|LAMBDA lambda) + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstVal(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for comparison : val (EQ|NEQ) val | LPAREN expr RPAREN + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstCmpr(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for arguments of a lambda : LAMBDA val (ARGSEP val)* RPAREN + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstLambda(void); + /** + * @ingroup ParsingFunctions + * @brief Function of grammar rules for a expresion : cmpr ((AND | OR) cmpr)* + * @return AST node + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> constructAstExpr(std::size_t precLimit = 0); + + + /** + * @brief The actual token in the parce + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> mCurrentToken; + /** + * @brief The lexem use + */ + ConditionalLexer mLexer; + +}; + + +} + +#endif //AIDGE_CORE_CONDITIONAL_PARSER_H_ diff --git a/include/aidge/nodeTester/ConditionalTypes.hpp b/include/aidge/nodeTester/ConditionalTypes.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6cb2edfd78e6b43c4f2dbc89c49cdaa9ea79f7d2 --- /dev/null +++ b/include/aidge/nodeTester/ConditionalTypes.hpp @@ -0,0 +1,36 @@ + + +#ifndef AIDGE_CORE_CONDITIONAL_TYPES_H_ +#define AIDGE_CORE_CONDITIONAL_TYPES_H_ +namespace Aidge{ + /** + * @brief enum for all types of token use in the parsing + * 7-5 type + * 4-0 id + */ + enum class ConditionalTokenTypes + { + STOP, + + NOT, /**< ! */ + AND, /**< && */ + OR, /**< || */ + + EQ, /**< == */ + NEQ, /**< != */ + + KEY, /**< [A-Za-z][A-Za-z0-9_]* */ + INTEGER, /**< [0-9]+ */ + FLOAT, /**< [0-9]+\.[0-9]* */ + STRING , /**< \'.*\' */ + BOOL, /**< true|false */ + NODE, /**< \$ */ + LAMBDA , /**< [A-Za-z][A-Za-z0-9_]*\( */ + + ARGSEP, /**< , */ + LPAREN, /**< \( */ + RPAREN, /**< \) */ + + }; +} +#endif // AIDGE_CORE_CONDITIONAL_TYPES_H_ diff --git a/include/aidge/operator/Add.hpp b/include/aidge/operator/Add.hpp index 1e0f17e6db9278e7edf2a11918472c084561a308..65c7e8ce0e47bd470e2a1499a682ed2f2c8c2dbc 100644 --- a/include/aidge/operator/Add.hpp +++ b/include/aidge/operator/Add.hpp @@ -81,14 +81,14 @@ public: // return *in; // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(static_cast<std::size_t>(inputIdx) < NUM && "wrong inputIdx for Add operator."); assert(strcmp(data->type(), Tensor::Type)==0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { const auto expectedDims = mInputs[0]->dims(); std::size_t nonEmptyInputTensor = 1; @@ -140,7 +140,7 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Add_Op<NUM>>::create(name)(*this); mOutput->setBackend(name); @@ -150,7 +150,7 @@ public: } } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -162,6 +162,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return NUM; } inline IOIndex_t nbDataInputs() const noexcept override final { return NUM; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input_0", "data_input_n"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::size_t NUM> diff --git a/include/aidge/operator/AvgPooling.hpp b/include/aidge/operator/AvgPooling.hpp index b29463c675eb8516e02b83ad47816e9e9aa5d147..dfcd0d5b3b4d892f201485e85710d42cd5b71dba 100644 --- a/include/aidge/operator/AvgPooling.hpp +++ b/include/aidge/operator/AvgPooling.hpp @@ -26,15 +26,14 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class AvgPoolingAttr { StrideDims, KernelDims, PaddingDims }; +enum class AvgPoolingAttr { StrideDims, KernelDims }; template <DimIdx_t DIM> class AvgPooling_Op : public Operator, public Registrable<AvgPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const AvgPooling_Op<DIM> &)>, public StaticAttributes<AvgPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >> { + std::array<DimSize_t, DIM>> { private: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -47,18 +46,15 @@ public: using Attributes_ = StaticAttributes<AvgPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1)> >; + std::array<DimSize_t, DIM>>; template <AvgPoolingAttr e> using attr = typename Attributes_::template attr<e>; constexpr AvgPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), Attributes_(attr<AvgPoolingAttr::StrideDims>(stride_dims), - attr<AvgPoolingAttr::KernelDims>(kernel_dims), - attr<AvgPoolingAttr::PaddingDims>(padding_dims)) { + attr<AvgPoolingAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } @@ -84,7 +80,7 @@ public: return std::make_shared<AvgPooling_Op<DIM>>(*this); } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); @@ -92,16 +88,14 @@ public: mInput = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInput->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; for (std::size_t dim = 0; dim < this->template getAttr<AvgPoolingAttr::KernelDims>().size() ; ++dim) { outputDims[dim+2] = 1 + static_cast<DimSize_t>( std::floor(static_cast<float>(mInput->dims()[dim+2] - - this->template getAttr<AvgPoolingAttr::KernelDims>()[dim] + - this->template getAttr<AvgPoolingAttr::PaddingDims>()[dim] + - this->template getAttr<AvgPoolingAttr::PaddingDims>()[dim+DIM]) / + this->template getAttr<AvgPoolingAttr::KernelDims>()[dim]) / static_cast<float>(this->template getAttr<AvgPoolingAttr::StrideDims>()[dim]))); } outputDims[1] = mInput->dims()[1]; @@ -145,7 +139,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<AvgPooling_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -153,7 +147,7 @@ public: mInput->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -163,34 +157,37 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> AvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { - // FIXME: properly handle default w&b initialization in every cases + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); - auto avgPool = std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims), name); - return avgPool; + return std::make_shared<Node>(std::make_shared<AvgPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims), name); } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <DimSize_t DIM> inline std::shared_ptr<Node> AvgPooling( DimSize_t const (&kernel_dims)[DIM], const std::string& name = "", - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by AvgPooling, not supported"); - return AvgPooling(to_array(kernel_dims), name, stride_dims, padding_dims); + return AvgPooling(to_array(kernel_dims), name, stride_dims); } } // namespace Aidge namespace { template <> const char *const EnumStrings<Aidge::AvgPoolingAttr>::data[] = {"StrideDims", - "KernelDims", "PaddingDims"}; + "KernelDims"}; } #endif /* AIDGE_CORE_OPERATOR_AVGPOOLING_H_ */ diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 90a6be7222ee1b3e377520f2bc612a72c2ba4ab3..da7360c8ba3816cdfe1d2d00f80b08808a80f961 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -87,14 +87,14 @@ public: // return *in; // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 5 && "operators supports only 5 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) { if(mInputs[i]->size() != mInputs[0]->dims()[1]) { @@ -136,7 +136,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<BatchNorm_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -147,7 +147,7 @@ public: mInputs[4]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -160,6 +160,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 5; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "scale", "shift", "mean", "variance"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <DimSize_t DIM> diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 22553080c6d4d8359149b3b34c5d040e5e900c4d..b1e3e34b0eff681632d90cb8314ebd8c96722eec 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -26,13 +26,13 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims, PaddingDims }; +enum class ConvAttr { StrideDims, DilationDims, InChannels, OutChannels, KernelDims }; template <DimIdx_t DIM> class Conv_Op : public Operator, public Registrable<Conv_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Conv_Op<DIM> &)>, public StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >> { + DimSize_t, std::array<DimSize_t, DIM>> { public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), @@ -45,7 +45,7 @@ public: Conv_Op() = delete; using Attributes_ = StaticAttributes<ConvAttr, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, - DimSize_t, DimSize_t, std::array<DimSize_t, DIM>, std::array<DimSize_t, (DIM<<1) >>; + DimSize_t, DimSize_t, std::array<DimSize_t, DIM>>; template <ConvAttr e> using attr = typename Attributes_::template attr<e>; @@ -53,15 +53,13 @@ public: DimSize_t out_channels, const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), Attributes_(attr<ConvAttr::StrideDims>(stride_dims), - attr<ConvAttr::DilationDims>(dilation_dims), - attr<ConvAttr::InChannels>(in_channels), - attr<ConvAttr::OutChannels>(out_channels), - attr<ConvAttr::KernelDims>(kernel_dims), - attr<ConvAttr::PaddingDims>(padding_dims)) { + attr<ConvAttr::DilationDims>(dilation_dims), + attr<ConvAttr::InChannels>(in_channels), + attr<ConvAttr::OutChannels>(out_channels), + attr<ConvAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } @@ -100,14 +98,14 @@ public: // } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; @@ -117,9 +115,7 @@ public: 1; outputDims[dim+2] = 1 + static_cast<DimSize_t>( - floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent + - this->template getAttr<ConvAttr::PaddingDims>()[dim] + - this->template getAttr<ConvAttr::PaddingDims>()[dim+DIM]) / + floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent) / static_cast<float>(this->template getAttr<ConvAttr::StrideDims>()[dim]))); } @@ -160,7 +156,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -169,7 +165,7 @@ public: mInputs[2]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -181,6 +177,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> @@ -189,17 +191,17 @@ inline std::shared_ptr<Node> Conv(DimSize_t in_channels, const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); - auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, padding_dims, dilation_dims), name); + auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), name); // addProducer(conv, 1, append(append(kernel_dims, in_channels), out_channels), "w"); addProducer(conv, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); - addProducer(conv, 2, {out_channels}, "b"); + addProducer(conv, 2, std::array<DimSize_t, 1>({out_channels}), "b"); return conv; } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <DimSize_t DIM> inline std::shared_ptr<Node> Conv( DimSize_t in_channels, @@ -207,10 +209,9 @@ inline std::shared_ptr<Node> Conv( DimSize_t const (&kernel_dims)[DIM], const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Conv, not supported"); - return Conv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); + return Conv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, dilation_dims); } } // namespace Aidge @@ -221,8 +222,7 @@ const char *const EnumStrings<Aidge::ConvAttr>::data[] = { "DilationDims", "InChannels", "OutChannels", - "KernelDims", - "PaddingDims" + "KernelDims" }; } diff --git a/include/aidge/operator/ConvDepthWise.hpp b/include/aidge/operator/ConvDepthWise.hpp index 7a4db68bae2f42eb892dd7240463e7363753b5a7..4caec2032a3c61529d452ae855f00c1da411af10 100644 --- a/include/aidge/operator/ConvDepthWise.hpp +++ b/include/aidge/operator/ConvDepthWise.hpp @@ -26,7 +26,7 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class ConvDepthWiseAttr { StrideDims, DilationDims, Channels, KernelDims, PaddingDims }; +enum class ConvDepthWiseAttr { StrideDims, DilationDims, Channels, KernelDims }; template <DimIdx_t DIM> class ConvDepthWise_Op : public Operator, @@ -35,8 +35,7 @@ class ConvDepthWise_Op : public Operator, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >> { + std::array<DimSize_t, DIM>> { public: // FIXME: change accessibility std::array<std::shared_ptr<Tensor>, 3> mInputs = {std::make_shared<Tensor>(), std::make_shared<Tensor>(), @@ -52,21 +51,18 @@ class ConvDepthWise_Op : public Operator, std::array<DimSize_t, DIM>, std::array<DimSize_t, DIM>, DimSize_t, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >>; + std::array<DimSize_t, DIM>>; template <ConvDepthWiseAttr e> using attr = typename Attributes_::template attr<e>; constexpr ConvDepthWise_Op(const std::array<DimSize_t, DIM> &kernel_dims, const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), Attributes_(attr<ConvDepthWiseAttr::StrideDims>(stride_dims), - attr<ConvDepthWiseAttr::DilationDims>(dilation_dims), - attr<ConvDepthWiseAttr::Channels>(0), - attr<ConvDepthWiseAttr::KernelDims>(kernel_dims), - attr<ConvDepthWiseAttr::PaddingDims>(padding_dims)) { + attr<ConvDepthWiseAttr::DilationDims>(dilation_dims), + attr<ConvDepthWiseAttr::Channels>(0), + attr<ConvDepthWiseAttr::KernelDims>(kernel_dims)) { setDatatype(DataType::Float32); } @@ -92,14 +88,14 @@ class ConvDepthWise_Op : public Operator, return std::make_shared<ConvDepthWise_Op<DIM>>(*this); } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 3 && "operators supports only 3 inputs"); assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInputs[0]->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; @@ -109,9 +105,7 @@ class ConvDepthWise_Op : public Operator, 1; outputDims[dim+2] = 1 + static_cast<DimSize_t>( - floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent + - this->template getAttr<ConvDepthWiseAttr::PaddingDims>()[dim] + - this->template getAttr<ConvDepthWiseAttr::PaddingDims>()[dim+DIM]) / + floor(static_cast<float>(mInputs[0]->dims()[dim+2] - kernelExtent) / static_cast<float>(this->template getAttr<ConvDepthWiseAttr::StrideDims>()[dim]))); } this->template getAttr<ConvDepthWiseAttr::Channels>() = mInputs[0]->dims()[1]; @@ -161,7 +155,7 @@ class ConvDepthWise_Op : public Operator, - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<ConvDepthWise_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -170,7 +164,7 @@ class ConvDepthWise_Op : public Operator, mInputs[2]->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -182,38 +176,43 @@ class ConvDepthWise_Op : public Operator, inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> ConvDepthWise(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { // FIXME: properly handle default w&b initialization in every cases static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); - auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims, dilation_dims), name); + auto convDW = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), name); addProducer(convDW, 1, std::array<DimSize_t,0>({}), "w"); addProducer(convDW, 2, std::array<DimSize_t,0>({}), "b"); return convDW; } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <DimSize_t DIM> inline std::shared_ptr<Node> ConvDepthWise( DimSize_t const (&kernel_dims)[DIM], const std::string& name = "", const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0), const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by ConvDepthWise, not supported"); - return ConvDepthWise(to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); + return ConvDepthWise(to_array(kernel_dims), name, stride_dims, dilation_dims); } } // namespace Aidge namespace { template <> const char *const EnumStrings<Aidge::ConvDepthWiseAttr>::data[] = {"StrideDims", "DilationDims", "Channels", - "KernelDims", "PaddingDims"}; + "KernelDims"}; } #endif /* AIDGE_CORE_OPERATOR_CONVDEPTHWISE_H_ */ diff --git a/include/aidge/operator/FC.hpp b/include/aidge/operator/FC.hpp index 127d39a8bdfdd233cdac9e1ca6cf0bf85f656d16..b949527c51b9330077dd3bd8f8b4bf1f1b9d719c 100644 --- a/include/aidge/operator/FC.hpp +++ b/include/aidge/operator/FC.hpp @@ -135,7 +135,7 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<FC_Op>::create(name)(*this); mOutput->setBackend(name); @@ -145,7 +145,7 @@ public: mInputs[2]->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -158,13 +158,19 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 3; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight", "bias"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> FC(DimSize_t out_channels, bool noBias = false, const std::string& name = "") { // FIXME: properly handle default w&b initialization in every cases auto fc = std::make_shared<Node>(std::make_shared<FC_Op>(out_channels, noBias), name); - addProducer(fc, 1, {out_channels, 1}, "w"); - addProducer(fc, 2, {(noBias ? 0 : out_channels)}, "b"); // already sets bias dims + addProducer(fc, 1, std::array<DimSize_t, 2>({out_channels, 1}), "w"); + addProducer(fc, 2, (noBias ? std::array<DimSize_t, 1>({0}) : std::array<DimSize_t, 1>({out_channels})), "b"); // already sets bias dims return fc; } } // namespace Aidge @@ -175,4 +181,4 @@ const char *const EnumStrings<Aidge::FCAttr>::data[] = {"OutChannels", "NoBias"}; } -#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_OPERATOR_FC_H_ */ diff --git a/include/aidge/operator/GenericOperator.hpp b/include/aidge/operator/GenericOperator.hpp index 76602ee25808be5c00cde963346f15ffd69f53fb..55ccbf1516fa79663d57e1e44bc4017bc5c8b843 100644 --- a/include/aidge/operator/GenericOperator.hpp +++ b/include/aidge/operator/GenericOperator.hpp @@ -24,6 +24,7 @@ #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" + namespace Aidge { class GenericOperator_Op : public Operator, @@ -165,8 +166,8 @@ class GenericOperator_Op ~GenericOperator_Op() = default; - void setBackend(const std::string & /*name*/) { printf("setBackend: not available yet.\n"); } - void setDatatype(const DataType & /*datatype*/) { printf("setDatatype: not available yet.\n"); } + void setBackend(const std::string & /*name*/) override { printf("setBackend: not available yet.\n"); } + void setDatatype(const DataType & /*datatype*/) override { printf("setDatatype: not available yet.\n"); } void forward() override final { if(mImpl){ mImpl->forward(); @@ -181,7 +182,6 @@ class GenericOperator_Op printf("backward: No implementation is linked.\n"); } } - inline IOIndex_t nbInputs() const noexcept override final { return mNbIn; }; inline IOIndex_t nbDataInputs() const noexcept override final { return mNbDataIn; }; inline IOIndex_t nbOutputs() const noexcept override final { return mNbOut; }; diff --git a/include/aidge/operator/LeakyReLU.hpp b/include/aidge/operator/LeakyReLU.hpp index c6ee01239e1ed065587276c1891d26ba3899fe89..bcdcbc7cabd8eda46a7c0c4930f317e562fb46a0 100644 --- a/include/aidge/operator/LeakyReLU.hpp +++ b/include/aidge/operator/LeakyReLU.hpp @@ -120,14 +120,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<LeakyReLU_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -137,10 +137,15 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> LeakyReLU(float negativeSlope = 0.0f, const std::string& name = "") { - // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<LeakyReLU_Op>(negativeSlope), name); } } diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index d0dadd847a59c9d2a1c0dd97f2f200437da71863..eed1ec04535aa5896aa3d01a27d8023d37a42183 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -127,7 +127,7 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<MatMul_Op>::create(name)(*this); mOutput->setBackend(name); @@ -136,7 +136,7 @@ public: mInputs[1]->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -148,12 +148,18 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 2; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input", "weight"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> MatMul(DimSize_t out_channels, const std::string& name = "") { // FIXME: properly handle default w initialization in every cases auto matmul = std::make_shared<Node>(std::make_shared<MatMul_Op>(out_channels), name); - addProducer(matmul, 1, {out_channels, 1}, "w"); + addProducer(matmul, 1, std::array<DimSize_t, 2>({out_channels, 1}), "w"); return matmul; } } // namespace Aidge diff --git a/include/aidge/operator/MaxPooling.hpp b/include/aidge/operator/MaxPooling.hpp index eae7e30df039c0514443e567032427f7a6556360..874ea81778e0b357a4890b6bb052e85fa266216e 100644 --- a/include/aidge/operator/MaxPooling.hpp +++ b/include/aidge/operator/MaxPooling.hpp @@ -26,15 +26,14 @@ #include "aidge/utils/Types.h" namespace Aidge { -enum class MaxPoolingAttr { StrideDims, KernelDims, PaddingDims }; +enum class MaxPoolingAttr { StrideDims, KernelDims }; template <DimIdx_t DIM> class MaxPooling_Op : public Operator, public Registrable<MaxPooling_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const MaxPooling_Op<DIM> &)>, public StaticAttributes<MaxPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1) >> { + std::array<DimSize_t, DIM>> { private: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -47,18 +46,15 @@ public: using Attributes_ = StaticAttributes<MaxPoolingAttr, std::array<DimSize_t, DIM>, - std::array<DimSize_t, DIM>, - std::array<DimSize_t, (DIM<<1)> >; + std::array<DimSize_t, DIM>>; template <MaxPoolingAttr e> using attr = typename Attributes_::template attr<e>; constexpr MaxPooling_Op(const std::array<DimSize_t, DIM> &kernel_dims, - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) : Operator(Type), Attributes_(attr<MaxPoolingAttr::StrideDims>(stride_dims), - attr<MaxPoolingAttr::KernelDims>(kernel_dims), - attr<MaxPoolingAttr::PaddingDims>(padding_dims)), + attr<MaxPoolingAttr::KernelDims>(kernel_dims)), mOutput(std::make_shared<Tensor>()) { setDatatype(DataType::Float32); } @@ -85,7 +81,7 @@ public: return std::make_shared<MaxPooling_Op<DIM>>(*this); } - constexpr void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { assert(inputIdx < 1 && "operators supports only 3 inputs"); (void) inputIdx; // avoid unused warning assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); @@ -93,16 +89,14 @@ public: mInput = std::dynamic_pointer_cast<Tensor>(data); } - constexpr void computeOutputDims() override final { + void computeOutputDims() override final { if (!mInput->empty()) { std::array<DimSize_t, DIM + 2> outputDims = {}; for (std::size_t dim = 0; dim < this->template getAttr<MaxPoolingAttr::KernelDims>().size() ; ++dim) { outputDims[dim+2] = 1 + static_cast<DimSize_t>( std::floor(static_cast<float>(mInput->dims()[dim+2] - - this->template getAttr<MaxPoolingAttr::KernelDims>()[dim] + - this->template getAttr<MaxPoolingAttr::PaddingDims>()[dim] + - this->template getAttr<MaxPoolingAttr::PaddingDims>()[dim+DIM]) / + this->template getAttr<MaxPoolingAttr::KernelDims>()[dim]) / static_cast<float>(this->template getAttr<MaxPoolingAttr::StrideDims>()[dim]))); } outputDims[1] = mInput->dims()[1]; @@ -146,7 +140,7 @@ public: } - void setBackend(const std::string &name) { + void setBackend(const std::string &name) override { mImpl = Registrar<MaxPooling_Op<DIM>>::create(name)(*this); mOutput->setBackend(name); @@ -154,7 +148,7 @@ public: mInput->setBackend(name); } - void setDatatype(const DataType &datatype) { + void setDatatype(const DataType &datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -164,33 +158,36 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; template <std::array<DimSize_t, 1>::size_type DIM> inline std::shared_ptr<Node> MaxPooling(const std::array<DimSize_t, DIM> &kernel_dims, const std::string& name = "", - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { - // FIXME: properly handle default w&b initialization in every cases + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by MaxPooling, not supported"); - auto avgPool = std::make_shared<Node>(std::make_shared<MaxPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, padding_dims), name); - return avgPool; + return std::make_shared<Node>(std::make_shared<MaxPooling_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims), name); } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <DimSize_t DIM> inline std::shared_ptr<Node> MaxPooling( DimSize_t const (&kernel_dims)[DIM], const std::string& name = "", - const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), - const std::array<DimSize_t, (DIM<<1)> &padding_dims = create_array<DimSize_t,(DIM<<1)>(0)) { + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1)) { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by MaxPooling, not supported"); - return MaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims); + return MaxPooling(to_array(kernel_dims), name, stride_dims); } } // namespace Aidge namespace { template <> -const char *const EnumStrings<Aidge::MaxPoolingAttr>::data[] = {"StrideDims", "KernelDims", "PaddingDims"}; +const char *const EnumStrings<Aidge::MaxPoolingAttr>::data[] = {"StrideDims", "KernelDims"}; } #endif /* AIDGE_CORE_OPERATOR_MAXPOOLING_H_ */ diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 0c77a752493d251303c036c4061823c4f8bc499d..72058dfcba6e811a01a22e261208741879638cad 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -13,21 +13,38 @@ #define AIDGE_CORE_OPERATOR_METAOPERATOR_H_ #include "aidge/operator/Operator.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/scheduler/Scheduler.hpp" namespace Aidge { -class MetaOperator : public Operator { +class MetaOperator_Op : public Operator, + public Registrable<MetaOperator_Op, std::array<std::string, 2>, std::unique_ptr<OperatorImpl>(const MetaOperator_Op &)> { public: - MetaOperator() - : Operator("MetaOp") - { - } + std::vector<std::shared_ptr<Tensor>> mInputs; + std::vector<std::shared_ptr<Tensor>> mOutputs; // These are shared with micro-graph outputs tensors + + // Micro-graph handling: + std::shared_ptr<GraphView> mGraph; // Meta operator micro-graph + std::shared_ptr<SequentialScheduler> mScheduler; + // Need to store an ordored list of input/output operators for the micro-graph, + // because input/output nodes in a GraphView are unordered. + // TODO: refactor GraphView to handle ordered input/output? + std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mInputOps; + std::vector<std::pair<std::shared_ptr<Operator>, IOIndex_t>> mOutputOps; + + public: + MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, + std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), + std::vector<NodePtr> outputNodes = std::vector<NodePtr>()); /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). * @param op Operator to copy. */ - MetaOperator(const MetaOperator& op) - : Operator("MetaOp") + MetaOperator_Op(const MetaOperator_Op& op) + : Operator(op.type().c_str()), + mGraph(op.mGraph->clone()) { // cpy-ctor } @@ -37,11 +54,114 @@ public: * @see Operator::MatMul_Op */ std::shared_ptr<Operator> clone() const override { - return std::make_shared<MetaOperator>(*this); + return std::make_shared<MetaOperator_Op>(*this); + } + + const std::shared_ptr<GraphView>& getMicroGraph() const { + return mGraph; + } + + const std::shared_ptr<SequentialScheduler>& getMicroGraphScheduler() const { + return mScheduler; + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + const auto& inputOp = mInputOps[inputIdx]; + inputOp.first->associateInput(inputOp.second, data); + + // Associate inputs for custom implementation + mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + // Forward dims of micro-graph + mGraph->forwardDims(); + + // Associate outputs to micro-graph outputs for custom implementation + for (size_t outputIdx = 0; outputIdx < mOutputOps.size(); ++outputIdx) { + const auto& outputOp = mOutputOps[outputIdx]; + mOutputs[outputIdx] = outputOp.first->getOutput(outputOp.second); + } + } + + bool outputDimsForwarded() const override final { return !(mOutputs[0]->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return *(mInputs[inputIdx].get()); + } + + inline Tensor& output(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return *(mOutputs[outputIdx].get()); + } + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return mInputs[inputIdx]; + } + + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return mOutputs[outputIdx]; + } + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx < mInputs.size() && "inputIdx out of range"); + return std::static_pointer_cast<Data>(mInputs[inputIdx]); + } + + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx < mOutputs.size() && "outputIdx out of range"); + return std::static_pointer_cast<Data>(mOutputs[outputIdx]); + } + + void setBackend(const std::string &name) override { + if (Registrar<MetaOperator_Op>::exists({name, type()})) { + // A custom implementation exists for this meta operator + mImpl = Registrar<MetaOperator_Op>::create({name, type()})(*this); + } + + // The micro-graph should always be set to the right backend, since it + // shares input/output tensors. + // Input/output tensors backend are updated here. + mGraph->setBackend(name); + } + + void setDatatype(const DataType &datatype) override { + // The micro-graph should always be set to the right data type, since it + // shares input/output tensors. + // Input/output tensors data type are updated here. + mGraph->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return mGraph->inputs().size(); } + inline IOIndex_t nbDataInputs() const noexcept override final { return mGraph->dataInputs().size(); } + inline IOIndex_t nbOutputs() const noexcept override final { return mGraph->outputs().size(); } + + NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const override; + NbElts_t getNbConsumedData(IOIndex_t inputIdx) const override; + NbElts_t getNbProducedData(IOIndex_t outputIdx) const override; + + void updateConsummerProducer() override; + void forward() override; + void backward() override { + assert(false && "not implemented"); } - ~MetaOperator() = default; }; + +inline std::shared_ptr<Node> MetaOperator(const char *type, + const std::shared_ptr<GraphView>& graph, + const std::string& name = "", + std::vector<NodePtr> inputNodes = std::vector<NodePtr>(), + std::vector<NodePtr> outputNodes = std::vector<NodePtr>()) +{ + return std::make_shared<Node>(std::make_shared<MetaOperator_Op>(type, graph, inputNodes, outputNodes), name); } +} // namespace Aidge #endif /* MetaOperator_H_ */ diff --git a/include/aidge/operator/MetaOperatorDefs.hpp b/include/aidge/operator/MetaOperatorDefs.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6da76c930a3f08358c8c09ce75e66109370e292a --- /dev/null +++ b/include/aidge/operator/MetaOperatorDefs.hpp @@ -0,0 +1,140 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ +#define AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/AvgPooling.hpp" +#include "aidge/operator/MaxPooling.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/Pad.hpp" + +namespace Aidge { +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedConv(DimSize_t in_channels, + DimSize_t out_channels, + const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + // Construct micro-graph + auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); + auto conv = std::make_shared<Node>(std::make_shared<Conv_Op<static_cast<DimIdx_t>(DIM)>>(in_channels, out_channels, kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); + // Need to specify the ordered list of input operators + const std::vector<NodePtr> orderedInputNodes = {pad, conv}; + + auto metaOp = MetaOperator("PaddedConv", Sequential({pad, conv}), name, orderedInputNodes); + addProducer(metaOp, 1, append(out_channels, append(in_channels, kernel_dims)), "w"); + addProducer(metaOp, 2, {out_channels}, "b"); + return metaOp; +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedConv( + DimSize_t in_channels, + DimSize_t out_channels, + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + return PaddedConv(in_channels, out_channels, to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedConvDepthWise(const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + // Construct micro-graph + auto pad = Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : "", PadBorderType::Constant, 0.0); + auto conv = std::make_shared<Node>(std::make_shared<ConvDepthWise_Op<static_cast<DimIdx_t>(DIM)>>(kernel_dims, stride_dims, dilation_dims), (!name.empty()) ? name + "_conv" : ""); + // Need to specify the ordered list of input operators + const std::vector<NodePtr> orderedInputNodes = {pad, conv}; + + auto metaOp = MetaOperator("PaddedConvDepthWise", Sequential({pad, conv}), name, orderedInputNodes); + addProducer(metaOp, 1, std::array<DimSize_t,0>({}), "w"); + addProducer(metaOp, 2, std::array<DimSize_t,0>({}), "b"); + return metaOp; +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedConvDepthWise( + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0), + const std::array<DimSize_t, DIM> &dilation_dims = create_array<DimSize_t,DIM>(1)) +{ + return PaddedConvDepthWise(to_array(kernel_dims), name, stride_dims, padding_dims, dilation_dims); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedAvgPooling(const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0)) +{ + auto graph = Sequential({ + Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), + AvgPooling(kernel_dims, (!name.empty()) ? name + "_avgpooling" : "", stride_dims) + }); + + return MetaOperator("PaddedAvgPooling", graph, name); +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedAvgPooling( + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0)) +{ + return PaddedAvgPooling(to_array(kernel_dims), name, stride_dims, padding_dims); +} + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> PaddedMaxPooling(const std::array<DimSize_t, DIM> &kernel_dims, + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0)) +{ + auto graph = Sequential({ + Pad<DIM>(padding_dims, (!name.empty()) ? name + "_pad" : ""), + MaxPooling(kernel_dims, (!name.empty()) ? name + "_maxpooling" : "", stride_dims) + }); + + return MetaOperator("PaddedMaxPooling", graph, name); +} + +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> PaddedMaxPooling( + DimSize_t const (&kernel_dims)[DIM], + const std::string& name = "", + const std::array<DimSize_t, DIM> &stride_dims = create_array<DimSize_t,DIM>(1), + const std::array<DimSize_t, 2*DIM> &padding_dims = create_array<DimSize_t,2*DIM>(0)) +{ + return PaddedMaxPooling(to_array(kernel_dims), name, stride_dims, padding_dims); +} +} // namespace Aidge + +#endif /* AIDGE_CORE_OPERATOR_METAOPERATORDEFS_H_ */ diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 2582f83c45b431aad51e95ee4ba43e0db8abfe5a..903b6362adf3db0c867dc419086e0cb6ddaa65c7 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -20,7 +20,7 @@ #include "aidge/data/Data.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/utils/Types.h" -#include "aidge/hook/hook.hpp" +#include "aidge/hook/Hook.hpp" namespace Aidge { @@ -89,7 +89,7 @@ public: * @param inputIdx Index of the input analysed. * @return NbElts_t */ - NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; + virtual NbElts_t getNbRequiredData(const IOIndex_t inputIdx) const; /** * @brief Amount of data from a specific input actually used in one computation pass. @@ -97,7 +97,7 @@ public: * @param inputIdx Index of the input analysed. * @return NbElts_t */ - NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; + virtual NbElts_t getNbConsumedData(const IOIndex_t inputIdx) const; /** * @brief Amount of data ready to be used on a specific output. @@ -105,9 +105,9 @@ public: * @param outputIdx Index of the output analysed. * @return NbElts_t */ - NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; + virtual NbElts_t getNbProducedData(const IOIndex_t outputIdx) const; - void updateConsummerProducer(); + virtual void updateConsummerProducer(); virtual void forward(); @@ -124,6 +124,12 @@ public: virtual IOIndex_t nbInputs() const noexcept = 0; virtual IOIndex_t nbDataInputs() const noexcept = 0; virtual IOIndex_t nbOutputs() const noexcept = 0; + static const std::vector<std::string> getInputsName(){ + return {}; + } + static const std::vector<std::string> getOutputsName(){ + return {}; + } }; } // namespace Aidge diff --git a/include/aidge/operator/Pad.hpp b/include/aidge/operator/Pad.hpp new file mode 100644 index 0000000000000000000000000000000000000000..cbebb16e1e24501b0ea371fb45211047f6e2b5e7 --- /dev/null +++ b/include/aidge/operator/Pad.hpp @@ -0,0 +1,201 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CORE_OPERATOR_PAD_H_ +#define AIDGE_CORE_OPERATOR_PAD_H_ + +#include <array> +#include <numeric> +#include <vector> +#include <cmath> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/StaticAttributes.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { +enum class PadAttr { BeginEndBorders, BorderType, BorderValue }; +enum class PadBorderType { Constant, Edge, Reflect, Wrap }; + +template <DimIdx_t DIM> +class Pad_Op : public Operator, + public Registrable<Pad_Op<DIM>, std::string, std::unique_ptr<OperatorImpl>(const Pad_Op<DIM> &)>, + public StaticAttributes<PadAttr, + std::array<DimSize_t, 2*DIM>, + PadBorderType, + double> { +private: + // FIXME: change accessibility + std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); + const std::shared_ptr<Tensor> mOutput = std::make_shared<Tensor>(); + +public: + static constexpr const char *Type = "Pad"; + + Pad_Op() = delete; + + using Attributes_ = StaticAttributes<PadAttr, + std::array<DimSize_t, 2*DIM>, + PadBorderType, + double>; + template <PadAttr e> + using attr = typename Attributes_::template attr<e>; + + constexpr Pad_Op(const std::array<DimSize_t, 2*DIM> &beginEndTuples, + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) + : Operator(Type), + Attributes_(attr<PadAttr::BeginEndBorders>(beginEndTuples), + attr<PadAttr::BorderType>(borderType), + attr<PadAttr::BorderValue>(borderValue)) { + setDatatype(DataType::Float32); + } + + /** + * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). + * @param op Operator to copy. + */ + Pad_Op(const Pad_Op& op) + : Operator(Type), + Attributes_(op), + mOutput(std::make_shared<Tensor>(*op.mOutput)) + { + // cpy-ctor + setDatatype(op.mOutput->dataType()); + } + + /** + * @brief Clone the operator using its copy-constructor. + * @see Operator::Pad_Op + */ + std::shared_ptr<Operator> clone() const override { + return std::make_shared<Pad_Op<DIM>>(*this); + } + + void associateInput(const IOIndex_t inputIdx, std::shared_ptr<Data> data) override final { + assert(inputIdx < 1 && "operators supports only 3 inputs"); + (void) inputIdx; // avoid unused warning + assert(strcmp(data->type(), Tensor::Type) == 0 && "input data must be of Tensor type"); + + mInput = std::dynamic_pointer_cast<Tensor>(data); + } + + void computeOutputDims() override final { + if (!mInput->empty()) { + std::array<DimSize_t, DIM + 2> outputDims = {}; + + for (std::size_t dim = 0; dim < DIM; ++dim) { + outputDims[dim+2] = this->template getAttr<PadAttr::BeginEndBorders>()[2*dim] + + mInput->dims()[dim+2] + + this->template getAttr<PadAttr::BeginEndBorders>()[2*dim+1]; + } + outputDims[1] = mInput->dims()[1]; + outputDims[0] = mInput->dims()[0]; + mOutput->resize(outputDims); + } + } + + bool outputDimsForwarded() const override final { return !(mOutput->empty()); } + + + inline Tensor& input(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return *(mInput.get()); + } + inline Tensor& output(const IOIndex_t /*outputIdx*/) const override final { return *(mOutput.get()); } + + + inline std::shared_ptr<Tensor> getInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "Pad Operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return mInput; + } + inline std::shared_ptr<Tensor> getOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "Pad Operators has only 1 outputs"); + (void) outputIdx; // avoid unused warning + return mOutput; + } + + + std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final { + assert(inputIdx == 0 && "operators supports only 1 inputs"); + (void) inputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mInput); + } + std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const override final { + assert(outputIdx == 0 && "operator supports only 1 output"); + (void) outputIdx; // avoid unused warning + return std::static_pointer_cast<Data>(mOutput); + } + + + void setBackend(const std::string &name) override { + mImpl = Registrar<Pad_Op<DIM>>::create(name)(*this); + mOutput->setBackend(name); + + // FIXME: temporary workaround + mInput->setBackend(name); + } + + void setDatatype(const DataType &datatype) override { + mOutput->setDatatype(datatype); + + // FIXME: temporary workaround + mInput->setDatatype(datatype); + } + + inline IOIndex_t nbInputs() const noexcept override final { return 1; } + inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } + inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } +}; + +template <std::array<DimSize_t, 1>::size_type DIM> +inline std::shared_ptr<Node> Pad(const std::array<DimSize_t, 2*DIM> &beginEndTuples, + const std::string& name = "", + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) +{ + static_assert(DIM<=MaxDim,"Too many kernel dimensions required by Pad, not supported"); + return std::make_shared<Node>(std::make_shared<Pad_Op<static_cast<DimIdx_t>(DIM)>>(beginEndTuples, borderType, borderValue), name); +} + +// helper with C-style array instead of std::array for beginEndTuples to allow automatic template DIM deduction +template <DimSize_t DIM> +inline std::shared_ptr<Node> Pad( + DimSize_t const (&beginEndTuples)[2*DIM], + const std::string& name = "", + const PadBorderType &borderType = PadBorderType::Constant, + double borderValue = 0.0) +{ + return Pad<DIM>(to_array(beginEndTuples), name, borderType, borderValue); +} +} // namespace Aidge + +namespace { +template <> +const char *const EnumStrings<Aidge::PadAttr>::data[] = {"BeginEndBorders", "BorderType", "BorderValue"}; + +template <> +const char *const EnumStrings<Aidge::PadBorderType>::data[] = {"Constant", "Edge", "Reflect", "Wrap"}; +} + +#endif /* AIDGE_CORE_OPERATOR_PAD_H_ */ diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 593192c9f402e2646ac94cff68aa0c805f5aecd1..d747b340618cc7e321f2cfc2ed9169798e5d77e9 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -79,7 +79,7 @@ public: * @brief Set the Output Tensor of the Producer operator. * This method will create a copy of the Tensor. * - * @param newOutput Tensor containing the values to copy + * @param newOutput Tensor containing the values to copy */ void setOutputTensor(const Tensor& newOutput) { *mOutput = newOutput; @@ -121,17 +121,23 @@ public: inline const std::vector<DimSize_t> dims() const noexcept { return mOutput->dims(); } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Producer_Op>::create(name)(*this); mOutput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); } inline IOIndex_t nbInputs() const noexcept override final { return 0; }; inline IOIndex_t nbDataInputs() const noexcept override final { return 0; }; inline IOIndex_t nbOutputs() const noexcept override final { return 1; }; + static const std::vector<std::string> getInputsName(){ + return {}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } public: void forward() override final { @@ -148,6 +154,7 @@ inline std::shared_ptr<Node> Producer(const std::array<DimSize_t, DIM> &dims, co return std::make_shared<Node>(std::make_shared<Producer_Op>(dims), name); } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <std::size_t DIM> inline std::shared_ptr<Node> Producer(DimSize_t const (&dims)[DIM], const std::string& name = "") { return Producer(to_array(dims), name); @@ -167,6 +174,7 @@ void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, con otherNode->getOperator()->associateInput(inputIdx, prod->getOperator()->getRawOutput(0)); } +// helper with C-style array instead of std::array for kernel_dims to allow automatic template DIM deduction template <std::size_t DIM> void addProducer(std::shared_ptr<Node>& otherNode, const IOIndex_t inputIdx, DimSize_t const (&dims)[DIM], const std::string& extension) { addProducer(otherNode, inputIdx, to_array(dims), extension); diff --git a/include/aidge/operator/ReLU.hpp b/include/aidge/operator/ReLU.hpp index 433e353f05f8b4ffc3cfc0e047464e7f9257da02..52f13f1c5ce1d0b7a0d4ccaa4d7fe9927bcc3e53 100644 --- a/include/aidge/operator/ReLU.hpp +++ b/include/aidge/operator/ReLU.hpp @@ -108,14 +108,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<ReLU_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -125,10 +125,15 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> ReLU(const std::string& name = "") { - // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<ReLU_Op>(), name); } } diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 0ea6ba39b3e4def2011ae5c7b2b9c348df5e2929..353666fb3950d034a7dbe8ec1d3ebdb312679f95 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -130,13 +130,13 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Scaling_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -146,6 +146,12 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") { diff --git a/include/aidge/operator/Softmax.hpp b/include/aidge/operator/Softmax.hpp index 898bae4c31bb2c41947523a86bfb9cd5c7b732b4..ba6132a5ee00325d0f7de57db117a169d42352e9 100644 --- a/include/aidge/operator/Softmax.hpp +++ b/include/aidge/operator/Softmax.hpp @@ -108,14 +108,14 @@ public: } - void setBackend(const std::string& name) { + void setBackend(const std::string& name) override { mImpl = Registrar<Softmax_Op>::create(name)(*this); mOutput->setBackend(name); // FIXME: temporary workaround mInput->setBackend(name); } - void setDatatype(const DataType& datatype) { + void setDatatype(const DataType& datatype) override { mOutput->setDatatype(datatype); // FIXME: temporary workaround @@ -125,10 +125,15 @@ public: inline IOIndex_t nbInputs() const noexcept override final { return 1; } inline IOIndex_t nbDataInputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } + static const std::vector<std::string> getInputsName(){ + return {"data_input"}; + } + static const std::vector<std::string> getOutputsName(){ + return {"data_output"}; + } }; inline std::shared_ptr<Node> Softmax(const std::string& name = "") { - // FIXME: properly handle default w&b initialization in every cases return std::make_shared<Node>(std::make_shared<Softmax_Op>(), name); } } diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 9916ee2004bd1aa9f33acf96d95cae4703f692df..1896894ee8690cedaef696394da0829604e36211 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -89,11 +89,6 @@ private: * */ std::vector<std::shared_ptr<Node>> mStaticSchedule; - /** - * @brief Number of computation node (i.e: nb nodes != Producer) - * - */ - std::size_t mComputationNumber = 0; // TODO: Check if not inferable from mStaticSchedule }; } // namespace Aidge diff --git a/include/aidge/utils/Attributes.hpp b/include/aidge/utils/Attributes.hpp index 76875f15ff4229522e6208b0edb23ec519ff59ce..d3444000191022b575adaf1430319479daa5d4fc 100644 --- a/include/aidge/utils/Attributes.hpp +++ b/include/aidge/utils/Attributes.hpp @@ -18,6 +18,7 @@ #endif #include <vector> #include <string> +#include <set> #ifdef PYBIND namespace py = pybind11; diff --git a/include/aidge/utils/DynamicAttributes.hpp b/include/aidge/utils/DynamicAttributes.hpp index 60f586edf947cef0e139049814263a29b4d01e24..2af8f47e9420f266cc6eca21f167944c761db7ea 100644 --- a/include/aidge/utils/DynamicAttributes.hpp +++ b/include/aidge/utils/DynamicAttributes.hpp @@ -19,7 +19,7 @@ #include <cassert> #include <string> -#include "aidge/utils/Any.hpp" +#include "aidge/utils/future_std/any.hpp" #include "aidge/utils/Attributes.hpp" #ifdef PYBIND @@ -54,12 +54,12 @@ public: auto itPy = mAttrsPy.find(name); if (itPy != mAttrsPy.end()) { // Insert the attribute back in C++ - mAttrs.emplace(std::make_pair(name, libany::any(itPy->second.cast<T>()))); + mAttrs.emplace(std::make_pair(name, future_std::any(itPy->second.cast<T>()))); } } #endif - return libany::any_cast<T&>(mAttrs.at(name)); + return future_std::any_cast<T&>(mAttrs.at(name)); } template<class T> const T& getAttr(const std::string& name) const @@ -71,12 +71,12 @@ public: auto itPy = mAttrsPy.find(name); if (itPy != mAttrsPy.end()) { // Insert the attribute back in C++ - mAttrs.emplace(std::make_pair(name, libany::any(itPy->second.cast<T>()))); + mAttrs.emplace(std::make_pair(name, future_std::any(itPy->second.cast<T>()))); } } #endif - return libany::any_cast<const T&>(mAttrs.at(name)); + return future_std::any_cast<const T&>(mAttrs.at(name)); } ///\brief Add a new Attribute, identified by its name. If it already exists, asserts. @@ -85,7 +85,7 @@ public: ///\param value Attribute value template<class T> void addAttr(const std::string& name, const T& value) { - const auto& res = mAttrs.emplace(std::make_pair(name, libany::any(value))); + const auto& res = mAttrs.emplace(std::make_pair(name, future_std::any(value))); assert(res.second && "attribute already exists"); #ifdef PYBIND @@ -103,9 +103,9 @@ public: ///\param value Attribute value template<class T> void setAttr(const std::string& name, const T& value) { - auto res = mAttrs.emplace(std::make_pair(name, libany::any(value))); + auto res = mAttrs.emplace(std::make_pair(name, future_std::any(value))); if (!res.second) - res.first->second = libany::any(value); + res.first->second = future_std::any(value); #ifdef PYBIND // We cannot handle Python object if the Python interpreter is not running @@ -194,7 +194,7 @@ public: * generic type caster for std::any is not feasable. * The strategy here is to keep a copy of each attribute in py::object that is updated everytime. */ - py::object getAttrPy(const std::string& name) const { + py::object getAttrPy(const std::string& name) const override final { return mAttrsPy.at(name); }; #endif @@ -210,9 +210,9 @@ private: std::map<std::string, py::object> mAttrsPy; // Stores C++ attributes only // mutable because it may be updated in getAttr() from Python - mutable std::map<std::string, libany::any> mAttrs; + mutable std::map<std::string, future_std::any> mAttrs; #else - std::map<std::string, libany::any> mAttrs; + std::map<std::string, future_std::any> mAttrs; #endif }; diff --git a/include/aidge/utils/Utils.hpp b/include/aidge/utils/ErrorHandling.hpp similarity index 53% rename from include/aidge/utils/Utils.hpp rename to include/aidge/utils/ErrorHandling.hpp index 7c0c03c82ff252b6175d3c9bbe5395bb05127c9f..8fbeff30abecfec0077786b21825b6a6f36677c6 100644 --- a/include/aidge/utils/Utils.hpp +++ b/include/aidge/utils/ErrorHandling.hpp @@ -10,16 +10,21 @@ ********************************************************************************/ -#ifndef AIDGE_UTILS_H_ -#define AIDGE_UTILS_H_ +#ifndef AIDGE_ERRORHANDLING_H_ +#define AIDGE_ERRORHANDLING_H_ #include <cstdio> +#include <memory> -#ifdef NO_EXCEPTIONS +#define AIDGE_STRINGIZE_DETAIL(x) #x +#define AIDGE_STRINGIZE(x) AIDGE_STRINGIZE_DETAIL(x) + +#ifdef NO_EXCEPTION #define AIDGE_THROW_OR_ABORT(ex, ...) \ do { std::printf(__VA_ARGS__); std::abort(); } while (false) #else #include <stdexcept> +#include <memory> #define AIDGE_THROW_OR_ABORT(ex, ...) \ do { \ int n = 128; \ @@ -34,4 +39,21 @@ do { \ } while (false) #endif -#endif //AIDGE_UTILS_H_ \ No newline at end of file +/** + * Macro for specified API assertions. + * Used to check logic directly related to user's inputs. + * If it asserts, it means an user error. +*/ +#define AIDGE_ASSERT(stm, ...) \ +if (!(stm)) { printf("Assertion failed: " AIDGE_STRINGIZE(stm) " in " __FILE__ ":%d", __LINE__); \ + AIDGE_THROW_OR_ABORT(std::runtime_error, __VA_ARGS__); } + +/** + * Macro for internal assertions. + * Used to check internal logic not directly related to API user's inputs. + * If it asserts, it means a bug. +*/ +#define AIDGE_INTERNAL_ASSERT(stm) \ +assert((stm) && "Internal assertion failed: " #stm " in " __FILE__ ":" AIDGE_STRINGIZE(__LINE__)) + +#endif //AIDGE_ERRORHANDLING_H_ diff --git a/include/aidge/utils/Registrar.hpp b/include/aidge/utils/Registrar.hpp index 993519f380c8e70cf6d601985f6c8faf0b268369..ece74509d466800c870d73d1e0bbe1d639f8bf54 100644 --- a/include/aidge/utils/Registrar.hpp +++ b/include/aidge/utils/Registrar.hpp @@ -58,6 +58,11 @@ struct Registrar { //assert(newInsert && "registrar already exists"); } + static bool exists(const typename C::registrar_key& key) { + const auto it = C::registry().find(key); + return (it != C::registry().end()); + } + static auto create(const typename C::registrar_key& key){ const auto it = C::registry().find(key); assert(it != C::registry().end() && "invalid registrar key"); diff --git a/include/aidge/utils/StaticAttributes.hpp b/include/aidge/utils/StaticAttributes.hpp index fb800cffbcff5d4113961f8e62977417336f2cb8..b67f69ae7afc2c22f3b424812ec994b10974b668 100644 --- a/include/aidge/utils/StaticAttributes.hpp +++ b/include/aidge/utils/StaticAttributes.hpp @@ -18,7 +18,7 @@ #include <typeinfo> #include "aidge/utils/Attributes.hpp" -#include "aidge/utils/Utils.hpp" +#include "aidge/utils/ErrorHandling.hpp" namespace Aidge { /** @@ -87,7 +87,7 @@ public: // Runtime access with name template <typename R> - constexpr R& getAttr(const char* name) { + R& getAttr(const char* name) { for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { if (strcmp(EnumStrings<ATTRS_ENUM>::data[i], name) == 0) { return getAttr<R>(i); @@ -98,7 +98,7 @@ public: } template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> - constexpr typename std::enable_if<(SIZE > 0), R&>::type getAttr(std::size_t i) { + typename std::enable_if<(SIZE > 0), R&>::type getAttr(std::size_t i) { if (i == SIZE-1) { if (std::is_same<R, typename std::tuple_element<SIZE-1,std::tuple<T...>>::type>::value) { return reinterpret_cast<R&>(std::get<SIZE-1>(mAttrs)); @@ -113,7 +113,7 @@ public: } template <typename R, std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> - [[noreturn]] constexpr typename std::enable_if<(SIZE == 0), R&>::type getAttr(std::size_t /*i*/) { + [[noreturn]] typename std::enable_if<(SIZE == 0), R&>::type getAttr(std::size_t /*i*/) { AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute not found"); } @@ -128,7 +128,7 @@ public: } template <std::size_t SIZE = std::tuple_size<std::tuple<T...>>::value> - [[noreturn]] constexpr typename std::enable_if<(SIZE == 0), const std::type_info&>::type getAttrType(std::size_t /*i*/) const { + [[noreturn]] typename std::enable_if<(SIZE == 0), const std::type_info&>::type getAttrType(std::size_t /*i*/) const { AIDGE_THROW_OR_ABORT(std::runtime_error, "attribute not found"); } @@ -140,7 +140,7 @@ public: /// Generic Attributes API ////////////////////////////////////// // Runtime existance check with name - constexpr bool hasAttr(const std::string& name) const override final { + bool hasAttr(const std::string& name) const override final { for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { if (name == EnumStrings<ATTRS_ENUM>::data[i]) { return true; @@ -151,7 +151,7 @@ public: } // Runtime type access with name - constexpr std::string getAttrType(const std::string& name) const override final { + std::string getAttrType(const std::string& name) const override final { for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { if (name == EnumStrings<ATTRS_ENUM>::data[i]) { return getAttrType(i).name(); @@ -170,7 +170,7 @@ public: } #ifdef PYBIND - py::object getAttrPy(const std::string& name) const { + py::object getAttrPy(const std::string& name) const override { for (std::size_t i = 0; i < size(EnumStrings<ATTRS_ENUM>::data); ++i) { if (name == EnumStrings<ATTRS_ENUM>::data[i]) { // https://github.com/pybind/pybind11/blob/f3e0602802c7840992c97f4960515777cad6a5c7/include/pybind11/pytypes.h#L1119-L1138 diff --git a/include/aidge/utils/Any.hpp b/include/aidge/utils/future_std/any.hpp similarity index 98% rename from include/aidge/utils/Any.hpp rename to include/aidge/utils/future_std/any.hpp index 0e65710596d31920de60a35d600e7ae612ea2bc4..8d9bfe28d0497dc12c59aaed68a23d3a9563815e 100644 --- a/include/aidge/utils/Any.hpp +++ b/include/aidge/utils/future_std/any.hpp @@ -14,11 +14,11 @@ * Copyright (c) 2018 Claudio Fantacci * * Distributed under the Boost Software License, Version 1.0. - * (See accompanying file LICENSE.md or copy at http://www.boost.org/LICENSE_1_0.txt) + * (See copy at http://www.boost.org/LICENSE_1_0.txt) */ -#ifndef AIDGE_CORE_UTILS_ANY_H_ -#define AIDGE_CORE_UTILS_ANY_H_ +#ifndef AIDGE_CORE_UTILS_FUTURE_STD_ANY_H_ +#define AIDGE_CORE_UTILS_FUTURE_STD_ANY_H_ #include <stdexcept> #include <typeinfo> @@ -26,7 +26,7 @@ #include <utility> -namespace libany +namespace future_std { class bad_any_cast : public std::bad_cast @@ -549,4 +549,4 @@ inline void swap(any& lhs, any& rhs) noexcept } -#endif /* AIDGE_CORE_UTILS_ANY_H_ */ +#endif /* AIDGE_CORE_UTILS_FUTURE_STD_ANY_H_ */ diff --git a/include/aidge/utils/future_std/expected.hpp b/include/aidge/utils/future_std/expected.hpp new file mode 100644 index 0000000000000000000000000000000000000000..c271d0e8d8066c0bcd0358f28f8bcd711a8b6ba0 --- /dev/null +++ b/include/aidge/utils/future_std/expected.hpp @@ -0,0 +1,3487 @@ +// Origin: https://github.com/martinmoene/expected-lite +// +// This version targets C++11 and later. +// +// Copyright (C) 2016-2020 Martin Moene. +// +// Distributed under the Boost Software License, Version 1.0. +// (See copy at http://www.boost.org/LICENSE_1_0.txt) +// +// expected lite is based on: +// A proposal to add a utility class to represent expected monad +// by Vicente J. Botet Escriba and Pierre Talbot. http:://wg21.link/p0323 + +#ifndef AIDGE_CORE_UTILS_FUTURE_STD_EXPECTED_H_ +#define AIDGE_CORE_UTILS_FUTURE_STD_EXPECTED_H_ + +#define expected_lite_MAJOR 0 +#define expected_lite_MINOR 6 +#define expected_lite_PATCH 3 + +#define expected_lite_VERSION expected_STRINGIFY(expected_lite_MAJOR) "." expected_STRINGIFY(expected_lite_MINOR) "." expected_STRINGIFY(expected_lite_PATCH) + +#define expected_STRINGIFY( x ) expected_STRINGIFY_( x ) +#define expected_STRINGIFY_( x ) #x + +// expected-lite configuration: + +#define nsel_EXPECTED_DEFAULT 0 +#define nsel_EXPECTED_FUTURE_STD 1 +#define nsel_EXPECTED_STD 2 + +// tweak header support: + +#ifdef __has_include +# if __has_include(<future_std/expected.tweak.hpp>) +# include <future_std/expected.tweak.hpp> +# endif +#define expected_HAVE_TWEAK_HEADER 1 +#else +#define expected_HAVE_TWEAK_HEADER 0 +//# pragma message("expected.hpp: Note: Tweak header not supported.") +#endif + +// expected selection and configuration: + +#if !defined( nsel_CONFIG_SELECT_EXPECTED ) +# define nsel_CONFIG_SELECT_EXPECTED ( nsel_HAVE_STD_EXPECTED ? nsel_EXPECTED_STD : nsel_EXPECTED_FUTURE_STD ) +#endif + +// Proposal revisions: +// +// DXXXXR0: -- +// N4015 : -2 (2014-05-26) +// N4109 : -1 (2014-06-29) +// P0323R0: 0 (2016-05-28) +// P0323R1: 1 (2016-10-12) +// -------: +// P0323R2: 2 (2017-06-15) +// P0323R3: 3 (2017-10-15) +// P0323R4: 4 (2017-11-26) +// P0323R5: 5 (2018-02-08) +// P0323R6: 6 (2018-04-02) +// P0323R7: 7 (2018-06-22) * +// +// expected-lite uses 2 and higher + +#ifndef nsel_P0323R +# define nsel_P0323R 7 +#endif + +// Monadic operations proposal revisions: +// +// P2505R0: 0 (2021-12-12) +// P2505R1: 1 (2022-02-10) +// P2505R2: 2 (2022-04-15) +// P2505R3: 3 (2022-06-05) +// P2505R4: 4 (2022-06-15) +// P2505R5: 5 (2022-09-20) * +// +// expected-lite uses 5 + +#ifndef nsel_P2505R +# define nsel_P2505R 5 +#endif + +// Control presence of C++ exception handling (try and auto discover): + +#ifndef nsel_CONFIG_NO_EXCEPTIONS +# if defined(_MSC_VER) +# include <cstddef> // for _HAS_EXCEPTIONS +# endif +# if defined(__cpp_exceptions) || defined(__EXCEPTIONS) || (_HAS_EXCEPTIONS) +# define nsel_CONFIG_NO_EXCEPTIONS 0 +# else +# define nsel_CONFIG_NO_EXCEPTIONS 1 +# endif +#endif + +// at default use SEH with MSVC for no C++ exceptions + +#ifndef nsel_CONFIG_NO_EXCEPTIONS_SEH +# define nsel_CONFIG_NO_EXCEPTIONS_SEH ( nsel_CONFIG_NO_EXCEPTIONS && _MSC_VER ) +#endif + +// C++ language version detection (C++23 is speculative): +// Note: VC14.0/1900 (VS2015) lacks too much from C++14. + +#ifndef nsel_CPLUSPLUS +# if defined(_MSVC_LANG ) && !defined(__clang__) +# define nsel_CPLUSPLUS (_MSC_VER == 1900 ? 201103L : _MSVC_LANG ) +# else +# define nsel_CPLUSPLUS __cplusplus +# endif +#endif + +#define nsel_CPP98_OR_GREATER ( nsel_CPLUSPLUS >= 199711L ) +#define nsel_CPP11_OR_GREATER ( nsel_CPLUSPLUS >= 201103L ) +#define nsel_CPP14_OR_GREATER ( nsel_CPLUSPLUS >= 201402L ) +#define nsel_CPP17_OR_GREATER ( nsel_CPLUSPLUS >= 201703L ) +#define nsel_CPP20_OR_GREATER ( nsel_CPLUSPLUS >= 202002L ) +#define nsel_CPP23_OR_GREATER ( nsel_CPLUSPLUS >= 202300L ) + +// Use C++23 std::expected if available and requested: + +#if nsel_CPP23_OR_GREATER && defined(__has_include ) +# if __has_include( <expected> ) +# define nsel_HAVE_STD_EXPECTED 1 +# else +# define nsel_HAVE_STD_EXPECTED 0 +# endif +#else +# define nsel_HAVE_STD_EXPECTED 0 +#endif + +#define nsel_USES_STD_EXPECTED ( (nsel_CONFIG_SELECT_EXPECTED == nsel_EXPECTED_STD) || ((nsel_CONFIG_SELECT_EXPECTED == nsel_EXPECTED_DEFAULT) && nsel_HAVE_STD_EXPECTED) ) + +// +// in_place: code duplicated in any-lite, expected-lite, expected-lite, value-ptr-lite, variant-lite: +// + +#ifndef future_std_lite_HAVE_IN_PLACE_TYPES +#define future_std_lite_HAVE_IN_PLACE_TYPES 1 + +// C++17 std::in_place in <utility>: + +#if nsel_CPP17_OR_GREATER + +#include <utility> + +namespace future_std { + +using std::in_place; +using std::in_place_type; +using std::in_place_index; +using std::in_place_t; +using std::in_place_type_t; +using std::in_place_index_t; + +#define future_std_lite_in_place_t( T) std::in_place_t +#define future_std_lite_in_place_type_t( T) std::in_place_type_t<T> +#define future_std_lite_in_place_index_t(K) std::in_place_index_t<K> + +#define future_std_lite_in_place( T) std::in_place_t{} +#define future_std_lite_in_place_type( T) std::in_place_type_t<T>{} +#define future_std_lite_in_place_index(K) std::in_place_index_t<K>{} + +} // namespace future_std + +#else // nsel_CPP17_OR_GREATER + +#include <cstddef> + +namespace future_std { +namespace detail { + +template< class T > +struct in_place_type_tag {}; + +template< std::size_t K > +struct in_place_index_tag {}; + +} // namespace detail + +struct in_place_t {}; + +template< class T > +inline in_place_t in_place( detail::in_place_type_tag<T> = detail::in_place_type_tag<T>() ) +{ + return in_place_t(); +} + +template< std::size_t K > +inline in_place_t in_place( detail::in_place_index_tag<K> = detail::in_place_index_tag<K>() ) +{ + return in_place_t(); +} + +template< class T > +inline in_place_t in_place_type( detail::in_place_type_tag<T> = detail::in_place_type_tag<T>() ) +{ + return in_place_t(); +} + +template< std::size_t K > +inline in_place_t in_place_index( detail::in_place_index_tag<K> = detail::in_place_index_tag<K>() ) +{ + return in_place_t(); +} + +// mimic templated typedef: + +#define future_std_lite_in_place_t( T) future_std::in_place_t(&)( future_std::detail::in_place_type_tag<T> ) +#define future_std_lite_in_place_type_t( T) future_std::in_place_t(&)( future_std::detail::in_place_type_tag<T> ) +#define future_std_lite_in_place_index_t(K) future_std::in_place_t(&)( future_std::detail::in_place_index_tag<K> ) + +#define future_std_lite_in_place( T) future_std::in_place_type<T> +#define future_std_lite_in_place_type( T) future_std::in_place_type<T> +#define future_std_lite_in_place_index(K) future_std::in_place_index<K> + +} // namespace future_std + +#endif // nsel_CPP17_OR_GREATER +#endif // future_std_lite_HAVE_IN_PLACE_TYPES + +// +// Using std::expected: +// + +#if nsel_USES_STD_EXPECTED + +#include <expected> + +namespace future_std { + + using std::expected; +// ... +} + +#else // nsel_USES_STD_EXPECTED + +#include <cassert> +#include <exception> +#include <functional> +#include <initializer_list> +#include <memory> +#include <new> +#include <system_error> +#include <type_traits> +#include <utility> + +// additional includes: + +#if nsel_CONFIG_NO_EXCEPTIONS +# if nsel_CONFIG_NO_EXCEPTIONS_SEH +# include <windows.h> // for ExceptionCodes +# else +// already included: <cassert> +# endif +#else +# include <stdexcept> +#endif + +// C++ feature usage: + +#if nsel_CPP11_OR_GREATER +# define nsel_constexpr constexpr +#else +# define nsel_constexpr /*constexpr*/ +#endif + +#if nsel_CPP14_OR_GREATER +# define nsel_constexpr14 constexpr +#else +# define nsel_constexpr14 /*constexpr*/ +#endif + +#if nsel_CPP17_OR_GREATER +# define nsel_inline17 inline +#else +# define nsel_inline17 /*inline*/ +#endif + +// Compiler versions: +// +// MSVC++ 6.0 _MSC_VER == 1200 nsel_COMPILER_MSVC_VERSION == 60 (Visual Studio 6.0) +// MSVC++ 7.0 _MSC_VER == 1300 nsel_COMPILER_MSVC_VERSION == 70 (Visual Studio .NET 2002) +// MSVC++ 7.1 _MSC_VER == 1310 nsel_COMPILER_MSVC_VERSION == 71 (Visual Studio .NET 2003) +// MSVC++ 8.0 _MSC_VER == 1400 nsel_COMPILER_MSVC_VERSION == 80 (Visual Studio 2005) +// MSVC++ 9.0 _MSC_VER == 1500 nsel_COMPILER_MSVC_VERSION == 90 (Visual Studio 2008) +// MSVC++ 10.0 _MSC_VER == 1600 nsel_COMPILER_MSVC_VERSION == 100 (Visual Studio 2010) +// MSVC++ 11.0 _MSC_VER == 1700 nsel_COMPILER_MSVC_VERSION == 110 (Visual Studio 2012) +// MSVC++ 12.0 _MSC_VER == 1800 nsel_COMPILER_MSVC_VERSION == 120 (Visual Studio 2013) +// MSVC++ 14.0 _MSC_VER == 1900 nsel_COMPILER_MSVC_VERSION == 140 (Visual Studio 2015) +// MSVC++ 14.1 _MSC_VER >= 1910 nsel_COMPILER_MSVC_VERSION == 141 (Visual Studio 2017) +// MSVC++ 14.2 _MSC_VER >= 1920 nsel_COMPILER_MSVC_VERSION == 142 (Visual Studio 2019) + +#if defined(_MSC_VER) && !defined(__clang__) +# define nsel_COMPILER_MSVC_VER (_MSC_VER ) +# define nsel_COMPILER_MSVC_VERSION (_MSC_VER / 10 - 10 * ( 5 + (_MSC_VER < 1900)) ) +#else +# define nsel_COMPILER_MSVC_VER 0 +# define nsel_COMPILER_MSVC_VERSION 0 +#endif + +#define nsel_COMPILER_VERSION( major, minor, patch ) ( 10 * ( 10 * (major) + (minor) ) + (patch) ) + +#if defined(__clang__) +# define nsel_COMPILER_CLANG_VERSION nsel_COMPILER_VERSION(__clang_major__, __clang_minor__, __clang_patchlevel__) +#else +# define nsel_COMPILER_CLANG_VERSION 0 +#endif + +#if defined(__GNUC__) && !defined(__clang__) +# define nsel_COMPILER_GNUC_VERSION nsel_COMPILER_VERSION(__GNUC__, __GNUC_MINOR__, __GNUC_PATCHLEVEL__) +#else +# define nsel_COMPILER_GNUC_VERSION 0 +#endif + +// half-open range [lo..hi): +//#define nsel_BETWEEN( v, lo, hi ) ( (lo) <= (v) && (v) < (hi) ) + +// Method enabling + +#define nsel_REQUIRES_0(...) \ + template< bool B = (__VA_ARGS__), typename std::enable_if<B, int>::type = 0 > + +#define nsel_REQUIRES_T(...) \ + , typename std::enable_if< (__VA_ARGS__), int >::type = 0 + +#define nsel_REQUIRES_R(R, ...) \ + typename std::enable_if< (__VA_ARGS__), R>::type + +#define nsel_REQUIRES_A(...) \ + , typename std::enable_if< (__VA_ARGS__), void*>::type = nullptr + +// Presence of language and library features: + +#ifdef _HAS_CPP0X +# define nsel_HAS_CPP0X _HAS_CPP0X +#else +# define nsel_HAS_CPP0X 0 +#endif + +//#define nsel_CPP11_140 (nsel_CPP11_OR_GREATER || nsel_COMPILER_MSVC_VER >= 1900) + +// Clang, GNUC, MSVC warning suppression macros: + +#ifdef __clang__ +# pragma clang diagnostic push +#elif defined __GNUC__ +# pragma GCC diagnostic push +#endif // __clang__ + +#if nsel_COMPILER_MSVC_VERSION >= 140 +# pragma warning( push ) +# define nsel_DISABLE_MSVC_WARNINGS(codes) __pragma( warning(disable: codes) ) +#else +# define nsel_DISABLE_MSVC_WARNINGS(codes) +#endif + +#ifdef __clang__ +# define nsel_RESTORE_WARNINGS() _Pragma("clang diagnostic pop") +#elif defined __GNUC__ +# define nsel_RESTORE_WARNINGS() _Pragma("GCC diagnostic pop") +#elif nsel_COMPILER_MSVC_VERSION >= 140 +# define nsel_RESTORE_WARNINGS() __pragma( warning( pop ) ) +#else +# define nsel_RESTORE_WARNINGS() +#endif + +// Suppress the following MSVC (GSL) warnings: +// - C26409: Avoid calling new and delete explicitly, use std::make_unique<T> instead (r.11) + +nsel_DISABLE_MSVC_WARNINGS( 26409 ) + +// +// expected: +// + +namespace future_std { namespace expected_lite { + +// type traits C++17: + +namespace std17 { + +#if nsel_CPP17_OR_GREATER + +using std::conjunction; +using std::is_swappable; +using std::is_nothrow_swappable; + +#else // nsel_CPP17_OR_GREATER + +namespace detail { + +using std::swap; + +struct is_swappable +{ + template< typename T, typename = decltype( swap( std::declval<T&>(), std::declval<T&>() ) ) > + static std::true_type test( int /* unused */); + + template< typename > + static std::false_type test(...); +}; + +struct is_nothrow_swappable +{ + // wrap noexcept(expr) in separate function as work-around for VC140 (VS2015): + + template< typename T > + static constexpr bool satisfies() + { + return noexcept( swap( std::declval<T&>(), std::declval<T&>() ) ); + } + + template< typename T > + static auto test( int ) -> std::integral_constant<bool, satisfies<T>()>{} + + template< typename > + static auto test(...) -> std::false_type; +}; +} // namespace detail + +// is [nothrow] swappable: + +template< typename T > +struct is_swappable : decltype( detail::is_swappable::test<T>(0) ){}; + +template< typename T > +struct is_nothrow_swappable : decltype( detail::is_nothrow_swappable::test<T>(0) ){}; + +// conjunction: + +template< typename... > struct conjunction : std::true_type{}; +template< typename B1 > struct conjunction<B1> : B1{}; + +template< typename B1, typename... Bn > +struct conjunction<B1, Bn...> : std::conditional<bool(B1::value), conjunction<Bn...>, B1>::type{}; + +#endif // nsel_CPP17_OR_GREATER + +} // namespace std17 + +// type traits C++20: + +namespace std20 { + +#if defined(__cpp_lib_remove_cvref) + +using std::remove_cvref; + +#else + +template< typename T > +struct remove_cvref +{ + typedef typename std::remove_cv< typename std::remove_reference<T>::type >::type type; +}; + +#endif + +} // namespace std20 + +// forward declaration: + +template< typename T, typename E > +class expected; + +namespace detail { + +#if nsel_P2505R >= 3 +template< typename T > +struct is_expected : std::false_type {}; + +template< typename T, typename E > +struct is_expected< expected< T, E > > : std::true_type {}; +#endif // nsel_P2505R >= 3 + +/// discriminated union to hold value or 'error'. + +template< typename T, typename E > +class storage_t_noncopy_nonmove_impl +{ + template< typename, typename > friend class future_std::expected_lite::expected; + +public: + using value_type = T; + using error_type = E; + + // no-op construction + storage_t_noncopy_nonmove_impl() {} + ~storage_t_noncopy_nonmove_impl() {} + + explicit storage_t_noncopy_nonmove_impl( bool has_value ) + : m_has_value( has_value ) + {} + + void construct_value() + { + new( &m_value ) value_type(); + } + + // void construct_value( value_type const & e ) + // { + // new( &m_value ) value_type( e ); + // } + + // void construct_value( value_type && e ) + // { + // new( &m_value ) value_type( std::move( e ) ); + // } + + template< class... Args > + void emplace_value( Args&&... args ) + { + new( &m_value ) value_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_value( std::initializer_list<U> il, Args&&... args ) + { + new( &m_value ) value_type( il, std::forward<Args>(args)... ); + } + + void destruct_value() + { + m_value.~value_type(); + } + + // void construct_error( error_type const & e ) + // { + // // new( &m_error ) error_type( e ); + // } + + // void construct_error( error_type && e ) + // { + // // new( &m_error ) error_type( std::move( e ) ); + // } + + template< class... Args > + void emplace_error( Args&&... args ) + { + new( &m_error ) error_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_error( std::initializer_list<U> il, Args&&... args ) + { + new( &m_error ) error_type( il, std::forward<Args>(args)... ); + } + + void destruct_error() + { + m_error.~error_type(); + } + + constexpr value_type const & value() const & + { + return m_value; + } + + value_type & value() & + { + return m_value; + } + + constexpr value_type const && value() const && + { + return std::move( m_value ); + } + + nsel_constexpr14 value_type && value() && + { + return std::move( m_value ); + } + + value_type const * value_ptr() const + { + return &m_value; + } + + value_type * value_ptr() + { + return &m_value; + } + + error_type const & error() const & + { + return m_error; + } + + error_type & error() & + { + return m_error; + } + + constexpr error_type const && error() const && + { + return std::move( m_error ); + } + + nsel_constexpr14 error_type && error() && + { + return std::move( m_error ); + } + + bool has_value() const + { + return m_has_value; + } + + void set_has_value( bool v ) + { + m_has_value = v; + } + +private: + union + { + value_type m_value; + error_type m_error; + }; + + bool m_has_value = false; +}; + +template< typename T, typename E > +class storage_t_impl +{ + template< typename, typename > friend class future_std::expected_lite::expected; + +public: + using value_type = T; + using error_type = E; + + // no-op construction + storage_t_impl() {} + ~storage_t_impl() {} + + explicit storage_t_impl( bool has_value ) + : m_has_value( has_value ) + {} + + void construct_value() + { + new( &m_value ) value_type(); + } + + void construct_value( value_type const & e ) + { + new( &m_value ) value_type( e ); + } + + void construct_value( value_type && e ) + { + new( &m_value ) value_type( std::move( e ) ); + } + + template< class... Args > + void emplace_value( Args&&... args ) + { + new( &m_value ) value_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_value( std::initializer_list<U> il, Args&&... args ) + { + new( &m_value ) value_type( il, std::forward<Args>(args)... ); + } + + void destruct_value() + { + m_value.~value_type(); + } + + void construct_error( error_type const & e ) + { + new( &m_error ) error_type( e ); + } + + void construct_error( error_type && e ) + { + new( &m_error ) error_type( std::move( e ) ); + } + + template< class... Args > + void emplace_error( Args&&... args ) + { + new( &m_error ) error_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_error( std::initializer_list<U> il, Args&&... args ) + { + new( &m_error ) error_type( il, std::forward<Args>(args)... ); + } + + void destruct_error() + { + m_error.~error_type(); + } + + constexpr value_type const & value() const & + { + return m_value; + } + + value_type & value() & + { + return m_value; + } + + constexpr value_type const && value() const && + { + return std::move( m_value ); + } + + nsel_constexpr14 value_type && value() && + { + return std::move( m_value ); + } + + value_type const * value_ptr() const + { + return &m_value; + } + + value_type * value_ptr() + { + return &m_value; + } + + error_type const & error() const & + { + return m_error; + } + + error_type & error() & + { + return m_error; + } + + constexpr error_type const && error() const && + { + return std::move( m_error ); + } + + nsel_constexpr14 error_type && error() && + { + return std::move( m_error ); + } + + bool has_value() const + { + return m_has_value; + } + + void set_has_value( bool v ) + { + m_has_value = v; + } + +private: + union + { + value_type m_value; + error_type m_error; + }; + + bool m_has_value = false; +}; + +/// discriminated union to hold only 'error'. + +template< typename E > +struct storage_t_impl<void, E> +{ + template< typename, typename > friend class future_std::expected_lite::expected; + +public: + using value_type = void; + using error_type = E; + + // no-op construction + storage_t_impl() {} + ~storage_t_impl() {} + + explicit storage_t_impl( bool has_value ) + : m_has_value( has_value ) + {} + + void construct_error( error_type const & e ) + { + new( &m_error ) error_type( e ); + } + + void construct_error( error_type && e ) + { + new( &m_error ) error_type( std::move( e ) ); + } + + template< class... Args > + void emplace_error( Args&&... args ) + { + new( &m_error ) error_type( std::forward<Args>(args)...); + } + + template< class U, class... Args > + void emplace_error( std::initializer_list<U> il, Args&&... args ) + { + new( &m_error ) error_type( il, std::forward<Args>(args)... ); + } + + void destruct_error() + { + m_error.~error_type(); + } + + error_type const & error() const & + { + return m_error; + } + + error_type & error() & + { + return m_error; + } + + constexpr error_type const && error() const && + { + return std::move( m_error ); + } + + nsel_constexpr14 error_type && error() && + { + return std::move( m_error ); + } + + bool has_value() const + { + return m_has_value; + } + + void set_has_value( bool v ) + { + m_has_value = v; + } + +private: + union + { + char m_dummy; + error_type m_error; + }; + + bool m_has_value = false; +}; + +template< typename T, typename E, bool isConstructable, bool isMoveable > +class storage_t +{ +public: +}; + +template< typename T, typename E > +class storage_t<T, E, false, false> : public storage_t_noncopy_nonmove_impl<T, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_noncopy_nonmove_impl<T, E>( has_value ) + {} + + storage_t( storage_t const & other ) = delete; + storage_t( storage_t && other ) = delete; + +}; + +template< typename T, typename E > +class storage_t<T, E, true, true> : public storage_t_impl<T, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<T, E>( has_value ) + {} + + storage_t( storage_t const & other ) + : storage_t_impl<T, E>( other.has_value() ) + { + if ( this->has_value() ) this->construct_value( other.value() ); + else this->construct_error( other.error() ); + } + + storage_t(storage_t && other ) + : storage_t_impl<T, E>( other.has_value() ) + { + if ( this->has_value() ) this->construct_value( std::move( other.value() ) ); + else this->construct_error( std::move( other.error() ) ); + } +}; + +template< typename E > +class storage_t<void, E, true, true> : public storage_t_impl<void, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<void, E>( has_value ) + {} + + storage_t( storage_t const & other ) + : storage_t_impl<void, E>( other.has_value() ) + { + if ( this->has_value() ) ; + else this->construct_error( other.error() ); + } + + storage_t(storage_t && other ) + : storage_t_impl<void, E>( other.has_value() ) + { + if ( this->has_value() ) ; + else this->construct_error( std::move( other.error() ) ); + } +}; + +template< typename T, typename E > +class storage_t<T, E, true, false> : public storage_t_impl<T, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<T, E>( has_value ) + {} + + storage_t( storage_t const & other ) + : storage_t_impl<T, E>(other.has_value()) + { + if ( this->has_value() ) this->construct_value( other.value() ); + else this->construct_error( other.error() ); + } + + storage_t( storage_t && other ) = delete; +}; + +template< typename E > +class storage_t<void, E, true, false> : public storage_t_impl<void, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<void, E>( has_value ) + {} + + storage_t( storage_t const & other ) + : storage_t_impl<void, E>(other.has_value()) + { + if ( this->has_value() ) ; + else this->construct_error( other.error() ); + } + + storage_t( storage_t && other ) = delete; +}; + +template< typename T, typename E > +class storage_t<T, E, false, true> : public storage_t_impl<T, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<T, E>( has_value ) + {} + + storage_t( storage_t const & other ) = delete; + + storage_t( storage_t && other ) + : storage_t_impl<T, E>( other.has_value() ) + { + if ( this->has_value() ) this->construct_value( std::move( other.value() ) ); + else this->construct_error( std::move( other.error() ) ); + } +}; + +template< typename E > +class storage_t<void, E, false, true> : public storage_t_impl<void, E> +{ +public: + storage_t() = default; + ~storage_t() = default; + + explicit storage_t( bool has_value ) + : storage_t_impl<void, E>( has_value ) + {} + + storage_t( storage_t const & other ) = delete; + + storage_t( storage_t && other ) + : storage_t_impl<void, E>( other.has_value() ) + { + if ( this->has_value() ) ; + else this->construct_error( std::move( other.error() ) ); + } +}; + +#if nsel_P2505R >= 3 +// C++11 invoke implementation +template< typename > +struct is_reference_wrapper : std::false_type {}; +template< typename T > +struct is_reference_wrapper< std::reference_wrapper< T > > : std::true_type {}; + +template< typename FnT, typename ClassT, typename ObjectT, typename... Args + nsel_REQUIRES_T( + std::is_function<FnT>::value + && ( std::is_same< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + || std::is_base_of< ClassT, typename std20::remove_cvref< ObjectT >::type >::value ) + ) +> +nsel_constexpr auto invoke_member_function_impl( FnT ClassT::* memfnptr, ObjectT && obj, Args && ... args ) + noexcept( noexcept( (std::forward< ObjectT >( obj ).*memfnptr)( std::forward< Args >( args )... ) ) ) + -> decltype( (std::forward< ObjectT >( obj ).*memfnptr)( std::forward< Args >( args )...) ) +{ + return (std::forward< ObjectT >( obj ).*memfnptr)( std::forward< Args >( args )... ); +} + +template< typename FnT, typename ClassT, typename ObjectT, typename... Args + nsel_REQUIRES_T( + std::is_function<FnT>::value + && is_reference_wrapper< typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_function_impl( FnT ClassT::* memfnptr, ObjectT && obj, Args && ... args ) + noexcept( noexcept( (obj.get().*memfnptr)( std::forward< Args >( args ) ... ) ) ) + -> decltype( (obj.get().*memfnptr)( std::forward< Args >( args ) ... ) ) +{ + return (obj.get().*memfnptr)( std::forward< Args >( args ) ... ); +} + +template< typename FnT, typename ClassT, typename ObjectT, typename... Args + nsel_REQUIRES_T( + std::is_function<FnT>::value + && !std::is_same< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + && !std::is_base_of< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + && !is_reference_wrapper< typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_function_impl( FnT ClassT::* memfnptr, ObjectT && obj, Args && ... args ) + noexcept( noexcept( ((*std::forward< ObjectT >( obj )).*memfnptr)( std::forward< Args >( args ) ... ) ) ) + -> decltype( ((*std::forward< ObjectT >( obj )).*memfnptr)( std::forward< Args >( args ) ... ) ) +{ + return ((*std::forward<ObjectT>(obj)).*memfnptr)( std::forward< Args >( args ) ... ); +} + +template< typename MemberT, typename ClassT, typename ObjectT + nsel_REQUIRES_T( + std::is_same< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + || std::is_base_of< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_object_impl( MemberT ClassT::* memobjptr, ObjectT && obj ) + noexcept( noexcept( std::forward< ObjectT >( obj ).*memobjptr ) ) + -> decltype( std::forward< ObjectT >( obj ).*memobjptr ) +{ + return std::forward< ObjectT >( obj ).*memobjptr; +} + +template< typename MemberT, typename ClassT, typename ObjectT + nsel_REQUIRES_T( + is_reference_wrapper< typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_object_impl( MemberT ClassT::* memobjptr, ObjectT && obj ) + noexcept( noexcept( obj.get().*memobjptr ) ) + -> decltype( obj.get().*memobjptr ) +{ + return obj.get().*memobjptr; +} + +template< typename MemberT, typename ClassT, typename ObjectT + nsel_REQUIRES_T( + !std::is_same< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + && !std::is_base_of< ClassT, typename std20::remove_cvref< ObjectT >::type >::value + && !is_reference_wrapper< typename std20::remove_cvref< ObjectT >::type >::value + ) +> +nsel_constexpr auto invoke_member_object_impl( MemberT ClassT::* memobjptr, ObjectT && obj ) + noexcept( noexcept( (*std::forward< ObjectT >( obj )).*memobjptr ) ) + -> decltype( (*std::forward< ObjectT >( obj )).*memobjptr ) +{ + return (*std::forward< ObjectT >( obj )).*memobjptr; +} + +template< typename F, typename... Args + nsel_REQUIRES_T( + std::is_member_function_pointer< typename std20::remove_cvref< F >::type >::value + ) +> +nsel_constexpr auto invoke( F && f, Args && ... args ) + noexcept( noexcept( invoke_member_function_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ) ) ) + -> decltype( invoke_member_function_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ) ) +{ + return invoke_member_function_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ); +} + +template< typename F, typename... Args + nsel_REQUIRES_T( + std::is_member_object_pointer< typename std20::remove_cvref< F >::type >::value + ) +> +nsel_constexpr auto invoke( F && f, Args && ... args ) + noexcept( noexcept( invoke_member_object_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ) ) ) + -> decltype( invoke_member_object_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ) ) +{ + return invoke_member_object_impl( std::forward< F >( f ), std::forward< Args >( args ) ... ); +} + +template< typename F, typename... Args + nsel_REQUIRES_T( + !std::is_member_function_pointer< typename std20::remove_cvref< F >::type >::value + && !std::is_member_object_pointer< typename std20::remove_cvref< F >::type >::value + ) +> +nsel_constexpr auto invoke( F && f, Args && ... args ) + noexcept( noexcept( std::forward< F >( f )( std::forward< Args >( args ) ... ) ) ) + -> decltype( std::forward< F >( f )( std::forward< Args >( args ) ... ) ) +{ + return std::forward< F >( f )( std::forward< Args >( args ) ... ); +} + +template< typename F, typename ... Args > +using invoke_result_nocvref_t = typename std20::remove_cvref< decltype( invoke( std::declval< F >(), std::declval< Args >()... ) ) >::type; + +#if nsel_P2505R >= 5 +template< typename F, typename ... Args > +using transform_invoke_result_t = typename std::remove_cv< decltype( invoke( std::declval< F >(), std::declval< Args >()... ) ) >::type; +#else +template< typename F, typename ... Args > +using transform_invoke_result_t = invoke_result_nocvref_t +#endif // nsel_P2505R >= 5 + +template< typename T > +struct valid_expected_value_type : std::integral_constant< bool, std::is_destructible< T >::value && !std::is_reference< T >::value && !std::is_array< T >::value > {}; + +#endif // nsel_P2505R >= 3 +} // namespace detail + +/// x.x.5 Unexpected object type; unexpected_type; C++17 and later can also use aliased type unexpected. + +#if nsel_P0323R <= 2 +template< typename E = std::exception_ptr > +class unexpected_type +#else +template< typename E > +class unexpected_type +#endif // nsel_P0323R +{ +public: + using error_type = E; + + // x.x.5.2.1 Constructors + +// unexpected_type() = delete; + + constexpr unexpected_type( unexpected_type const & ) = default; + constexpr unexpected_type( unexpected_type && ) = default; + + template< typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, Args&&...>::value + ) + > + constexpr explicit unexpected_type( future_std_lite_in_place_t(E), Args &&... args ) + : m_error( std::forward<Args>( args )...) + {} + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, std::initializer_list<U>, Args&&...>::value + ) + > + constexpr explicit unexpected_type( future_std_lite_in_place_t(E), std::initializer_list<U> il, Args &&... args ) + : m_error( il, std::forward<Args>( args )...) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible<E,E2>::value + && !std::is_same< typename std20::remove_cvref<E2>::type, future_std_lite_in_place_t(E2) >::value + && !std::is_same< typename std20::remove_cvref<E2>::type, unexpected_type >::value + ) + > + constexpr explicit unexpected_type( E2 && error ) + : m_error( std::forward<E2>( error ) ) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible< E, E2>::value + && !std::is_constructible<E, unexpected_type<E2> & >::value + && !std::is_constructible<E, unexpected_type<E2> >::value + && !std::is_constructible<E, unexpected_type<E2> const & >::value + && !std::is_constructible<E, unexpected_type<E2> const >::value + && !std::is_convertible< unexpected_type<E2> &, E>::value + && !std::is_convertible< unexpected_type<E2> , E>::value + && !std::is_convertible< unexpected_type<E2> const &, E>::value + && !std::is_convertible< unexpected_type<E2> const , E>::value + && !std::is_convertible< E2 const &, E>::value /*=> explicit */ + ) + > + constexpr explicit unexpected_type( unexpected_type<E2> const & error ) + : m_error( E{ error.value() } ) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible< E, E2>::value + && !std::is_constructible<E, unexpected_type<E2> & >::value + && !std::is_constructible<E, unexpected_type<E2> >::value + && !std::is_constructible<E, unexpected_type<E2> const & >::value + && !std::is_constructible<E, unexpected_type<E2> const >::value + && !std::is_convertible< unexpected_type<E2> &, E>::value + && !std::is_convertible< unexpected_type<E2> , E>::value + && !std::is_convertible< unexpected_type<E2> const &, E>::value + && !std::is_convertible< unexpected_type<E2> const , E>::value + && std::is_convertible< E2 const &, E>::value /*=> explicit */ + ) + > + constexpr /*non-explicit*/ unexpected_type( unexpected_type<E2> const & error ) + : m_error( error.value() ) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible< E, E2>::value + && !std::is_constructible<E, unexpected_type<E2> & >::value + && !std::is_constructible<E, unexpected_type<E2> >::value + && !std::is_constructible<E, unexpected_type<E2> const & >::value + && !std::is_constructible<E, unexpected_type<E2> const >::value + && !std::is_convertible< unexpected_type<E2> &, E>::value + && !std::is_convertible< unexpected_type<E2> , E>::value + && !std::is_convertible< unexpected_type<E2> const &, E>::value + && !std::is_convertible< unexpected_type<E2> const , E>::value + && !std::is_convertible< E2 const &, E>::value /*=> explicit */ + ) + > + constexpr explicit unexpected_type( unexpected_type<E2> && error ) + : m_error( E{ std::move( error.value() ) } ) + {} + + template< typename E2 + nsel_REQUIRES_T( + std::is_constructible< E, E2>::value + && !std::is_constructible<E, unexpected_type<E2> & >::value + && !std::is_constructible<E, unexpected_type<E2> >::value + && !std::is_constructible<E, unexpected_type<E2> const & >::value + && !std::is_constructible<E, unexpected_type<E2> const >::value + && !std::is_convertible< unexpected_type<E2> &, E>::value + && !std::is_convertible< unexpected_type<E2> , E>::value + && !std::is_convertible< unexpected_type<E2> const &, E>::value + && !std::is_convertible< unexpected_type<E2> const , E>::value + && std::is_convertible< E2 const &, E>::value /*=> non-explicit */ + ) + > + constexpr /*non-explicit*/ unexpected_type( unexpected_type<E2> && error ) + : m_error( std::move( error.value() ) ) + {} + + // x.x.5.2.2 Assignment + + nsel_constexpr14 unexpected_type& operator=( unexpected_type const & ) = default; + nsel_constexpr14 unexpected_type& operator=( unexpected_type && ) = default; + + template< typename E2 = E > + nsel_constexpr14 unexpected_type & operator=( unexpected_type<E2> const & other ) + { + unexpected_type{ other.value() }.swap( *this ); + return *this; + } + + template< typename E2 = E > + nsel_constexpr14 unexpected_type & operator=( unexpected_type<E2> && other ) + { + unexpected_type{ std::move( other.value() ) }.swap( *this ); + return *this; + } + + // x.x.5.2.3 Observers + + nsel_constexpr14 E & value() & noexcept + { + return m_error; + } + + constexpr E const & value() const & noexcept + { + return m_error; + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + nsel_constexpr14 E && value() && noexcept + { + return std::move( m_error ); + } + + constexpr E const && value() const && noexcept + { + return std::move( m_error ); + } + +#endif + + // x.x.5.2.4 Swap + + template< typename U=E > + nsel_REQUIRES_R( void, + std17::is_swappable<U>::value + ) + swap( unexpected_type & other ) noexcept ( + std17::is_nothrow_swappable<U>::value + ) + { + using std::swap; + swap( m_error, other.m_error ); + } + + // TODO: ??? unexpected_type: in-class friend operator==, != + +private: + error_type m_error; +}; + +#if nsel_CPP17_OR_GREATER + +/// template deduction guide: + +template< typename E > +unexpected_type( E ) -> unexpected_type< E >; + +#endif + +/// class unexpected_type, std::exception_ptr specialization (P0323R2) + +#if !nsel_CONFIG_NO_EXCEPTIONS +#if nsel_P0323R <= 2 + +// TODO: Should expected be specialized for particular E types such as exception_ptr and how? +// See p0323r7 2.1. Ergonomics, http://wg21.link/p0323 +template<> +class unexpected_type< std::exception_ptr > +{ +public: + using error_type = std::exception_ptr; + + unexpected_type() = delete; + + ~unexpected_type(){} + + explicit unexpected_type( std::exception_ptr const & error ) + : m_error( error ) + {} + + explicit unexpected_type(std::exception_ptr && error ) + : m_error( std::move( error ) ) + {} + + template< typename E > + explicit unexpected_type( E error ) + : m_error( std::make_exception_ptr( error ) ) + {} + + std::exception_ptr const & value() const + { + return m_error; + } + + std::exception_ptr & value() + { + return m_error; + } + +private: + std::exception_ptr m_error; +}; + +#endif // nsel_P0323R +#endif // !nsel_CONFIG_NO_EXCEPTIONS + +/// x.x.4, Unexpected equality operators + +template< typename E1, typename E2 > +constexpr bool operator==( unexpected_type<E1> const & x, unexpected_type<E2> const & y ) +{ + return x.value() == y.value(); +} + +template< typename E1, typename E2 > +constexpr bool operator!=( unexpected_type<E1> const & x, unexpected_type<E2> const & y ) +{ + return ! ( x == y ); +} + +#if nsel_P0323R <= 2 + +template< typename E > +constexpr bool operator<( unexpected_type<E> const & x, unexpected_type<E> const & y ) +{ + return x.value() < y.value(); +} + +template< typename E > +constexpr bool operator>( unexpected_type<E> const & x, unexpected_type<E> const & y ) +{ + return ( y < x ); +} + +template< typename E > +constexpr bool operator<=( unexpected_type<E> const & x, unexpected_type<E> const & y ) +{ + return ! ( y < x ); +} + +template< typename E > +constexpr bool operator>=( unexpected_type<E> const & x, unexpected_type<E> const & y ) +{ + return ! ( x < y ); +} + +#endif // nsel_P0323R + +/// x.x.5 Specialized algorithms + +template< typename E + nsel_REQUIRES_T( + std17::is_swappable<E>::value + ) +> +void swap( unexpected_type<E> & x, unexpected_type<E> & y) noexcept ( noexcept ( x.swap(y) ) ) +{ + x.swap( y ); +} + +#if nsel_P0323R <= 2 + +// unexpected: relational operators for std::exception_ptr: + +inline constexpr bool operator<( unexpected_type<std::exception_ptr> const & /*x*/, unexpected_type<std::exception_ptr> const & /*y*/ ) +{ + return false; +} + +inline constexpr bool operator>( unexpected_type<std::exception_ptr> const & /*x*/, unexpected_type<std::exception_ptr> const & /*y*/ ) +{ + return false; +} + +inline constexpr bool operator<=( unexpected_type<std::exception_ptr> const & x, unexpected_type<std::exception_ptr> const & y ) +{ + return ( x == y ); +} + +inline constexpr bool operator>=( unexpected_type<std::exception_ptr> const & x, unexpected_type<std::exception_ptr> const & y ) +{ + return ( x == y ); +} + +#endif // nsel_P0323R + +// unexpected: traits + +#if nsel_P0323R <= 3 + +template< typename E> +struct is_unexpected : std::false_type {}; + +template< typename E> +struct is_unexpected< unexpected_type<E> > : std::true_type {}; + +#endif // nsel_P0323R + +// unexpected: factory + +// keep make_unexpected() removed in p0323r2 for pre-C++17: + +template< typename E> +nsel_constexpr14 auto +make_unexpected( E && value ) -> unexpected_type< typename std::decay<E>::type > +{ + return unexpected_type< typename std::decay<E>::type >( std::forward<E>(value) ); +} + +#if nsel_P0323R <= 3 + +/*nsel_constexpr14*/ auto inline +make_unexpected_from_current_exception() -> unexpected_type< std::exception_ptr > +{ + return unexpected_type< std::exception_ptr >( std::current_exception() ); +} + +#endif // nsel_P0323R + +/// x.x.6, x.x.7 expected access error + +template< typename E > +class bad_expected_access; + +/// x.x.7 bad_expected_access<void>: expected access error + +template <> +class bad_expected_access< void > : public std::exception +{ +public: + explicit bad_expected_access() + : std::exception() + {} +}; + +/// x.x.6 bad_expected_access: expected access error + +#if !nsel_CONFIG_NO_EXCEPTIONS + +template< typename E > +class bad_expected_access : public bad_expected_access< void > +{ +public: + using error_type = E; + + explicit bad_expected_access( error_type error ) + : m_error( error ) + {} + + virtual char const * what() const noexcept override + { + return "bad_expected_access"; + } + + nsel_constexpr14 error_type & error() & + { + return m_error; + } + + constexpr error_type const & error() const & + { + return m_error; + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + nsel_constexpr14 error_type && error() && + { + return std::move( m_error ); + } + + constexpr error_type const && error() const && + { + return std::move( m_error ); + } + +#endif + +private: + error_type m_error; +}; + +#endif // nsel_CONFIG_NO_EXCEPTIONS + +/// x.x.8 unexpect tag, in_place_unexpected tag: construct an error + +struct unexpect_t{}; +using in_place_unexpected_t = unexpect_t; + +nsel_inline17 constexpr unexpect_t unexpect{}; +nsel_inline17 constexpr unexpect_t in_place_unexpected{}; + +/// class error_traits + +#if nsel_CONFIG_NO_EXCEPTIONS + +namespace detail { + inline bool text( char const * /*text*/ ) { return true; } +} + +template< typename Error > +struct error_traits +{ + static void rethrow( Error const & /*e*/ ) + { +#if nsel_CONFIG_NO_EXCEPTIONS_SEH + RaiseException( EXCEPTION_ACCESS_VIOLATION, EXCEPTION_NONCONTINUABLE, 0, NULL ); +#else + assert( false && detail::text("throw bad_expected_access<Error>{ e };") ); +#endif + } +}; + +template<> +struct error_traits< std::exception_ptr > +{ + static void rethrow( std::exception_ptr const & /*e*/ ) + { +#if nsel_CONFIG_NO_EXCEPTIONS_SEH + RaiseException( EXCEPTION_ACCESS_VIOLATION, EXCEPTION_NONCONTINUABLE, 0, NULL ); +#else + assert( false && detail::text("throw bad_expected_access<std::exception_ptr>{ e };") ); +#endif + } +}; + +template<> +struct error_traits< std::error_code > +{ + static void rethrow( std::error_code const & /*e*/ ) + { +#if nsel_CONFIG_NO_EXCEPTIONS_SEH + RaiseException( EXCEPTION_ACCESS_VIOLATION, EXCEPTION_NONCONTINUABLE, 0, NULL ); +#else + assert( false && detail::text("throw std::system_error( e );") ); +#endif + } +}; + +#else // nsel_CONFIG_NO_EXCEPTIONS + +template< typename Error > +struct error_traits +{ + static void rethrow( Error const & e ) + { + throw bad_expected_access<Error>{ e }; + } +}; + +template<> +struct error_traits< std::exception_ptr > +{ + static void rethrow( std::exception_ptr const & e ) + { + std::rethrow_exception( e ); + } +}; + +template<> +struct error_traits< std::error_code > +{ + static void rethrow( std::error_code const & e ) + { + throw std::system_error( e ); + } +}; + +#endif // nsel_CONFIG_NO_EXCEPTIONS + +#if nsel_P2505R >= 3 +namespace detail { + +// from https://en.cppreference.com/w/cpp/utility/expected/unexpected: +// "the type of the unexpected value. The type must not be an array type, a non-object type, a specialization of std::unexpected, or a cv-qualified type." +template< typename T > +struct valid_unexpected_type : std::integral_constant< bool, + std::is_same< T, typename std20::remove_cvref< T >::type >::value + && std::is_object< T >::value + && !std::is_array< T >::value +> {}; + +template< typename T > +struct valid_unexpected_type< unexpected_type< T > > : std::false_type {}; + +} // namespace detail +#endif // nsel_P2505R >= 3 + +} // namespace expected_lite + +// provide future_std::unexpected_type: + +using expected_lite::unexpected_type; + +namespace expected_lite { + +/// class expected + +#if nsel_P0323R <= 2 +template< typename T, typename E = std::exception_ptr > +class expected +#else +template< typename T, typename E > +class expected +#endif // nsel_P0323R +{ +private: + template< typename, typename > friend class expected; + +public: + using value_type = T; + using error_type = E; + using unexpected_type = future_std::unexpected_type<E>; + + template< typename U > + struct rebind + { + using type = expected<U, error_type>; + }; + + // x.x.4.1 constructors + + nsel_REQUIRES_0( + std::is_default_constructible<T>::value + ) + nsel_constexpr14 expected() + : contained( true ) + { + contained.construct_value(); + } + + nsel_constexpr14 expected( expected const & ) = default; + nsel_constexpr14 expected( expected && ) = default; + + template< typename U, typename G + nsel_REQUIRES_T( + std::is_constructible< T, U const &>::value + && std::is_constructible<E, G const &>::value + && !std::is_constructible<T, expected<U, G> & >::value + && !std::is_constructible<T, expected<U, G> && >::value + && !std::is_constructible<T, expected<U, G> const & >::value + && !std::is_constructible<T, expected<U, G> const && >::value + && !std::is_convertible< expected<U, G> & , T>::value + && !std::is_convertible< expected<U, G> &&, T>::value + && !std::is_convertible< expected<U, G> const & , T>::value + && !std::is_convertible< expected<U, G> const &&, T>::value + && (!std::is_convertible<U const &, T>::value || !std::is_convertible<G const &, E>::value ) /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( expected<U, G> const & other ) + : contained( other.has_value() ) + { + if ( has_value() ) contained.construct_value( T{ other.contained.value() } ); + else contained.construct_error( E{ other.contained.error() } ); + } + + template< typename U, typename G + nsel_REQUIRES_T( + std::is_constructible< T, U const &>::value + && std::is_constructible<E, G const &>::value + && !std::is_constructible<T, expected<U, G> & >::value + && !std::is_constructible<T, expected<U, G> && >::value + && !std::is_constructible<T, expected<U, G> const & >::value + && !std::is_constructible<T, expected<U, G> const && >::value + && !std::is_convertible< expected<U, G> & , T>::value + && !std::is_convertible< expected<U, G> &&, T>::value + && !std::is_convertible< expected<U, G> const &, T>::value + && !std::is_convertible< expected<U, G> const &&, T>::value + && !(!std::is_convertible<U const &, T>::value || !std::is_convertible<G const &, E>::value ) /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( expected<U, G> const & other ) + : contained( other.has_value() ) + { + if ( has_value() ) contained.construct_value( other.contained.value() ); + else contained.construct_error( other.contained.error() ); + } + + template< typename U, typename G + nsel_REQUIRES_T( + std::is_constructible< T, U>::value + && std::is_constructible<E, G>::value + && !std::is_constructible<T, expected<U, G> & >::value + && !std::is_constructible<T, expected<U, G> && >::value + && !std::is_constructible<T, expected<U, G> const & >::value + && !std::is_constructible<T, expected<U, G> const && >::value + && !std::is_convertible< expected<U, G> & , T>::value + && !std::is_convertible< expected<U, G> &&, T>::value + && !std::is_convertible< expected<U, G> const & , T>::value + && !std::is_convertible< expected<U, G> const &&, T>::value + && (!std::is_convertible<U, T>::value || !std::is_convertible<G, E>::value ) /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( expected<U, G> && other ) + : contained( other.has_value() ) + { + if ( has_value() ) contained.construct_value( T{ std::move( other.contained.value() ) } ); + else contained.construct_error( E{ std::move( other.contained.error() ) } ); + } + + template< typename U, typename G + nsel_REQUIRES_T( + std::is_constructible< T, U>::value + && std::is_constructible<E, G>::value + && !std::is_constructible<T, expected<U, G> & >::value + && !std::is_constructible<T, expected<U, G> && >::value + && !std::is_constructible<T, expected<U, G> const & >::value + && !std::is_constructible<T, expected<U, G> const && >::value + && !std::is_convertible< expected<U, G> & , T>::value + && !std::is_convertible< expected<U, G> &&, T>::value + && !std::is_convertible< expected<U, G> const & , T>::value + && !std::is_convertible< expected<U, G> const &&, T>::value + && !(!std::is_convertible<U, T>::value || !std::is_convertible<G, E>::value ) /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( expected<U, G> && other ) + : contained( other.has_value() ) + { + if ( has_value() ) contained.construct_value( std::move( other.contained.value() ) ); + else contained.construct_error( std::move( other.contained.error() ) ); + } + + template< typename U = T + nsel_REQUIRES_T( + std::is_copy_constructible<U>::value + ) + > + nsel_constexpr14 expected( value_type const & value ) + : contained( true ) + { + contained.construct_value( value ); + } + + template< typename U = T + nsel_REQUIRES_T( + std::is_constructible<T,U&&>::value + && !std::is_same<typename std20::remove_cvref<U>::type, future_std_lite_in_place_t(U)>::value + && !std::is_same< expected<T,E> , typename std20::remove_cvref<U>::type>::value + && !std::is_same<future_std::unexpected_type<E>, typename std20::remove_cvref<U>::type>::value + && !std::is_convertible<U&&,T>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( U && value ) noexcept + ( + std::is_nothrow_move_constructible<U>::value && + std::is_nothrow_move_constructible<E>::value + ) + : contained( true ) + { + contained.construct_value( T{ std::forward<U>( value ) } ); + } + + template< typename U = T + nsel_REQUIRES_T( + std::is_constructible<T,U&&>::value + && !std::is_same<typename std20::remove_cvref<U>::type, future_std_lite_in_place_t(U)>::value + && !std::is_same< expected<T,E> , typename std20::remove_cvref<U>::type>::value + && !std::is_same<future_std::unexpected_type<E>, typename std20::remove_cvref<U>::type>::value + && std::is_convertible<U&&,T>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( U && value ) noexcept + ( + std::is_nothrow_move_constructible<U>::value && + std::is_nothrow_move_constructible<E>::value + ) + : contained( true ) + { + contained.construct_value( std::forward<U>( value ) ); + } + + // construct error: + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G const & >::value + && !std::is_convertible< G const &, E>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( future_std::unexpected_type<G> const & error ) + : contained( false ) + { + contained.construct_error( E{ error.value() } ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G const & >::value + && std::is_convertible< G const &, E>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( future_std::unexpected_type<G> const & error ) + : contained( false ) + { + contained.construct_error( error.value() ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G&& >::value + && !std::is_convertible< G&&, E>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( future_std::unexpected_type<G> && error ) + : contained( false ) + { + contained.construct_error( E{ std::move( error.value() ) } ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G&& >::value + && std::is_convertible< G&&, E>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( future_std::unexpected_type<G> && error ) + : contained( false ) + { + contained.construct_error( std::move( error.value() ) ); + } + + // in-place construction, value + + template< typename... Args + nsel_REQUIRES_T( + std::is_constructible<T, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( future_std_lite_in_place_t(T), Args&&... args ) + : contained( true ) + { + contained.emplace_value( std::forward<Args>( args )... ); + } + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_constructible<T, std::initializer_list<U>, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( future_std_lite_in_place_t(T), std::initializer_list<U> il, Args&&... args ) + : contained( true ) + { + contained.emplace_value( il, std::forward<Args>( args )... ); + } + + // in-place construction, error + + template< typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( unexpect_t, Args&&... args ) + : contained( false ) + { + contained.emplace_error( std::forward<Args>( args )... ); + } + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, std::initializer_list<U>, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( unexpect_t, std::initializer_list<U> il, Args&&... args ) + : contained( false ) + { + contained.emplace_error( il, std::forward<Args>( args )... ); + } + + // x.x.4.2 destructor + + // TODO: ~expected: triviality + // Effects: If T is not cv void and is_trivially_destructible_v<T> is false and bool(*this), calls val.~T(). If is_trivially_destructible_v<E> is false and !bool(*this), calls unexpect.~unexpected<E>(). + // Remarks: If either T is cv void or is_trivially_destructible_v<T> is true, and is_trivially_destructible_v<E> is true, then this destructor shall be a trivial destructor. + + ~expected() + { + if ( has_value() ) contained.destruct_value(); + else contained.destruct_error(); + } + + // x.x.4.3 assignment + + expected & operator=( expected const & other ) + { + expected( other ).swap( *this ); + return *this; + } + + expected & operator=( expected && other ) noexcept + ( + std::is_nothrow_move_constructible< T>::value + && std::is_nothrow_move_assignable< T>::value + && std::is_nothrow_move_constructible<E>::value // added for missing + && std::is_nothrow_move_assignable< E>::value ) // nothrow above + { + expected( std::move( other ) ).swap( *this ); + return *this; + } + + template< typename U + nsel_REQUIRES_T( + !std::is_same<expected<T,E>, typename std20::remove_cvref<U>::type>::value + && std17::conjunction<std::is_scalar<T>, std::is_same<T, std::decay<U>> >::value + && std::is_constructible<T ,U>::value + && std::is_assignable< T&,U>::value + && std::is_nothrow_move_constructible<E>::value ) + > + expected & operator=( U && value ) + { + expected( std::forward<U>( value ) ).swap( *this ); + return *this; + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G const&>::value && + std::is_copy_constructible<G>::value // TODO: std::is_nothrow_copy_constructible<G> + && std::is_copy_assignable<G>::value + ) + > + expected & operator=( future_std::unexpected_type<G> const & error ) + { + expected( unexpect, error.value() ).swap( *this ); + return *this; + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_constructible<E, G&&>::value && + std::is_move_constructible<G>::value // TODO: std::is_nothrow_move_constructible<G> + && std::is_move_assignable<G>::value + ) + > + expected & operator=( future_std::unexpected_type<G> && error ) + { + expected( unexpect, std::move( error.value() ) ).swap( *this ); + return *this; + } + + template< typename... Args + nsel_REQUIRES_T( + std::is_nothrow_constructible<T, Args&&...>::value + ) + > + value_type & emplace( Args &&... args ) + { + expected( future_std_lite_in_place(T), std::forward<Args>(args)... ).swap( *this ); + return value(); + } + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_nothrow_constructible<T, std::initializer_list<U>&, Args&&...>::value + ) + > + value_type & emplace( std::initializer_list<U> il, Args &&... args ) + { + expected( future_std_lite_in_place(T), il, std::forward<Args>(args)... ).swap( *this ); + return value(); + } + + // x.x.4.4 swap + + template< typename U=T, typename G=E > + nsel_REQUIRES_R( void, + std17::is_swappable< U>::value + && std17::is_swappable<G>::value + && ( std::is_move_constructible<U>::value || std::is_move_constructible<G>::value ) + ) + swap( expected & other ) noexcept + ( + std::is_nothrow_move_constructible<T>::value && std17::is_nothrow_swappable<T&>::value && + std::is_nothrow_move_constructible<E>::value && std17::is_nothrow_swappable<E&>::value + ) + { + using std::swap; + + if ( bool(*this) && bool(other) ) { swap( contained.value(), other.contained.value() ); } + else if ( ! bool(*this) && ! bool(other) ) { swap( contained.error(), other.contained.error() ); } + else if ( bool(*this) && ! bool(other) ) { error_type t( std::move( other.error() ) ); + other.contained.destruct_error(); + other.contained.construct_value( std::move( contained.value() ) ); + contained.destruct_value(); + contained.construct_error( std::move( t ) ); + bool has_value = contained.has_value(); + bool other_has_value = other.has_value(); + other.contained.set_has_value(has_value); + contained.set_has_value(other_has_value); + } + else if ( ! bool(*this) && bool(other) ) { other.swap( *this ); } + } + + // x.x.4.5 observers + + constexpr value_type const * operator ->() const + { + return assert( has_value() ), contained.value_ptr(); + } + + value_type * operator ->() + { + return assert( has_value() ), contained.value_ptr(); + } + + constexpr value_type const & operator *() const & + { + return assert( has_value() ), contained.value(); + } + + value_type & operator *() & + { + return assert( has_value() ), contained.value(); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + constexpr value_type const && operator *() const && + { + return std::move( ( assert( has_value() ), contained.value() ) ); + } + + nsel_constexpr14 value_type && operator *() && + { + return std::move( ( assert( has_value() ), contained.value() ) ); + } + +#endif + + constexpr explicit operator bool() const noexcept + { + return has_value(); + } + + constexpr bool has_value() const noexcept + { + return contained.has_value(); + } + + constexpr value_type const & value() const & + { + return has_value() + ? ( contained.value() ) + : ( error_traits<error_type>::rethrow( contained.error() ), contained.value() ); + } + + value_type & value() & + { + return has_value() + ? ( contained.value() ) + : ( error_traits<error_type>::rethrow( contained.error() ), contained.value() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + constexpr value_type const && value() const && + { + return std::move( has_value() + ? ( contained.value() ) + : ( error_traits<error_type>::rethrow( contained.error() ), contained.value() ) ); + } + + nsel_constexpr14 value_type && value() && + { + return std::move( has_value() + ? ( contained.value() ) + : ( error_traits<error_type>::rethrow( contained.error() ), contained.value() ) ); + } + +#endif + + constexpr error_type const & error() const & + { + return assert( ! has_value() ), contained.error(); + } + + error_type & error() & + { + return assert( ! has_value() ), contained.error(); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + constexpr error_type const && error() const && + { + return std::move( ( assert( ! has_value() ), contained.error() ) ); + } + + error_type && error() && + { + return std::move( ( assert( ! has_value() ), contained.error() ) ); + } + +#endif + + constexpr unexpected_type get_unexpected() const + { + return make_unexpected( contained.error() ); + } + + template< typename Ex > + bool has_exception() const + { + using ContainedEx = typename std::remove_reference< decltype( get_unexpected().value() ) >::type; + return ! has_value() && std::is_base_of< Ex, ContainedEx>::value; + } + + template< typename U + nsel_REQUIRES_T( + std::is_copy_constructible< T>::value + && std::is_convertible<U&&, T>::value + ) + > + value_type value_or( U && v ) const & + { + return has_value() + ? contained.value() + : static_cast<T>( std::forward<U>( v ) ); + } + + template< typename U + nsel_REQUIRES_T( + std::is_move_constructible< T>::value + && std::is_convertible<U&&, T>::value + ) + > + value_type value_or( U && v ) && + { + return has_value() + ? std::move( contained.value() ) + : static_cast<T>( std::forward<U>( v ) ); + } + +#if nsel_P2505R >= 4 + template< typename G = E + nsel_REQUIRES_T( + std::is_copy_constructible< E >::value + && std::is_convertible< G, E >::value + ) + > + nsel_constexpr error_type error_or( G && e ) const & + { + return has_value() + ? static_cast< E >( std::forward< G >( e ) ) + : contained.error(); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_move_constructible< E >::value + && std::is_convertible< G, E >::value + ) + > + nsel_constexpr14 error_type error_or( G && e ) && + { + return has_value() + ? static_cast< E >( std::forward< G >( e ) ) + : std::move( contained.error() ); + } +#endif // nsel_P2505R >= 4 + +#if nsel_P2505R >= 3 + // Monadic operations (P2505) + template< typename F + nsel_REQUIRES_T( + detail::is_expected < detail::invoke_result_nocvref_t< F, value_type & > > ::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, value_type & >::error_type, error_type >::value + && std::is_constructible< error_type, error_type & >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, value_type & > and_then( F && f ) & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, value_type & >( detail::invoke( std::forward< F >( f ), value() ) ) + : detail::invoke_result_nocvref_t< F, value_type & >( unexpect, error() ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const value_type & > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, const value_type & >::error_type, error_type >::value + && std::is_constructible< error_type, const error_type & >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const value_type & > and_then( F && f ) const & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const value_type & >( detail::invoke( std::forward< F >( f ), value() ) ) + : detail::invoke_result_nocvref_t< F, const value_type & >( unexpect, error() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, value_type && > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, value_type && >::error_type, error_type >::value + && std::is_constructible< error_type, error_type && >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, value_type && > and_then( F && f ) && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, value_type && >( detail::invoke( std::forward< F >( f ), std::move( value() ) ) ) + : detail::invoke_result_nocvref_t< F, value_type && >( unexpect, std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const value_type && > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, const value_type & >::error_type, error_type >::value + && std::is_constructible< error_type, const error_type && >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const value_type && > and_then( F && f ) const && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const value_type && >( detail::invoke( std::forward< F >( f ), std::move( value() ) ) ) + : detail::invoke_result_nocvref_t< F, const value_type && >( unexpect, std::move( error() ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, error_type & > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, error_type & >::value_type, value_type >::value + && std::is_constructible< value_type, value_type & >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, error_type & > or_else( F && f ) & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, error_type & >( value() ) + : detail::invoke_result_nocvref_t< F, error_type & >( detail::invoke( std::forward< F >( f ), error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const error_type & > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, const error_type & >::value_type, value_type >::value + && std::is_constructible< value_type, const value_type & >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const error_type & > or_else( F && f ) const & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const error_type & >( value() ) + : detail::invoke_result_nocvref_t< F, const error_type & >( detail::invoke( std::forward< F >( f ), error() ) ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, error_type && > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, error_type && >::value_type, value_type >::value + && std::is_constructible< value_type, value_type && >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, error_type && > or_else( F && f ) && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, error_type && >( std::move( value() ) ) + : detail::invoke_result_nocvref_t< F, error_type && >( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const error_type && > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F, const error_type && >::value_type, value_type >::value + && std::is_constructible< value_type, const value_type && >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const error_type && > or_else( F && f ) const && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const error_type && >( std::move( value() ) ) + : detail::invoke_result_nocvref_t< F, const error_type && >( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type & >::value + && !std::is_void< detail::transform_invoke_result_t< F, value_type & > >::value + && detail::valid_expected_value_type< detail::transform_invoke_result_t< F, value_type & > >::value + ) + > + nsel_constexpr14 expected< detail::transform_invoke_result_t< F, value_type & >, error_type > transform( F && f ) & + { + return has_value() + ? expected< detail::transform_invoke_result_t< F, value_type & >, error_type >( detail::invoke( std::forward< F >( f ), **this ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type & >::value + && std::is_void< detail::transform_invoke_result_t< F, value_type & > >::value + ) + > + nsel_constexpr14 expected< void, error_type > transform( F && f ) & + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ), **this ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type & >::value + && !std::is_void< detail::transform_invoke_result_t< F, const value_type & > >::value + && detail::valid_expected_value_type< detail::transform_invoke_result_t< F, const value_type & > >::value + ) + > + nsel_constexpr expected< detail::transform_invoke_result_t< F, const value_type & >, error_type > transform( F && f ) const & + { + return has_value() + ? expected< detail::transform_invoke_result_t< F, const value_type & >, error_type >( detail::invoke( std::forward< F >( f ), **this ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type & >::value + && std::is_void< detail::transform_invoke_result_t< F, const value_type & > >::value + ) + > + nsel_constexpr expected< void, error_type > transform( F && f ) const & + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ), **this ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type && >::value + && !std::is_void< detail::transform_invoke_result_t< F, value_type && > >::value + && detail::valid_expected_value_type< detail::transform_invoke_result_t< F, value_type && > >::value + ) + > + nsel_constexpr14 expected< detail::transform_invoke_result_t< F, value_type && >, error_type > transform( F && f ) && + { + return has_value() + ? expected< detail::transform_invoke_result_t< F, value_type && >, error_type >( detail::invoke( std::forward< F >( f ), std::move( **this ) ) ) + : make_unexpected( std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type && >::value + && std::is_void< detail::transform_invoke_result_t< F, value_type && > >::value + ) + > + nsel_constexpr14 expected< void, error_type > transform( F && f ) && + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ), **this ), expected< void, error_type >() ) + : make_unexpected( std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type && >::value + && !std::is_void< detail::transform_invoke_result_t< F, const value_type && > >::value + && detail::valid_expected_value_type< detail::transform_invoke_result_t< F, const value_type && > >::value + ) + > + nsel_constexpr expected< detail::transform_invoke_result_t< F, const value_type && >, error_type > transform( F && f ) const && + { + return has_value() + ? expected< detail::transform_invoke_result_t< F, const value_type && >, error_type >( detail::invoke( std::forward< F >( f ), std::move( **this ) ) ) + : make_unexpected( std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type && >::value + && std::is_void< detail::transform_invoke_result_t< F, const value_type && > >::value + ) + > + nsel_constexpr expected< void, error_type > transform( F && f ) const && + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ), **this ), expected< void, error_type >() ) + : make_unexpected( std::move( error() ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, error_type & > >::value + && std::is_constructible< value_type, value_type & >::value + ) + > + nsel_constexpr14 expected< value_type, detail::transform_invoke_result_t< F, error_type & > > transform_error( F && f ) & + { + return has_value() + ? expected< value_type, detail::transform_invoke_result_t< F, error_type & > >( in_place, **this ) + : make_unexpected( detail::invoke( std::forward< F >( f ), error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, const error_type & > >::value + && std::is_constructible< value_type, const value_type & >::value + ) + > + nsel_constexpr expected< value_type, detail::transform_invoke_result_t< F, const error_type & > > transform_error( F && f ) const & + { + return has_value() + ? expected< value_type, detail::transform_invoke_result_t< F, const error_type & > >( in_place, **this ) + : make_unexpected( detail::invoke( std::forward< F >( f ), error() ) ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, error_type && > >::value + && std::is_constructible< value_type, value_type && >::value + ) + > + nsel_constexpr14 expected< value_type, detail::transform_invoke_result_t< F, error_type && > > transform_error( F && f ) && + { + return has_value() + ? expected< value_type, detail::transform_invoke_result_t< F, error_type && > >( in_place, std::move( **this ) ) + : make_unexpected( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, const error_type && > >::value + && std::is_constructible< value_type, const value_type && >::value + ) + > + nsel_constexpr expected< value_type, detail::transform_invoke_result_t< F, const error_type && > > transform_error( F && f ) const && + { + return has_value() + ? expected< value_type, detail::transform_invoke_result_t< F, const error_type && > >( in_place, std::move( **this ) ) + : make_unexpected( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } +#endif +#endif // nsel_P2505R >= 3 + // unwrap() + +// template <class U, class E> +// constexpr expected<U,E> expected<expected<U,E>,E>::unwrap() const&; + +// template <class T, class E> +// constexpr expected<T,E> expected<T,E>::unwrap() const&; + +// template <class U, class E> +// expected<U,E> expected<expected<U,E>, E>::unwrap() &&; + +// template <class T, class E> +// template expected<T,E> expected<T,E>::unwrap() &&; + + // factories + +// template< typename Ex, typename F> +// expected<T,E> catch_exception(F&& f); + +// template< typename F> +// expected<decltype(func(declval<T>())),E> map(F&& func) ; + +// template< typename F> +// 'see below' bind(F&& func); + +// template< typename F> +// expected<T,E> catch_error(F&& f); + +// template< typename F> +// 'see below' then(F&& func); + +private: + detail::storage_t + < + T + ,E + , std::is_copy_constructible<T>::value && std::is_copy_constructible<E>::value + , std::is_move_constructible<T>::value && std::is_move_constructible<E>::value + > + contained; +}; + +/// class expected, void specialization + +template< typename E > +class expected<void, E> +{ +private: + template< typename, typename > friend class expected; + +public: + using value_type = void; + using error_type = E; + using unexpected_type = future_std::unexpected_type<E>; + + // x.x.4.1 constructors + + constexpr expected() noexcept + : contained( true ) + {} + + nsel_constexpr14 expected( expected const & other ) = default; + nsel_constexpr14 expected( expected && other ) = default; + + constexpr explicit expected( future_std_lite_in_place_t(void) ) + : contained( true ) + {} + + template< typename G = E + nsel_REQUIRES_T( + !std::is_convertible<G const &, E>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( future_std::unexpected_type<G> const & error ) + : contained( false ) + { + contained.construct_error( E{ error.value() } ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_convertible<G const &, E>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( future_std::unexpected_type<G> const & error ) + : contained( false ) + { + contained.construct_error( error.value() ); + } + + template< typename G = E + nsel_REQUIRES_T( + !std::is_convertible<G&&, E>::value /*=> explicit */ + ) + > + nsel_constexpr14 explicit expected( future_std::unexpected_type<G> && error ) + : contained( false ) + { + contained.construct_error( E{ std::move( error.value() ) } ); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_convertible<G&&, E>::value /*=> non-explicit */ + ) + > + nsel_constexpr14 /*non-explicit*/ expected( future_std::unexpected_type<G> && error ) + : contained( false ) + { + contained.construct_error( std::move( error.value() ) ); + } + + template< typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( unexpect_t, Args&&... args ) + : contained( false ) + { + contained.emplace_error( std::forward<Args>( args )... ); + } + + template< typename U, typename... Args + nsel_REQUIRES_T( + std::is_constructible<E, std::initializer_list<U>, Args&&...>::value + ) + > + nsel_constexpr14 explicit expected( unexpect_t, std::initializer_list<U> il, Args&&... args ) + : contained( false ) + { + contained.emplace_error( il, std::forward<Args>( args )... ); + } + + // destructor + + ~expected() + { + if ( ! has_value() ) + { + contained.destruct_error(); + } + } + + // x.x.4.3 assignment + + expected & operator=( expected const & other ) + { + expected( other ).swap( *this ); + return *this; + } + + expected & operator=( expected && other ) noexcept + ( + std::is_nothrow_move_assignable<E>::value && + std::is_nothrow_move_constructible<E>::value ) + { + expected( std::move( other ) ).swap( *this ); + return *this; + } + + void emplace() + { + expected().swap( *this ); + } + + // x.x.4.4 swap + + template< typename G = E > + nsel_REQUIRES_R( void, + std17::is_swappable<G>::value + && std::is_move_constructible<G>::value + ) + swap( expected & other ) noexcept + ( + std::is_nothrow_move_constructible<E>::value && std17::is_nothrow_swappable<E&>::value + ) + { + using std::swap; + + if ( ! bool(*this) && ! bool(other) ) { swap( contained.error(), other.contained.error() ); } + else if ( bool(*this) && ! bool(other) ) { contained.construct_error( std::move( other.error() ) ); + bool has_value = contained.has_value(); + bool other_has_value = other.has_value(); + other.contained.set_has_value(has_value); + contained.set_has_value(other_has_value); + } + else if ( ! bool(*this) && bool(other) ) { other.swap( *this ); } + } + + // x.x.4.5 observers + + constexpr explicit operator bool() const noexcept + { + return has_value(); + } + + constexpr bool has_value() const noexcept + { + return contained.has_value(); + } + + void value() const + { + if ( ! has_value() ) + { + error_traits<error_type>::rethrow( contained.error() ); + } + } + + constexpr error_type const & error() const & + { + return assert( ! has_value() ), contained.error(); + } + + error_type & error() & + { + return assert( ! has_value() ), contained.error(); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + + constexpr error_type const && error() const && + { + return std::move( ( assert( ! has_value() ), contained.error() ) ); + } + + error_type && error() && + { + return std::move( ( assert( ! has_value() ), contained.error() ) ); + } + +#endif + + constexpr unexpected_type get_unexpected() const + { + return make_unexpected( contained.error() ); + } + + template< typename Ex > + bool has_exception() const + { + using ContainedEx = typename std::remove_reference< decltype( get_unexpected().value() ) >::type; + return ! has_value() && std::is_base_of< Ex, ContainedEx>::value; + } + +#if nsel_P2505R >= 4 + template< typename G = E + nsel_REQUIRES_T( + std::is_copy_constructible< E >::value + && std::is_convertible< G, E >::value + ) + > + nsel_constexpr error_type error_or( G && e ) const & + { + return has_value() + ? static_cast< E >( std::forward< G >( e ) ) + : contained.error(); + } + + template< typename G = E + nsel_REQUIRES_T( + std::is_move_constructible< E >::value + && std::is_convertible< G, E >::value + ) + > + nsel_constexpr14 error_type error_or( G && e ) && + { + return has_value() + ? static_cast< E >( std::forward< G >( e ) ) + : std::move( contained.error() ); + } +#endif // nsel_P2505R >= 4 + +#if nsel_P2505R >= 3 + // Monadic operations (P2505) + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F >::error_type, error_type >::value + && std::is_constructible< error_type, error_type & >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F > and_then( F && f ) & + { + return has_value() + ? detail::invoke_result_nocvref_t< F >( detail::invoke( std::forward< F >( f ) ) ) + : detail::invoke_result_nocvref_t< F >( unexpect, error() ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F >::error_type, error_type >::value + && std::is_constructible< error_type, const error_type & >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F > and_then( F && f ) const & + { + return has_value() + ? detail::invoke_result_nocvref_t< F >( detail::invoke( std::forward< F >( f ) ) ) + : detail::invoke_result_nocvref_t< F >( unexpect, error() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F >::error_type, error_type >::value + && std::is_constructible< error_type, error_type && >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F > and_then( F && f ) && + { + return has_value() + ? detail::invoke_result_nocvref_t< F >( detail::invoke( std::forward< F >( f ) ) ) + : detail::invoke_result_nocvref_t< F >( unexpect, std::move( error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F > >::value + && std::is_same< typename detail::invoke_result_nocvref_t< F >::error_type, error_type >::value + && std::is_constructible< error_type, const error_type && >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F > and_then( F && f ) const && + { + return has_value() + ? detail::invoke_result_nocvref_t< F >( detail::invoke( std::forward< F >( f ) ) ) + : detail::invoke_result_nocvref_t< F >( unexpect, std::move( error() ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, error_type & > >::value + && std::is_void< typename detail::invoke_result_nocvref_t< F, error_type & >::value_type >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, error_type & > or_else( F && f ) & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, error_type & >() + : detail::invoke_result_nocvref_t< F, error_type & >( detail::invoke( std::forward< F >( f ), error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const error_type & > >::value + && std::is_void< typename detail::invoke_result_nocvref_t< F, const error_type & >::value_type >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const error_type & > or_else( F && f ) const & + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const error_type & >() + : detail::invoke_result_nocvref_t< F, const error_type & >( detail::invoke( std::forward< F >( f ), error() ) ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, error_type && > >::value + && std::is_void< typename detail::invoke_result_nocvref_t< F, error_type && >::value_type >::value + ) + > + nsel_constexpr14 detail::invoke_result_nocvref_t< F, error_type && > or_else( F && f ) && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, error_type && >() + : detail::invoke_result_nocvref_t< F, error_type && >( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::is_expected< detail::invoke_result_nocvref_t< F, const error_type && > >::value + && std::is_void< typename detail::invoke_result_nocvref_t< F, const error_type && >::value_type >::value + ) + > + nsel_constexpr detail::invoke_result_nocvref_t< F, const error_type && > or_else( F && f ) const && + { + return has_value() + ? detail::invoke_result_nocvref_t< F, const error_type && >() + : detail::invoke_result_nocvref_t< F, const error_type && >( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type & >::value + && !std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr14 expected< detail::transform_invoke_result_t< F >, error_type > transform( F && f ) & + { + return has_value() + ? expected< detail::transform_invoke_result_t< F >, error_type >( detail::invoke( std::forward< F >( f ) ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type & >::value + && std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr14 expected< void, error_type > transform( F && f ) & + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ) ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type & >::value + && !std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr expected< detail::transform_invoke_result_t< F >, error_type > transform( F && f ) const & + { + return has_value() + ? expected< detail::transform_invoke_result_t< F >, error_type >( detail::invoke( std::forward< F >( f ) ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type & >::value + && std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr expected< void, error_type > transform( F && f ) const & + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ) ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type && >::value + && !std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr14 expected< detail::transform_invoke_result_t< F >, error_type > transform( F && f ) && + { + return has_value() + ? expected< detail::transform_invoke_result_t< F >, error_type >( detail::invoke( std::forward< F >( f ) ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, error_type && >::value + && std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr14 expected< void, error_type > transform( F && f ) && + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ) ), expected< void, error_type >() ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type && >::value + && !std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr expected< detail::transform_invoke_result_t< F >, error_type > transform( F && f ) const && + { + return has_value() + ? expected< detail::transform_invoke_result_t< F >, error_type >( detail::invoke( std::forward< F >( f ) ) ) + : make_unexpected( error() ); + } + + template<typename F + nsel_REQUIRES_T( + std::is_constructible< error_type, const error_type && >::value + && std::is_void< detail::transform_invoke_result_t< F > >::value + ) + > + nsel_constexpr expected< void, error_type > transform( F && f ) const && + { + return has_value() + ? ( detail::invoke( std::forward< F >( f ) ), expected< void, error_type >() ) + : make_unexpected( error() ); + } +#endif + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, error_type & > >::value + ) + > + nsel_constexpr14 expected< void, detail::transform_invoke_result_t< F, error_type & > > transform_error( F && f ) & + { + return has_value() + ? expected< void, detail::transform_invoke_result_t< F, error_type & > >() + : make_unexpected( detail::invoke( std::forward< F >( f ), error() ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, const error_type & > >::value + ) + > + nsel_constexpr expected< void, detail::transform_invoke_result_t< F, const error_type & > > transform_error( F && f ) const & + { + return has_value() + ? expected< void, detail::transform_invoke_result_t< F, const error_type & > >() + : make_unexpected( detail::invoke( std::forward< F >( f ), error() ) ); + } + +#if !nsel_COMPILER_GNUC_VERSION || nsel_COMPILER_GNUC_VERSION >= 490 + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, error_type && > >::value + ) + > + nsel_constexpr14 expected< void, detail::transform_invoke_result_t< F, error_type && > > transform_error( F && f ) && + { + return has_value() + ? expected< void, detail::transform_invoke_result_t< F, error_type && > >() + : make_unexpected( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } + + template<typename F + nsel_REQUIRES_T( + detail::valid_unexpected_type< detail::transform_invoke_result_t< F, const error_type && > >::value + ) + > + nsel_constexpr expected< void, detail::transform_invoke_result_t< F, const error_type && > > transform_error( F && f ) const && + { + return has_value() + ? expected< void, detail::transform_invoke_result_t< F, const error_type && > >() + : make_unexpected( detail::invoke( std::forward< F >( f ), std::move( error() ) ) ); + } +#endif +#endif // nsel_P2505R >= 3 + +// template constexpr 'see below' unwrap() const&; +// +// template 'see below' unwrap() &&; + + // factories + +// template< typename Ex, typename F> +// expected<void,E> catch_exception(F&& f); +// +// template< typename F> +// expected<decltype(func()), E> map(F&& func) ; +// +// template< typename F> +// 'see below' bind(F&& func) ; +// +// template< typename F> +// expected<void,E> catch_error(F&& f); +// +// template< typename F> +// 'see below' then(F&& func); + +private: + detail::storage_t + < + void + , E + , std::is_copy_constructible<E>::value + , std::is_move_constructible<E>::value + > + contained; +}; + +// x.x.4.6 expected<>: comparison operators + +template< typename T1, typename E1, typename T2, typename E2 + nsel_REQUIRES_T( + !std::is_void<T1>::value && !std::is_void<T2>::value + ) +> +constexpr bool operator==( expected<T1,E1> const & x, expected<T2,E2> const & y ) +{ + return bool(x) != bool(y) ? false : bool(x) ? *x == *y : x.error() == y.error(); +} + +template< typename T1, typename E1, typename T2, typename E2 + nsel_REQUIRES_T( + std::is_void<T1>::value && std::is_void<T2>::value + ) +> +constexpr bool operator==( expected<T1,E1> const & x, expected<T2,E2> const & y ) +{ + return bool(x) != bool(y) ? false : bool(x) || static_cast<bool>( x.error() == y.error() ); +} + +template< typename T1, typename E1, typename T2, typename E2 > +constexpr bool operator!=( expected<T1,E1> const & x, expected<T2,E2> const & y ) +{ + return !(x == y); +} + +#if nsel_P0323R <= 2 + +template< typename T, typename E > +constexpr bool operator<( expected<T,E> const & x, expected<T,E> const & y ) +{ + return (!y) ? false : (!x) ? true : *x < *y; +} + +template< typename T, typename E > +constexpr bool operator>( expected<T,E> const & x, expected<T,E> const & y ) +{ + return (y < x); +} + +template< typename T, typename E > +constexpr bool operator<=( expected<T,E> const & x, expected<T,E> const & y ) +{ + return !(y < x); +} + +template< typename T, typename E > +constexpr bool operator>=( expected<T,E> const & x, expected<T,E> const & y ) +{ + return !(x < y); +} + +#endif + +// x.x.4.7 expected: comparison with T + +template< typename T1, typename E1, typename T2 + nsel_REQUIRES_T( + !std::is_void<T1>::value + ) +> +constexpr bool operator==( expected<T1,E1> const & x, T2 const & v ) +{ + return bool(x) ? *x == v : false; +} + +template< typename T1, typename E1, typename T2 + nsel_REQUIRES_T( + !std::is_void<T1>::value + ) +> +constexpr bool operator==(T2 const & v, expected<T1,E1> const & x ) +{ + return bool(x) ? v == *x : false; +} + +template< typename T1, typename E1, typename T2 > +constexpr bool operator!=( expected<T1,E1> const & x, T2 const & v ) +{ + return bool(x) ? *x != v : true; +} + +template< typename T1, typename E1, typename T2 > +constexpr bool operator!=( T2 const & v, expected<T1,E1> const & x ) +{ + return bool(x) ? v != *x : true; +} + +#if nsel_P0323R <= 2 + +template< typename T, typename E > +constexpr bool operator<( expected<T,E> const & x, T const & v ) +{ + return bool(x) ? *x < v : true; +} + +template< typename T, typename E > +constexpr bool operator<( T const & v, expected<T,E> const & x ) +{ + return bool(x) ? v < *x : false; +} + +template< typename T, typename E > +constexpr bool operator>( T const & v, expected<T,E> const & x ) +{ + return bool(x) ? *x < v : false; +} + +template< typename T, typename E > +constexpr bool operator>( expected<T,E> const & x, T const & v ) +{ + return bool(x) ? v < *x : false; +} + +template< typename T, typename E > +constexpr bool operator<=( T const & v, expected<T,E> const & x ) +{ + return bool(x) ? ! ( *x < v ) : false; +} + +template< typename T, typename E > +constexpr bool operator<=( expected<T,E> const & x, T const & v ) +{ + return bool(x) ? ! ( v < *x ) : true; +} + +template< typename T, typename E > +constexpr bool operator>=( expected<T,E> const & x, T const & v ) +{ + return bool(x) ? ! ( *x < v ) : false; +} + +template< typename T, typename E > +constexpr bool operator>=( T const & v, expected<T,E> const & x ) +{ + return bool(x) ? ! ( v < *x ) : true; +} + +#endif // nsel_P0323R + +// x.x.4.8 expected: comparison with unexpected_type + +template< typename T1, typename E1 , typename E2 > +constexpr bool operator==( expected<T1,E1> const & x, unexpected_type<E2> const & u ) +{ + return (!x) ? x.get_unexpected() == u : false; +} + +template< typename T1, typename E1 , typename E2 > +constexpr bool operator==( unexpected_type<E2> const & u, expected<T1,E1> const & x ) +{ + return ( x == u ); +} + +template< typename T1, typename E1 , typename E2 > +constexpr bool operator!=( expected<T1,E1> const & x, unexpected_type<E2> const & u ) +{ + return ! ( x == u ); +} + +template< typename T1, typename E1 , typename E2 > +constexpr bool operator!=( unexpected_type<E2> const & u, expected<T1,E1> const & x ) +{ + return ! ( x == u ); +} + +#if nsel_P0323R <= 2 + +template< typename T, typename E > +constexpr bool operator<( expected<T,E> const & x, unexpected_type<E> const & u ) +{ + return (!x) ? ( x.get_unexpected() < u ) : false; +} + +template< typename T, typename E > +constexpr bool operator<( unexpected_type<E> const & u, expected<T,E> const & x ) +{ + return (!x) ? ( u < x.get_unexpected() ) : true ; +} + +template< typename T, typename E > +constexpr bool operator>( expected<T,E> const & x, unexpected_type<E> const & u ) +{ + return ( u < x ); +} + +template< typename T, typename E > +constexpr bool operator>( unexpected_type<E> const & u, expected<T,E> const & x ) +{ + return ( x < u ); +} + +template< typename T, typename E > +constexpr bool operator<=( expected<T,E> const & x, unexpected_type<E> const & u ) +{ + return ! ( u < x ); +} + +template< typename T, typename E > +constexpr bool operator<=( unexpected_type<E> const & u, expected<T,E> const & x) +{ + return ! ( x < u ); +} + +template< typename T, typename E > +constexpr bool operator>=( expected<T,E> const & x, unexpected_type<E> const & u ) +{ + return ! ( u > x ); +} + +template< typename T, typename E > +constexpr bool operator>=( unexpected_type<E> const & u, expected<T,E> const & x ) +{ + return ! ( x > u ); +} + +#endif // nsel_P0323R + +/// x.x.x Specialized algorithms + +template< typename T, typename E + nsel_REQUIRES_T( + ( std::is_void<T>::value || std::is_move_constructible<T>::value ) + && std::is_move_constructible<E>::value + && std17::is_swappable<T>::value + && std17::is_swappable<E>::value ) +> +void swap( expected<T,E> & x, expected<T,E> & y ) noexcept ( noexcept ( x.swap(y) ) ) +{ + x.swap( y ); +} + +#if nsel_P0323R <= 3 + +template< typename T > +constexpr auto make_expected( T && v ) -> expected< typename std::decay<T>::type > +{ + return expected< typename std::decay<T>::type >( std::forward<T>( v ) ); +} + +// expected<void> specialization: + +auto inline make_expected() -> expected<void> +{ + return expected<void>( in_place ); +} + +template< typename T > +constexpr auto make_expected_from_current_exception() -> expected<T> +{ + return expected<T>( make_unexpected_from_current_exception() ); +} + +template< typename T > +auto make_expected_from_exception( std::exception_ptr v ) -> expected<T> +{ + return expected<T>( unexpected_type<std::exception_ptr>( std::forward<std::exception_ptr>( v ) ) ); +} + +template< typename T, typename E > +constexpr auto make_expected_from_error( E e ) -> expected<T, typename std::decay<E>::type> +{ + return expected<T, typename std::decay<E>::type>( make_unexpected( e ) ); +} + +template< typename F + nsel_REQUIRES_T( ! std::is_same<typename std::result_of<F()>::type, void>::value ) +> +/*nsel_constexpr14*/ +auto make_expected_from_call( F f ) -> expected< typename std::result_of<F()>::type > +{ + try + { + return make_expected( f() ); + } + catch (...) + { + return make_unexpected_from_current_exception(); + } +} + +template< typename F + nsel_REQUIRES_T( std::is_same<typename std::result_of<F()>::type, void>::value ) +> +/*nsel_constexpr14*/ +auto make_expected_from_call( F f ) -> expected<void> +{ + try + { + f(); + return make_expected(); + } + catch (...) + { + return make_unexpected_from_current_exception(); + } +} + +#endif // nsel_P0323R + +} // namespace expected_lite + +using namespace expected_lite; + +// using expected_lite::expected; +// using ... + +} // namespace future_std + +namespace std { + +// expected: hash support + +template< typename T, typename E > +struct hash< future_std::expected<T,E> > +{ + using result_type = std::size_t; + using argument_type = future_std::expected<T,E>; + + constexpr result_type operator()(argument_type const & arg) const + { + return arg ? std::hash<T>{}(*arg) : result_type{}; + } +}; + +// TBD - ?? remove? see spec. +template< typename T, typename E > +struct hash< future_std::expected<T&,E> > +{ + using result_type = std::size_t; + using argument_type = future_std::expected<T&,E>; + + constexpr result_type operator()(argument_type const & arg) const + { + return arg ? std::hash<T>{}(*arg) : result_type{}; + } +}; + +// TBD - implement +// bool(e), hash<expected<void,E>>()(e) shall evaluate to the hashing true; +// otherwise it evaluates to an unspecified value if E is exception_ptr or +// a combination of hashing false and hash<E>()(e.error()). + +template< typename E > +struct hash< future_std::expected<void,E> > +{ +}; + +} // namespace std + +namespace future_std { + +// void unexpected() is deprecated && removed in C++17 + +#if nsel_CPP17_OR_GREATER || nsel_COMPILER_MSVC_VERSION > 141 +template< typename E > +using unexpected = unexpected_type<E>; +#endif + +} // namespace future_std + +#undef nsel_REQUIRES +#undef nsel_REQUIRES_0 +#undef nsel_REQUIRES_T + +nsel_RESTORE_WARNINGS() + +#endif // nsel_USES_STD_EXPECTED + +#endif // AIDGE_CORE_UTILS_FUTURE_STD_EXPECTED_H_ diff --git a/include/aidge/utilsParsing/AstNode.hpp b/include/aidge/utilsParsing/AstNode.hpp index 1158ae148a22993476adb00ecbf8ebd24101830c..bf4f73236fb65b88da309e71ba55997b5342df41 100644 --- a/include/aidge/utilsParsing/AstNode.hpp +++ b/include/aidge/utilsParsing/AstNode.hpp @@ -1,7 +1,7 @@ -#ifndef _AIDGE_AST_NODE_H_ -#define _AIDGE_AST_NODE_H_ +#ifndef AIDGE_CORE_AST_NODE_H_ +#define AIDGE_CORE_AST_NODE_H_ #include <string> #include <type_traits> @@ -12,11 +12,11 @@ namespace Aidge{ template <typename EnumType> - class AstNode: public std::enable_shared_from_this<AstNode> + class AstNode: public std::enable_shared_from_this<AstNode<EnumType>> { static_assert(std::is_enum<EnumType>::value, "AstNode EnumType must be an enum type"); public: - AstNode(std::shared_ptr<ParsingToken<EnumType>> token,std::vector<std::shared_ptr<AstNode>> child ={}):mToken(token),mChild(child){} + AstNode(std::shared_ptr<ParsingToken<EnumType>> token,std::vector<std::shared_ptr<AstNode<EnumType>>> child ={}):mToken(token),mChild(child){} /** * @brief get the type of the token * @return the type @@ -41,7 +41,7 @@ namespace Aidge{ } /** * @brief test if the node is a leaf in the tree - * @return true if a leaf + * @return true if a leaf */ bool isLeaf() const { return mChild.size() == 0; @@ -66,4 +66,4 @@ namespace Aidge{ }; } -#endif //_AIDGE_AST_NODE_H_ +#endif //AIDGE_CORE_AST_NODE_H_ diff --git a/include/aidge/utilsParsing/ParsingToken.hpp b/include/aidge/utilsParsing/ParsingToken.hpp index 78045cf3085a18bfd0565354fd34aef02ef395bd..e303a5eabe6f7710873468f8edc8f3e844f4175f 100644 --- a/include/aidge/utilsParsing/ParsingToken.hpp +++ b/include/aidge/utilsParsing/ParsingToken.hpp @@ -1,13 +1,15 @@ -#ifndef _AIDGE_PARSING_TOKEN_H_ -#define _AIDGE_PARSING_TOKEN_H_ +#ifndef AIDGE_CORE_PARSING_TOKEN_H_ +#define AIDGE_CORE_PARSING_TOKEN_H_ #include <string> #include <type_traits> +#include <sstream> // Include the necessary header namespace Aidge{ + template <typename EnumType> - class ParsingToken: public std::enable_shared_from_this<ParsingToken> + class ParsingToken: public std::enable_shared_from_this<ParsingToken<EnumType>> { static_assert(std::is_enum<EnumType>::value, "ParsingToken EnumType must be an enum type"); public: @@ -16,11 +18,11 @@ namespace Aidge{ * @param type one of the token type * @param lexeme String representing aditional information of the token */ - ParsingToken(const EnumType type , const std::string lexeme )mLexeme(lexeme),mType(type){} + ParsingToken(const EnumType type , const std::string lexeme ):mLexeme(lexeme),mType(type){} /** * @brief get the lexeme - * @return std::string + * @return std::string */ const std::string getLexeme(void){ return mLexeme; @@ -28,8 +30,8 @@ namespace Aidge{ /** * @brief get the token type - * - * @return ParsingToken + * + * @return ParsingToken */ const EnumType getType(void){ return mType; @@ -39,7 +41,10 @@ namespace Aidge{ * @brief copy the token * @return deep copy of the token */ - std::shared_ptr<Aidge::ParsingToken> copy(); + std::shared_ptr<ParsingToken> copy(){ + auto newToken = std::make_shared<ParsingToken<EnumType>>(mType,mLexeme); + return newToken; + } //TODO std::ostringstream rep(void){ @@ -47,6 +52,7 @@ namespace Aidge{ out << " Token (" << mLexeme <<")" << "\n"; return out; } + private: /** @@ -63,4 +69,4 @@ namespace Aidge{ }; } -#endif //_AIDGE_PARSING_TOKEN_H_ \ No newline at end of file +#endif //AIDGE_CORE_PARSING_TOKEN_H_ diff --git a/python_binding/operator/pybind_Add.cpp b/python_binding/operator/pybind_Add.cpp index ab8b4cf7b91d5eea2db5245a8c5122ab004b4766..0b2323c5cfb660415ec3ae009beaa7aa78afca0b 100644 --- a/python_binding/operator/pybind_Add.cpp +++ b/python_binding/operator/pybind_Add.cpp @@ -20,7 +20,9 @@ namespace py = pybind11; namespace Aidge { template <std::size_t NUM> void declare_Add(py::module &m) { - py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "Add_Op", py::multiple_inheritance()); + py::class_<Add_Op<NUM>, std::shared_ptr<Add_Op<NUM>>, Operator>(m, "AddOp", py::multiple_inheritance()) + .def("get_inputs_name", &Add_Op<NUM>::getInputsName) + .def("get_outputs_name", &Add_Op<NUM>::getOutputsName); m.def("Add", &Add<NUM>, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_AvgPooling.cpp b/python_binding/operator/pybind_AvgPooling.cpp index 372afebdd3e1626cd0af88e335b78ec7fd73a5f4..fe67fcb7a26f6ea1f05577b47444df5cb271110a 100644 --- a/python_binding/operator/pybind_AvgPooling.cpp +++ b/python_binding/operator/pybind_AvgPooling.cpp @@ -8,7 +8,7 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ -#ifdef PYBIND + #include <pybind11/pybind11.h> #include <pybind11/stl.h> @@ -30,48 +30,23 @@ template <DimIdx_t DIM> void declare_AvgPoolingOp(py::module &m) { m, ("AvgPoolingOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) .def(py::init<const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &>(), + const std::array<DimSize_t, DIM> &>(), py::arg("kernel_dims"), - py::arg("stride_dims"), - py::arg("padding_dims")); - - m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + py::arg("stride_dims")) + .def("get_inputs_name", &AvgPooling_Op<DIM>::getInputsName) + .def("get_outputs_name", &AvgPooling_Op<DIM>::getOutputsName); + + m.def(("AvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, const std::string& name, - const std::vector<DimSize_t> &stride_dims, - const std::vector<DimSize_t> &padding_dims) { - // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. - if (kernel_dims.size() != DIM) { - throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (stride_dims.size() != DIM) { - throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } - DimSize_t tmp_kernel_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_kernel_dims_array[i] = kernel_dims[i]; - } - DimSize_t tmp_stride_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_stride_dims_array[i] = stride_dims[i]; - } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } - const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; - const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - return AvgPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array)); + const std::vector<DimSize_t> &stride_dims) { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + + return AvgPooling<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin())); }, py::arg("kernel_dims"), py::arg("name") = "", - py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0)); - + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1)); + } @@ -79,10 +54,9 @@ void init_AvgPooling(py::module &m) { declare_AvgPoolingOp<1>(m); declare_AvgPoolingOp<2>(m); declare_AvgPoolingOp<3>(m); - + // FIXME: // m.def("AvgPooling1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&AvgPooling)); } } // namespace Aidge -#endif \ No newline at end of file diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index f43381fecc689a292e166c4da40ea0cb4842c9e6..cabaa2edd7053718160fa5013492d1914ee4cf16 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -21,7 +21,9 @@ namespace Aidge { template <DimSize_t DIM> void declare_BatchNormOp(py::module& m) { - py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, Attributes>(m, ("BatchNorm_Op" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()); + py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, Operator, Attributes>(m, ("BatchNormOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) + .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) + .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_Conv.cpp b/python_binding/operator/pybind_Conv.cpp index 6807520c57696de4ae74c414a71843a499fabc03..f4f7946c6ecc180f83e4bf58eee16102752f0c6e 100644 --- a/python_binding/operator/pybind_Conv.cpp +++ b/python_binding/operator/pybind_Conv.cpp @@ -32,65 +32,33 @@ template <DimIdx_t DIM> void declare_ConvOp(py::module &m) { DimSize_t, const std::array<DimSize_t, DIM> &, const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &, const std::array<DimSize_t, DIM> &>(), py::arg("in_channels"), py::arg("out_channels"), py::arg("kernel_dims"), py::arg("stride_dims"), - py::arg("padding_dims"), - py::arg("dilation_dims")); + py::arg("dilation_dims")) + .def("get_inputs_name", &Conv_Op<DIM>::getInputsName) + .def("get_outputs_name", &Conv_Op<DIM>::getOutputsName) + ; m.def(("Conv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, DimSize_t out_channels, const std::vector<DimSize_t>& kernel_dims, const std::string& name, const std::vector<DimSize_t> &stride_dims, - const std::vector<DimSize_t> &padding_dims, const std::vector<DimSize_t> &dilation_dims) { - // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. - if (kernel_dims.size() != DIM) { - throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (stride_dims.size() != DIM) { - throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } - if (dilation_dims.size() != DIM) { - throw std::runtime_error("dilation_dims size [" + std::to_string(dilation_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - DimSize_t tmp_kernel_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_kernel_dims_array[i] = kernel_dims[i]; - } - DimSize_t tmp_stride_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_stride_dims_array[i] = stride_dims[i]; - } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } - DimSize_t tmp_dilation_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_dilation_dims_array[i] = dilation_dims[i]; - } - const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; - const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - const DimSize_t (&dilation_dims_array)[DIM] = tmp_dilation_dims_array; - return Conv<DIM>(in_channels, out_channels, to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array), to_array(dilation_dims_array)); + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(dilation_dims.size() == DIM, "dilation_dims size [%ld] does not match DIM [%d]", dilation_dims.size(), DIM); + + return Conv<DIM>(in_channels, out_channels, to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<DIM>(dilation_dims.begin())); }, py::arg("in_channels"), py::arg("out_channels"), py::arg("kernel_dims"), py::arg("name") = "", py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); - } diff --git a/python_binding/operator/pybind_ConvDepthWise.cpp b/python_binding/operator/pybind_ConvDepthWise.cpp index 3f48c50f7ffdb44450c0e2a155d85dcbf9f73fd9..4745ef345264763f1a890d566235be072c8e50d8 100644 --- a/python_binding/operator/pybind_ConvDepthWise.cpp +++ b/python_binding/operator/pybind_ConvDepthWise.cpp @@ -31,59 +31,27 @@ template <DimIdx_t DIM> void declare_ConvDepthWiseOp(py::module &m) { py::multiple_inheritance()) .def(py::init<const std::array<DimSize_t, DIM> &, const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &, const std::array<DimSize_t, DIM> &>(), py::arg("kernel_dims"), py::arg("stride_dims"), - py::arg("padding_dims"), - py::arg("dilation_dims")); - - m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + py::arg("dilation_dims")) + .def("get_inputs_name", &ConvDepthWise_Op<DIM>::getInputsName) + .def("get_outputs_name", &ConvDepthWise_Op<DIM>::getOutputsName); + + m.def(("ConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, const std::string& name, const std::vector<DimSize_t> &stride_dims, - const std::vector<DimSize_t> &padding_dims, const std::vector<DimSize_t> &dilation_dims) { - // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. - if (kernel_dims.size() != DIM) { - throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (stride_dims.size() != DIM) { - throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } - if (dilation_dims.size() != DIM) { - throw std::runtime_error("dilation_dims size [" + std::to_string(dilation_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - DimSize_t tmp_kernel_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_kernel_dims_array[i] = kernel_dims[i]; - } - DimSize_t tmp_stride_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_stride_dims_array[i] = stride_dims[i]; - } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } - DimSize_t tmp_dilation_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_dilation_dims_array[i] = dilation_dims[i]; - } - const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; - const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - const DimSize_t (&dilation_dims_array)[DIM] = tmp_dilation_dims_array; - return ConvDepthWise<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array), to_array(dilation_dims_array)); + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(dilation_dims.size() == DIM, "dilation_dims size [%ld] does not match DIM [%d]", dilation_dims.size(), DIM); + + return ConvDepthWise<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<DIM>(dilation_dims.begin())); }, py::arg("kernel_dims"), py::arg("name") = "", py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0), py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); - + } @@ -91,7 +59,7 @@ void init_ConvDepthWise(py::module &m) { declare_ConvDepthWiseOp<1>(m); declare_ConvDepthWiseOp<2>(m); declare_ConvDepthWiseOp<3>(m); - + // FIXME: // m.def("ConvDepthWise1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&ConvDepthWise)); diff --git a/python_binding/operator/pybind_FC.cpp b/python_binding/operator/pybind_FC.cpp index 4b9d61d082ebed4d426b41efa071d3943f83d231..c6a1c70000e3e6d604a6652716667efa1c18e956 100644 --- a/python_binding/operator/pybind_FC.cpp +++ b/python_binding/operator/pybind_FC.cpp @@ -20,7 +20,9 @@ namespace py = pybind11; namespace Aidge { void declare_FC(py::module &m) { - py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator, Attributes>(m, "FC_Op", py::multiple_inheritance()); + py::class_<FC_Op, std::shared_ptr<FC_Op>, Operator, Attributes>(m, "FCOp", py::multiple_inheritance()) + .def("get_inputs_name", &FC_Op::getInputsName) + .def("get_outputs_name", &FC_Op::getOutputsName); m.def("FC", &FC, py::arg("out_channels"), py::arg("nobias") = false, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_LeakyReLU.cpp b/python_binding/operator/pybind_LeakyReLU.cpp index cae8a88bab7b59189dfbc6528cd653f1c97cb73a..af7689f0e64dd4ca8f798dcb34ea968972ace464 100644 --- a/python_binding/operator/pybind_LeakyReLU.cpp +++ b/python_binding/operator/pybind_LeakyReLU.cpp @@ -18,7 +18,9 @@ namespace py = pybind11; namespace Aidge { void init_LeakyReLU(py::module& m) { - py::class_<LeakyReLU_Op, std::shared_ptr<LeakyReLU_Op>, Operator, Attributes>(m, "LeakyReLU_Op", py::multiple_inheritance()); + py::class_<LeakyReLU_Op, std::shared_ptr<LeakyReLU_Op>, Operator, Attributes>(m, "LeakyReLUOp", py::multiple_inheritance()) + .def("get_inputs_name", &LeakyReLU_Op::getInputsName) + .def("get_outputs_name", &LeakyReLU_Op::getOutputsName); m.def("LeakyReLU", &LeakyReLU, py::arg("negative_slope") = 0.0f, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_Matmul.cpp b/python_binding/operator/pybind_Matmul.cpp index 2f738550041bcdb1ae809d68fa24fdf5a72e9164..fdb51b24a87ce358c1e7808873ebc569ca2227c8 100644 --- a/python_binding/operator/pybind_Matmul.cpp +++ b/python_binding/operator/pybind_Matmul.cpp @@ -20,7 +20,9 @@ namespace py = pybind11; namespace Aidge { void declare_MatMul(py::module &m) { - py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, Operator, Attributes>(m, "MatMul_Op", py::multiple_inheritance()); + py::class_<MatMul_Op, std::shared_ptr<MatMul_Op>, Operator, Attributes>(m, "MatMulOp", py::multiple_inheritance()) + .def("get_inputs_name", &MatMul_Op::getInputsName) + .def("get_outputs_name", &MatMul_Op::getOutputsName); m.def("MatMul", &MatMul, py::arg("out_channels"), py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_MaxPooling.cpp b/python_binding/operator/pybind_MaxPooling.cpp index 2efd18c816c2d588e574872b3d3776a3409dc4ba..c83dfaa3639f05af345bd9214460f95fd661cd31 100644 --- a/python_binding/operator/pybind_MaxPooling.cpp +++ b/python_binding/operator/pybind_MaxPooling.cpp @@ -8,7 +8,7 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ -#ifdef PYBIND + #include <pybind11/pybind11.h> #include <pybind11/stl.h> @@ -30,48 +30,23 @@ template <DimIdx_t DIM> void declare_MaxPoolingOp(py::module &m) { m, ("MaxPoolingOp" + std::to_string(DIM) + "D").c_str(), py::multiple_inheritance()) .def(py::init<const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, DIM> &, - const std::array<DimSize_t, (DIM<<1)> &>(), + const std::array<DimSize_t, DIM> &>(), py::arg("kernel_dims"), - py::arg("stride_dims"), - py::arg("padding_dims")); - - m.def(("MaxPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + py::arg("stride_dims")) + .def("get_inputs_name", &MaxPooling_Op<DIM>::getInputsName) + .def("get_outputs_name", &MaxPooling_Op<DIM>::getOutputsName); + + m.def(("MaxPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, const std::string& name, - const std::vector<DimSize_t> &stride_dims, - const std::vector<DimSize_t> &padding_dims) { - // Lambda function wrapper because PyBind fails to convert const array. - // So we use a vector that we convert in this function to a const DimeSize_t [DIM] array. - if (kernel_dims.size() != DIM) { - throw std::runtime_error("kernel_dims size [" + std::to_string(kernel_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (stride_dims.size() != DIM) { - throw std::runtime_error("stride_dims size [" + std::to_string(stride_dims.size()) + "] does not match DIM [" + std::to_string(DIM) +"]"); - } - if (padding_dims.size() != (DIM<<1)) { - throw std::runtime_error("padding_dims size [" + std::to_string(padding_dims.size()) + "] does not match DIM [" + std::to_string(DIM<<1) +"]"); - } - DimSize_t tmp_kernel_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_kernel_dims_array[i] = kernel_dims[i]; - } - DimSize_t tmp_stride_dims_array[DIM]; - for (size_t i = 0; i < DIM; ++i) { - tmp_stride_dims_array[i] = stride_dims[i]; - } - DimSize_t tmp_padding_dims_array[DIM<<1]; - for (size_t i = 0; i < (DIM<<1); ++i) { - tmp_padding_dims_array[i] = padding_dims[i]; - } - const DimSize_t (&kernel_dims_array)[DIM] = tmp_kernel_dims_array; - const DimSize_t (&stride_dims_array)[DIM] = tmp_stride_dims_array; - const DimSize_t (&padding_dims_array)[DIM<<1] = tmp_padding_dims_array; - return MaxPooling<DIM>(to_array(kernel_dims_array), name, to_array(stride_dims_array), to_array(padding_dims_array)); + const std::vector<DimSize_t> &stride_dims) { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + + return MaxPooling<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin())); }, py::arg("kernel_dims"), py::arg("name") = "", - py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), - py::arg("padding_dims") = std::vector<DimSize_t>(DIM<<1,0)); - + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1)); + } @@ -79,10 +54,9 @@ void init_MaxPooling(py::module &m) { declare_MaxPoolingOp<1>(m); declare_MaxPoolingOp<2>(m); declare_MaxPoolingOp<3>(m); - + // FIXME: // m.def("MaxPooling1D", static_cast<NodeAPI(*)(const char*, int, int, int const // (&)[1])>(&MaxPooling)); } } // namespace Aidge -#endif \ No newline at end of file diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3372d50e14be9e0d24ba5d9171766255ab49f23b --- /dev/null +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -0,0 +1,126 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> + +#include <string> +#include <vector> +#include <array> + +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/operator/Operator.hpp" +#include "aidge/utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +template <DimIdx_t DIM> void declare_PaddedConvOp(py::module &m) { + m.def(("PaddedConv" + std::to_string(DIM) + "D").c_str(), [](DimSize_t in_channels, + DimSize_t out_channels, + const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &padding_dims, + const std::vector<DimSize_t> &dilation_dims) + { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(padding_dims.size() == 2*DIM, "padding_dims size [%ld] does not match DIM [%d]", padding_dims.size(), 2*DIM); + AIDGE_ASSERT(dilation_dims.size() == DIM, "dilation_dims size [%ld] does not match DIM [%d]", dilation_dims.size(), DIM); + + return PaddedConv<DIM>(in_channels, out_channels, to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<2*DIM>(padding_dims.begin()), to_array<DIM>(dilation_dims.begin())); + }, py::arg("in_channels"), + py::arg("out_channels"), + py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(2*DIM,0), + py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); +} + +template <DimIdx_t DIM> void declare_PaddedConvDepthWiseOp(py::module &m) { + m.def(("PaddedConvDepthWise" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &padding_dims, + const std::vector<DimSize_t> &dilation_dims) + { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(padding_dims.size() == 2*DIM, "padding_dims size [%ld] does not match DIM [%d]", padding_dims.size(), 2*DIM); + AIDGE_ASSERT(dilation_dims.size() == DIM, "dilation_dims size [%ld] does not match DIM [%d]", dilation_dims.size(), DIM); + + return PaddedConvDepthWise<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<2*DIM>(padding_dims.begin()), to_array<DIM>(dilation_dims.begin())); + }, py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(2*DIM,0), + py::arg("dilation_dims") = std::vector<DimSize_t>(DIM,1)); + +} + +template <DimIdx_t DIM> void declare_PaddedAvgPoolingOp(py::module &m) { + m.def(("PaddedAvgPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &padding_dims) + { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(padding_dims.size() == 2*DIM, "padding_dims size [%ld] does not match DIM [%d]", padding_dims.size(), 2*DIM); + + return PaddedAvgPooling<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<2*DIM>(padding_dims.begin())); + }, py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(2*DIM,0)); + +} + +template <DimIdx_t DIM> void declare_PaddedMaxPoolingOp(py::module &m) { + m.def(("PaddedMaxPooling" + std::to_string(DIM) + "D").c_str(), [](const std::vector<DimSize_t>& kernel_dims, + const std::string& name, + const std::vector<DimSize_t> &stride_dims, + const std::vector<DimSize_t> &padding_dims) + { + AIDGE_ASSERT(kernel_dims.size() == DIM, "kernel_dims size [%ld] does not match DIM [%d]", kernel_dims.size(), DIM); + AIDGE_ASSERT(stride_dims.size() == DIM, "stride_dims size [%ld] does not match DIM [%d]", stride_dims.size(), DIM); + AIDGE_ASSERT(padding_dims.size() == 2*DIM, "padding_dims size [%ld] does not match DIM [%d]", padding_dims.size(), 2*DIM); + + return PaddedMaxPooling<DIM>(to_array<DIM>(kernel_dims.begin()), name, to_array<DIM>(stride_dims.begin()), to_array<2*DIM>(padding_dims.begin())); + }, py::arg("kernel_dims"), + py::arg("name") = "", + py::arg("stride_dims") = std::vector<DimSize_t>(DIM,1), + py::arg("padding_dims") = std::vector<DimSize_t>(2*DIM,0)); + +} + +void init_MetaOperatorDefs(py::module &m) { + declare_PaddedConvOp<1>(m); + declare_PaddedConvOp<2>(m); + declare_PaddedConvOp<3>(m); + declare_PaddedConvDepthWiseOp<1>(m); + declare_PaddedConvDepthWiseOp<2>(m); + declare_PaddedConvDepthWiseOp<3>(m); + declare_PaddedAvgPoolingOp<1>(m); + declare_PaddedAvgPoolingOp<2>(m); + declare_PaddedAvgPoolingOp<3>(m); + declare_PaddedMaxPoolingOp<1>(m); + declare_PaddedMaxPoolingOp<2>(m); + declare_PaddedMaxPoolingOp<3>(m); + + // FIXME: + // m.def("Conv1D", static_cast<NodeAPI(*)(const char*, int, int, int const + // (&)[1])>(&Conv)); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index b0866970c57798f11fe2efff6777e11af53dd37e..6b535e8cf3293b26aaa64f95ca2f9a394768935f 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -20,6 +20,7 @@ void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("output", &Operator::output, py::arg("outputIdx")) .def("input", &Operator::input, py::arg("inputIdx")) + .def("nb_data_inputs", &Operator::nbDataInputs) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_backend", &Operator::setBackend, py::arg("name")) diff --git a/python_binding/operator/pybind_Producer.cpp b/python_binding/operator/pybind_Producer.cpp index 1c62cd0adf6b8712073ec0674754ce7c8c2014a5..107b7ba00e4077d9f7c215257bf7fd46629481c1 100644 --- a/python_binding/operator/pybind_Producer.cpp +++ b/python_binding/operator/pybind_Producer.cpp @@ -35,7 +35,9 @@ void init_Producer(py::module &m) { "ProducerOp", py::multiple_inheritance()) .def("dims", &Producer_Op::dims) - .def("set_output_tensor", &Producer_Op::setOutputTensor); + .def("set_output_tensor", &Producer_Op::setOutputTensor) + .def("get_inputs_name", &Producer_Op::getInputsName) + .def("get_outputs_name", &Producer_Op::getOutputsName); m.def("Producer", static_cast<std::shared_ptr<Node>(*)(const std::shared_ptr<Tensor>, const std::string&)>(&Producer), py::arg("tensor"), py::arg("name") = ""); declare_Producer<1>(m); diff --git a/python_binding/operator/pybind_ReLU.cpp b/python_binding/operator/pybind_ReLU.cpp index 820589d76507b39ca65ac2397614aabd1221fe3e..dbcb483e8089373bc8599c2d09fed00049e2a2ac 100644 --- a/python_binding/operator/pybind_ReLU.cpp +++ b/python_binding/operator/pybind_ReLU.cpp @@ -18,7 +18,9 @@ namespace py = pybind11; namespace Aidge { void init_ReLU(py::module& m) { - py::class_<ReLU_Op, std::shared_ptr<ReLU_Op>, Operator>(m, "ReLU_Op", py::multiple_inheritance()); + py::class_<ReLU_Op, std::shared_ptr<ReLU_Op>, Operator>(m, "ReLUOp", py::multiple_inheritance()) + .def("get_inputs_name", &ReLU_Op::getInputsName) + .def("get_outputs_name", &ReLU_Op::getOutputsName); m.def("ReLU", &ReLU, py::arg("name") = ""); } diff --git a/python_binding/operator/pybind_Softmax.cpp b/python_binding/operator/pybind_Softmax.cpp index 72ac1107181c1d7e2f578e31a965636dbb5c111b..8e50ab7c83bf43285b357cb803c0ce3eb42f4cc7 100644 --- a/python_binding/operator/pybind_Softmax.cpp +++ b/python_binding/operator/pybind_Softmax.cpp @@ -19,7 +19,9 @@ namespace py = pybind11; namespace Aidge { void init_Softmax(py::module& m) { - py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "Softmax_Op", py::multiple_inheritance()); + py::class_<Softmax_Op, std::shared_ptr<Softmax_Op>, Operator>(m, "SoftmaxOp", py::multiple_inheritance()) + .def("get_inputs_name", &Softmax_Op::getInputsName) + .def("get_outputs_name", &Softmax_Op::getOutputsName); m.def("Softmax", &Softmax, py::arg("name") = ""); } diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 8bd2c51d25561957165dae36b166dee34f2b16e2..04e39b11e58718dfcc5f9faef24b140132367700 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -30,6 +30,7 @@ void init_GenericOperator(py::module&); void init_LeakyReLU(py::module&); void init_MatMul(py::module&); void init_MaxPooling(py::module&); +void init_MetaOperatorDefs(py::module&); void init_Producer(py::module&); void init_ReLU(py::module&); void init_Softmax(py::module&); @@ -71,6 +72,7 @@ void init_Aidge(py::module& m){ init_LeakyReLU(m); init_MatMul(m); init_MaxPooling(m); + init_MetaOperatorDefs(m); init_ReLU(m); init_Softmax(m); diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..24ce15ab7ead32f98c7ac3edcd34bb2010ff4326 --- /dev/null +++ b/requirements.txt @@ -0,0 +1 @@ +numpy diff --git a/setup.py b/setup.py index 16305afdfdfa5de2e328460d9e96c77eb96a9d98..60807df560510ad4cfacfdd2b178aca957306439 100644 --- a/setup.py +++ b/setup.py @@ -66,11 +66,13 @@ class CMakeBuild(build_ext): # used to launch setup.py to setup PythonInterp param_py = "-DPYTHON_EXECUTABLE=" + sys.executable + compile_type = 'Debug' install_path = os.path.join(sys.prefix, "lib", "libAidge") if "AIDGE_INSTALL" not in os.environ else os.environ["AIDGE_INSTALL"] - self.spawn(['cmake', str(cwd), param_py, '-DTEST=OFF', f'-DCMAKE_INSTALL_PREFIX:PATH={install_path}']) + self.spawn(['cmake', str(cwd), param_py, '-DTEST=OFF', f'-DCMAKE_INSTALL_PREFIX:PATH={install_path}', f'-DCMAKE_BUILD_TYPE={compile_type}']) if not self.dry_run: - self.spawn(['make', 'all', 'install', '-j', max_jobs]) + self.spawn(['cmake', '--build', '.', '--config', compile_type, '-j', max_jobs]) + self.spawn(['cmake', '--install', '.', '--config', compile_type]) os.chdir(str(cwd)) aidge_package = build_lib / (get_project_name()) @@ -81,7 +83,7 @@ class CMakeBuild(build_ext): # Copy all shared object files from build_temp/lib to aidge_package for root, _, files in os.walk(build_temp.absolute()): for file in files: - if file.endswith('.so') and (root != str(aidge_package.absolute())): + if (file.endswith('.so') or file.endswith('.pyd')) and (root != str(aidge_package.absolute())): currentFile=os.path.join(root, file) shutil.copy(currentFile, str(aidge_package.absolute())) @@ -100,7 +102,6 @@ if __name__ == '__main__': long_description_content_type="text/markdown", long_description="\n".join(DOCLINES[2:]), classifiers=[c for c in CLASSIFIERS.split('\n') if c], - platforms=["Linux"], packages=find_packages(where="."), include_package_data=True, ext_modules=[CMakeExtension(get_project_name())], diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 03b2a9adb439eb00d0ba59a13fead4f25d617b36..8f8f51c89bbcc380963f355f781e8fda940dcffc 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -125,21 +125,17 @@ Aidge::IOIndex_t Aidge::GraphView::getNbFreeDataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::dataInputs() const { - IOIndex_t nbDataIn = 0U; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - nbDataIn += inputNode->nbDataInputs(); - } - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbDataIn); - nbDataIn = 0U; + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (const std::shared_ptr<Node>& inputNode : mInputNodes) { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->dataInputs(); - std::move(inputNodeinputs.begin(), inputNodeinputs.end(), - res.begin() + nbDataIn); - nbDataIn += inputNode->nbDataInputs(); - // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> - // inputs()).end()); + + for (const auto& input : inputNodeinputs) { + if (mNodes.find(input.first) == mNodes.end()) { + res.push_back(input); + } + } } return res; } @@ -147,21 +143,17 @@ Aidge::GraphView::dataInputs() const { std::vector<std::pair<std::shared_ptr<Aidge::Node>, Aidge::IOIndex_t>> Aidge::GraphView::inputs() const { - std::size_t nbIn = 0U; - for (const std::shared_ptr<Node>& inputNode : mInputNodes) { - nbIn += inputNode->nbInputs(); - } - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res = - std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>(nbIn); - nbIn = 0U; + std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> res; + for (const std::shared_ptr<Node>& inputNode : mInputNodes) { std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = inputNode->inputs(); - std::move(inputNodeinputs.begin(), inputNodeinputs.end(), - res.begin() + nbIn); - nbIn += inputNode->nbInputs(); - // res.insert(res.end(), (inputNode -> inputs()).begin(), (inputNode -> - // inputs()).end()); + + for (const auto& input : inputNodeinputs) { + if (mNodes.find(input.first) == mNodes.end()) { + res.push_back(input); + } + } } return res; } diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 54fdac808642f3ae603e237737e265ba394fccbd..e6a53c871f5312c68f40dc5c9a2777729470298b 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -341,6 +341,35 @@ Aidge::NodePtr Aidge::Node::clone() const { return std::make_shared<Node>(mOperator->clone(), mName); } + +std::set<Aidge::NodePtr> Aidge::Node::getNodeDelta(int delta,std::set<Aidge::NodePtr> nodeSee){ + + std::set<Aidge::NodePtr> out; + nodeSee.insert(shared_from_this()); + + if(delta == 0) { + out.insert(shared_from_this()); + + }else if (delta > 0){ + for (const NodePtr& node : getChildren()) { + if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance + for (const NodePtr& ch : node->getNodeDelta(delta-1,nodeSee)){ + out.insert(ch); + } + } + } + }else{ + for (const NodePtr& node : getParents()) { + if(nodeSee.find(node) == nodeSee.end()){ //loop avoidance + for (const NodePtr& pr : node->getNodeDelta(delta+1,nodeSee)){ + out.insert(pr); + } + } + } + } + + return out; +} ///////////////////////////////////////////////////////////////////////////////////////////// // private diff --git a/src/graphRegex/GraphFsmInterpreter.cpp b/src/graphRegex/GraphFsmInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..2984ab4fb3864244c9e32dbfcda9ef2ae080acf0 --- /dev/null +++ b/src/graphRegex/GraphFsmInterpreter.cpp @@ -0,0 +1,182 @@ +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + +using namespace Aidge; + + +GraphFsmInterpreter::GraphFsmInterpreter(const std::string graphMatchExpr,std::map<std::string,std::shared_ptr<ConditionalInterpreter>> nodesCondition):mParser(graphMatchExpr){ + mActGroupe = 0; + mNodesCondition = nodesCondition; +} +std::shared_ptr<FsmGraph> GraphFsmInterpreter::interpret(void){ + mActGroupe = 0; + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = mParser.parse(); + return visit(tree); +} +std::shared_ptr<FsmGraph> GraphFsmInterpreter::visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree){ + + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>> nextAstNodes = AstTree->getChilds(); + + if(AstTree->getType() == gRegexTokenTypes::SEP){ + return sepF(visit(nextAstNodes[0]),visit(nextAstNodes[1])); + }else if(AstTree->getType() == gRegexTokenTypes::NEXT){ + return nextF(visit(nextAstNodes[0]),visit(nextAstNodes[1])); + }else if(AstTree->getType() == gRegexTokenTypes::QOM){ + return qomF(visit(nextAstNodes[0])); + }else if(AstTree->getType() == gRegexTokenTypes::QZM){ + return qzmF(visit(nextAstNodes[0])); + }else if(AstTree->getType() == gRegexTokenTypes::KEY || AstTree->getType() == gRegexTokenTypes::CKEY){ + return keyF(AstTree); + }else if(AstTree->getType() == gRegexTokenTypes::LPAREN){ + mActGroupe += 1; + std::shared_ptr<FsmGraph> out = visit(nextAstNodes[0]); + mActGroupe -= 1; + return out; + }else{ + throw std::logic_error("visit Bad token type" ); + } +} + + + + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::keyF(std::shared_ptr<AstNode<gRegexTokenTypes>> AstNode){ + + + std::shared_ptr<FsmNode> start = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> valid = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + std::shared_ptr<FsmEdge> edge; + + + if(AstNode->getType() == gRegexTokenTypes::CKEY){ + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::COMMON,mNodesCondition,AstNode->getValue()); + }else if (AstNode->getType() == gRegexTokenTypes::KEY) + { + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::UNIQUE,mNodesCondition,AstNode->getValue()); + }else{ + + throw std::logic_error("keyF Bad in AST" ); + } + + graph->addEdge(edge); + graph->setGroupe(mActGroupe); + return graph; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::sepF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm){ + + size_t idxLeft = leftFsm->getNbSubFsm(); + rigthFsm->incOrigineAllNodeBy(idxLeft); + leftFsm->unionG(rigthFsm); + //the rigthFsm is no longer usfull + return leftFsm; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::nextF(std::shared_ptr<FsmGraph> leftFsm,std::shared_ptr<FsmGraph> rigthFsm){ + /* + combine the 2 Graph + all valid node of A are merge with Start B, Start B is un Start + update the relative reference + + A B + SA -> VA + SB -> VB + A B + SA -> q -> VB + */ + leftFsm->mergeOneStartOneValid(rigthFsm); + //the rigthFsm is no longer usfull + return leftFsm; +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::qomF(std::shared_ptr<FsmGraph> fsm){ + /* + + + valid node is connect to the child of Start with the same edge condition + A + S -> V + + A + S -> V + (E|R) + V -> S + */ + + std::vector<std::shared_ptr<FsmNode>> allStart = fsm->getStartNodes(); + std::set<std::shared_ptr<FsmNode>> allValid = fsm->getValidNodes(); + std::shared_ptr<FsmEdge> edge; + + if(allStart.size() != 1){ + throw std::logic_error("qomF Bad in AST" ); + } + + for(auto start : allStart ){ + for(auto edgeStart :start->getEdges() ){ + if (auto sharedEdge = edgeStart.lock()) { + + const std::map<size_t, int> commonRef = sharedEdge->getRelative(); + bool haveCommon = !commonRef.empty(); + + for(auto valid : allValid){ + if(haveCommon){ + /* + the // quantif case + get the go back and make a lexeme id(number) + we need to go back to the ref delta min #TODO + */ + bool hasMinRef = false; + std::pair<size_t, int> minRef; + for (const auto& entry : commonRef) { + if (!hasMinRef || std::abs(minRef.second) > std::abs(entry.second)) { + hasMinRef = true; + minRef = entry; + } + } + std::stringstream lexem; + lexem << "(" << minRef.first << ", " << minRef.second << ")"; + edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::REF,mNodesCondition, lexem.str()); + }else{ + /* + the sequensial quantif case + no reference to common + */ + edge = FsmEdgeFactory::make(valid,start,FsmEdgeTypes::EMPTY,mNodesCondition,""); + + } + fsm->addEdge(edge); + } + }else{ + throw std::runtime_error("edgeStart weak pointer is expired" ); + } + } + + } + return fsm; + +} + +std::shared_ptr<FsmGraph> GraphFsmInterpreter::qzmF(std::shared_ptr<FsmGraph> fsm){ + /* + qomf and a bypass empty start to valide + */ + fsm = qomF(fsm); + + std::vector<std::shared_ptr<FsmNode>> allStart = fsm->getStartNodes(); + std::set<std::shared_ptr<FsmNode>> allValid = fsm->getValidNodes(); + std::shared_ptr<FsmEdge> edge; + + if(allStart.size() != 1){ + throw std::logic_error("qzmF Bad in AST" ); + } + + for(auto start : allStart ){ + + for(auto valid : allValid){ + edge = FsmEdgeFactory::make(start,valid,FsmEdgeTypes::EMPTY,mNodesCondition,""); + fsm->addEdge(edge); + } + } + + return fsm; + + +} \ No newline at end of file diff --git a/src/graphRegex/GraphLexer.cpp b/src/graphRegex/GraphLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..61214f96a090fef5d28cb0ce1a009644d9570880 --- /dev/null +++ b/src/graphRegex/GraphLexer.cpp @@ -0,0 +1,155 @@ + +#include "aidge/graphRegex/GraphLexer.hpp" + +using namespace Aidge; + + +GraphLexer::GraphLexer( const std::string gRegexExpressions ): +mRegularExpressions(gRegexExpressions){ + mPosition = 0; +} + +std::shared_ptr<ParsingToken<gRegexTokenTypes>> GraphLexer::getNextToken(void){ + std::string currentChars = ""; + while (mPosition < mRegularExpressions.length()) + { + //erase all space + if (mRegularExpressions[mPosition] != ' ') + { + currentChars += mRegularExpressions[mPosition]; + } + else + { + mPosition++; + continue; + } + + ///// + // const lent token + ///// + + if (std::regex_match(currentChars,std::regex("\\->")))// the next TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::NEXT,""); + } + else if (std::regex_match(currentChars,std::regex("\\*")))// the * TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::QZM,""); + } + else if (std::regex_match(currentChars,std::regex("\\+")))// the + TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::QOM,""); + } + else if (std::regex_match(currentChars,std::regex("\\(")))// the LPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::LPAREN,""); + } + else if (std::regex_match(currentChars,std::regex("\\)")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::RPAREN,""); + } + + // + else if (std::regex_match(currentChars,std::regex(";")))// the SEP TOKEN + { + //test if the last sep + //std::string subStr = mRegularExpressions.substr(mPosition); + mPosition++; + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::SEP,""); + } + + ///// + //unconst lent token + ///// + + else if (std::regex_match(currentChars,std::regex("[A-Za-z_0-9]")))// the KEY or CKEY + { + + //read all the key + bool isCKey = false; + std::regex keyRegex("[A-Za-z_0-9]+"); + std::regex cKeyRegex("[A-Za-z_0-9]+\\#[0-9]*"); + + while ( mPosition < mRegularExpressions.length()) { + + if(!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,cKeyRegex)) + { + currentChars.pop_back(); //the last char is the problemes + break; + } + else if (std::regex_match(currentChars,cKeyRegex)){ + isCKey = true; + } + mPosition++; + if (mPosition < mRegularExpressions.length()) currentChars += mRegularExpressions[mPosition]; + + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mRegularExpressions.length()-1) + { + if (!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,cKeyRegex)) + { + throw badTokenError(currentChars,mPosition); + } + } + + + if (isCKey){ + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::CKEY,currentChars); + } else{ + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::KEY,currentChars); + } + } + + mPosition++; + } + + + //no more to find no one match the currentChars + if (currentChars.empty()) { + return std::make_shared<ParsingToken<gRegexTokenTypes>>(gRegexTokenTypes::STOP,""); // Null shared pointer ; + }else{ + throw badTokenError(currentChars,mPosition); + } + +} + +void GraphLexer::rstPosition(void){ + if (isEnd()){ + mPosition = 0; + }else{ + throw badTokenError("end rst",mPosition); + } +} + +bool GraphLexer::isEnd(void){ + return mPosition >= mRegularExpressions.length(); +} + +std::runtime_error GraphLexer::badTokenError(const std::string& currentChars,std::size_t position){ + std::ostringstream errorMessage; + errorMessage << "\nBad syntax " << currentChars << " :\n" << mRegularExpressions << "\n"; + for (std::size_t i = 0; i < position; i++) { + errorMessage << ' '; + } + errorMessage << "^\n"; + + return std::runtime_error(errorMessage.str()); +} + + const std::string GraphLexer::rep(){ + std::string out = mRegularExpressions; + out += "\n"; + for (std::size_t i = 0; i < mPosition; i++) { + out += ' '; + } + out += "^\n"; + return out ; + } \ No newline at end of file diff --git a/src/graphRegex/GraphParser.cpp b/src/graphRegex/GraphParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5aa653c482dae82c2e9fa02bfc36b2ffc821785f --- /dev/null +++ b/src/graphRegex/GraphParser.cpp @@ -0,0 +1,181 @@ +#include "aidge/graphRegex/GraphParser.hpp" + +using namespace Aidge; + +GraphParser::GraphParser(const std::string gRegexExpressions): +mLexer(gRegexExpressions) +{ + mCurrentToken = mLexer.getNextToken(); +} + + +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::parse(void){ + + std::shared_ptr<AstNode<gRegexTokenTypes>> astTree = constructAstAllExpr(); + rstParser(); + return astTree; +} + + +void GraphParser::rstParser(void){ + mLexer.rstPosition(); + mCurrentToken = mLexer.getNextToken(); +} + + +void GraphParser::ackToken(gRegexTokenTypes tokenType){ + + if(mCurrentToken->getType() == tokenType ){ + try { + mCurrentToken = mLexer.getNextToken(); + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "Graph Lexer error in Parser :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + }else{ + std::ostringstream errorMessage; + errorMessage << "Bad syntax GraphParser " << static_cast<int>(mCurrentToken->getType()) <<"!="<< static_cast<int>(tokenType) << "\n"; + errorMessage << mLexer.rep(); + throw std::runtime_error(errorMessage.str()); + } +} + +/* +exp : KEY(QOM | QZM)? | CKEY | domain +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstExp(void) +{ + + try{ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + std::shared_ptr<AstNode<gRegexTokenTypes>> node = std::make_shared<AstNode<gRegexTokenTypes>>(token); + + if (mCurrentToken->getType() == gRegexTokenTypes::KEY ){ + ackToken(gRegexTokenTypes::KEY ); + if (mCurrentToken->getType() == gRegexTokenTypes::QOM ){ + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::QOM ); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + return newNode; + }else if (mCurrentToken->getType() == gRegexTokenTypes::QZM ){ + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::QZM ); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + return newNode; + } + return node; + }else if (mCurrentToken->getType() == gRegexTokenTypes::CKEY){ + ackToken(gRegexTokenTypes::CKEY ); + return node; + }else{ + return constructAstDomain(); + } + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstExp :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} + +/* +seq :exp (NEXT seq)* +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstSeq(void) +{ + + try{ + + std::shared_ptr<AstNode<gRegexTokenTypes>> left = constructAstExp(); + if(mCurrentToken->getType() == gRegexTokenTypes::NEXT ) + { + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::NEXT); + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{left,constructAstSeq()}); + left = newNode; + } + return left; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstSeq :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + +} + + +/* +LPAREN seq RPAREN (QOM | QZM) +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstDomain(void) +{ + + try{ + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token ; + std::shared_ptr<AstNode<gRegexTokenTypes>> node ; + + token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::LPAREN); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{constructAstSeq()}); + ackToken(gRegexTokenTypes::RPAREN); + //(QOM | QZM) + + token = mCurrentToken->copy(); + if (mCurrentToken->getType() == gRegexTokenTypes::QOM){ + ackToken(gRegexTokenTypes::QOM); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + }else if (mCurrentToken->getType() == gRegexTokenTypes::QZM){ + ackToken(gRegexTokenTypes::QZM); + node = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{node}); + }else{ + std::ostringstream errorMessage; + errorMessage << "Bad syntax constructAstDomain must have quantifier \n"; + throw std::runtime_error(errorMessage.str()); + } + + return node; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstDomain :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} + +/* + allExpr: seq (SEP allExpr)* | STOP +*/ +std::shared_ptr<AstNode<gRegexTokenTypes>> GraphParser::constructAstAllExpr(void) +{ + + try{ + std::shared_ptr<AstNode<gRegexTokenTypes>> left = constructAstSeq(); + if(mCurrentToken->getType() == gRegexTokenTypes::SEP ) + { + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = mCurrentToken->copy(); + ackToken(gRegexTokenTypes::SEP); + + if(mCurrentToken->getType() == gRegexTokenTypes::STOP ) + { + return left; + } + std::shared_ptr<AstNode<gRegexTokenTypes>> newNode = std::make_shared<AstNode<gRegexTokenTypes>>(token, + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>>{left,constructAstAllExpr()}); + left = newNode; + } + return left; + + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "GraphParser constructAstDomain :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } +} diff --git a/src/graphRegex/GraphStrInterpreter.cpp b/src/graphRegex/GraphStrInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8ad24b5b9b0fee5fba34dd7397132bec2410fd23 --- /dev/null +++ b/src/graphRegex/GraphStrInterpreter.cpp @@ -0,0 +1,38 @@ +#include "aidge/graphRegex/GraphStrInterpreter.hpp" + +using namespace Aidge; + +GraphStrInterpreter::GraphStrInterpreter(const std::string graphMatchExpr):mParser(graphMatchExpr){ + mToTest = graphMatchExpr; + mToTest.erase(std::remove_if(mToTest.begin(), mToTest.end(), ::isspace), mToTest.end()); +} + + +std::string GraphStrInterpreter::visit(std::shared_ptr<AstNode<gRegexTokenTypes>> AstTree){ + + std::vector<std::shared_ptr<AstNode<gRegexTokenTypes>>> nextAstNodes = AstTree->getChilds(); + + if(AstTree->getType() == gRegexTokenTypes::SEP){ + return visit(nextAstNodes[0])+";"+visit(nextAstNodes[1]); + }else if(AstTree->getType() == gRegexTokenTypes::NEXT){ + return visit(nextAstNodes[0])+"->"+visit(nextAstNodes[1]); + }else if(AstTree->getType() == gRegexTokenTypes::QOM){ + return visit(nextAstNodes[0])+"+"; + }else if(AstTree->getType() == gRegexTokenTypes::QZM){ + return visit(nextAstNodes[0])+"*"; + }else if(AstTree->getType() == gRegexTokenTypes::KEY || AstTree->getType() == gRegexTokenTypes::CKEY){ + return AstTree->getValue(); + }else if(AstTree->getType() == gRegexTokenTypes::LPAREN){ + return "("+visit(nextAstNodes[0])+")"; + }else{ + throw std::logic_error("visit Bad token type" ); + } + + +} + + +std::string GraphStrInterpreter::interpret(void){ + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = mParser.parse(); + return visit(tree); +} \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmEdge.cpp b/src/graphRegex/matchFsm/FsmEdge.cpp new file mode 100644 index 0000000000000000000000000000000000000000..593da06abe18576d435ae55718d379aa5b682d60 --- /dev/null +++ b/src/graphRegex/matchFsm/FsmEdge.cpp @@ -0,0 +1,277 @@ +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + +std::map<std::string,int> FsmEdgeCommon::mCommonIdxMap; + +bool FsmEdge::isCommon(void){ + return false; +} + +size_t FsmEdge::getCommonIdx(void){ + return std::numeric_limits<std::size_t>::max(); +} +const std::map<size_t,int>& FsmEdge::getRelative(void){ + return mRelativePos; +} +void FsmEdge::updateRelative( const std::map<size_t,int>& relativePos ){ + for (const auto& kvp : relativePos) { + mRelativePos.insert(kvp); + } +} +std::shared_ptr<FsmNode> FsmEdge::getSourceNode(void){ + return mNodeSource; +} +void FsmEdge::reSetSouceNode(const std::shared_ptr<FsmNode>& newSource){ + mNodeSource->rmEdge(shared_from_this()); + mNodeSource = newSource; + mNodeSource->addEdge(shared_from_this()); + propagateRelativePos(); + +} +std::shared_ptr<FsmNode> FsmEdge::getDestNode(void){ + return mNodeDest; +} +void FsmEdge::reSetDestNode(const std::shared_ptr<FsmNode>& newDest){ + mNodeDest->rmParent(mNodeSource); + mNodeDest = newDest; + mNodeDest->addParent(mNodeSource); + propagateRelativePos(); +} +void FsmEdge::propagateRelativePos(void){ + + std::set<int> myRelativeID; + for (const auto& kvp : mRelativePos) { + myRelativeID.insert(kvp.first); + } + + for (const auto& nextWeakEdge : mNodeDest->getEdges()){ + + if (auto nextEdge = nextWeakEdge.lock()) { + + if(this == nextEdge.get()){ + continue; + } + + + std::set<int> nextRelativeID; + for (const auto& kvp : nextEdge->getRelative()) { + nextRelativeID.insert(kvp.first); + } + + // Find elements in myRelativeID but not in nextRelativeID + std::set<int> idxsToPush; + std::set_difference(myRelativeID.begin(), myRelativeID.end(), + nextRelativeID.begin(), nextRelativeID.end(), + std::inserter(idxsToPush, idxsToPush.begin())); + + // Find elements in nextRelativeID but not in myRelativeID + std::set<int> idxsToGet; + std::set_difference(nextRelativeID.begin(), nextRelativeID.end(), + myRelativeID.begin(), myRelativeID.end(), + std::inserter(idxsToGet, idxsToGet.begin())); + + // test for integrity we look if 2 edge refert to the samme + // ref and are link the ref dif is one + // not working for common node + // we can go deeper by find the all pass to a ref and see if the delta is good + + // Find elements present in both myRelativeID and nextRelativeID + std::set<int> idxsTotest; + for (int idx : nextRelativeID){ + if (myRelativeID.find(idx) != myRelativeID.end()){ + if (std::abs(getRelative().at(idx) - nextEdge->getRelative().at(idx)) != 1) { + throw std::runtime_error("Bad relative"); + } + } + } + + + + // this egde have more relative info than the next + std::map<size_t,int> tmpRelative; + // we push this info to the next + for( auto idxToPush :idxsToPush ){ + tmpRelative.insert( std::make_pair(idxToPush, getRelative().at(idxToPush) +1)); + } + if(tmpRelative.size() != 0){ + nextEdge->updateRelative(tmpRelative); + nextEdge->propagateRelativePos(); + } + tmpRelative.clear(); + + + // the next node have more info than me i need to get it + for( auto idxToGet :idxsToGet ){ + tmpRelative.insert( std::make_pair(idxToGet, nextEdge->getRelative().at(idxToGet) -1)); + } + if(tmpRelative.size() != 0){ + updateRelative(tmpRelative); + + for(auto weakParent : getSourceNode()->getParentNodes()){ + if (auto parent = weakParent.lock()) { + for(auto weakPEdge : parent->getEdges()){ + if (auto pEdge = weakPEdge.lock()) { + pEdge->propagateRelativePos(); + }else{ + throw std::runtime_error("propagateRelativePos parent edge weak pointer is expired" ); + } + } + }else{ + throw std::runtime_error("propagateRelativePos parent weak pointer is expired" ); + } + } + } + tmpRelative.clear(); + }else{ + throw std::runtime_error("propagateRelativePos edge weak pointer is expired" ); + } + } +} + +void FsmEdge::updateWeak(void){ + mNodeSource->addEdge(shared_from_this()); + mNodeDest->addParent(mNodeSource); +} + +FsmEdge::FsmEdge(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest) +:mToTest(toTest) +{ + mNodeSource = source; + mNodeDest = dest; + // wen i make the edge I init the nodes + // mNodeSource->addEdge(shared_from_this()); + // mNodeDest->addParent(mNodeSource); +} + + +/////surchage + +FsmEdgeUnique::FsmEdgeUnique(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest) +:FsmEdge(source,dest,toTest) +{ +} +const EdgeTestResult FsmEdgeUnique::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + auto opNode = stmContext->getActNode(); + + if(opNode == nullptr){ + return {false,std::set<NodePtr>()};//none + } + + if(mToTest->test(opNode) && opNode->getChildren().size() <= 1){ + stmContext->setValid(opNode,mToTest); + return {true,opNode->getChildren()} ; + }else{ + stmContext->addRejectedNode(opNode); + return {false,std::set<NodePtr>()}; + } +} +///////////////////// +FsmEdgeCommon::FsmEdgeCommon(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const std::shared_ptr<ConditionalInterpreter> toTest, const std::string commonKey) +:FsmEdge(source,dest,toTest) +{ + //make a uid for common node + if(mCommonIdxMap.find(commonKey) == mCommonIdxMap.end()){ + mCommonIdxMap.insert(std::make_pair(commonKey, mCommonIdxMap.size())); + } + mCommonIdx = mCommonIdxMap[commonKey]; + propagateRelativePos(); +} + + +const EdgeTestResult FsmEdgeCommon::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + + auto opNode = stmContext->getActNode(); + + if(opNode == nullptr){ + return {false,std::set<NodePtr>()};//none + } + if(mToTest->test(opNode)){ + stmContext->setCommon(opNode,mCommonIdx); + stmContext->setValid(opNode,mToTest); + return {true,opNode->getChildren()} ; + }else{ + stmContext->addRejectedNode(opNode); + return {false,std::set<NodePtr>()}; + } +} +bool FsmEdgeCommon::isCommon(void){ + return true; + } +//////////////////// TODO FsmEdgeEmpty must be size_t +FsmEdgeRef::FsmEdgeRef(std::shared_ptr<FsmNode>& source,std::shared_ptr<FsmNode>& dest, const size_t refCommonIdx,const int deltaCommonIdx) +:FsmEdge(source,dest,nullptr),mRefCommonIdx(refCommonIdx),mdeltaCommonIdx(deltaCommonIdx) +{ + +} +const EdgeTestResult FsmEdgeRef::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + + NodePtr refNode = stmContext->getCommonNodeFromIdx(mRefCommonIdx); + if (refNode){ + std::set<std::shared_ptr<Node>> see; + return {true,refNode->getNodeDelta(mdeltaCommonIdx,see)}; + } + return {false,std::set<NodePtr>()}; +} +//////////////////// +FsmEdgeEmpty::FsmEdgeEmpty(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest) +:FsmEdge(source,dest,nullptr) +{} +const EdgeTestResult FsmEdgeEmpty::test(const std::shared_ptr<FsmRunTimeContext> stmContext){ + auto opNode = stmContext->getActNode(); + if(opNode == nullptr){ + return {false,std::set<NodePtr>()}; + } + return {true,std::set<NodePtr>({opNode})};//none +} + +/// factory +std::shared_ptr<FsmEdge> FsmEdgeFactory::make( +std::shared_ptr<FsmNode> source, +std::shared_ptr<FsmNode> dest, FsmEdgeTypes type, +std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest, +const std::string lexeme) +{ + if (type == FsmEdgeTypes::EMPTY) { + if (lexeme.empty()) { + return std::make_shared<FsmEdgeEmpty>(source, dest); + } else { + throw std::invalid_argument("error lexem EMPTY"); + } + } else if (type == FsmEdgeTypes::REF) { + std::smatch m; + std::regex refRegex("\\s*\\(\\s*(\\d+)\\s*,\\s*(-?\\d+)\\s*\\)\\s*"); + if (std::regex_match(lexeme, m, refRegex)) { + int refCommonIdx = std::stoi(m[1]); + int deltaCommonIdx = std::stoi(m[2]); + return std::make_shared<FsmEdgeRef>(source, dest, refCommonIdx, deltaCommonIdx); + } else { + throw std::invalid_argument("error lexem REF " + lexeme); + } + } else if (type == FsmEdgeTypes::COMMON) { + std::smatch m; + std::regex commonRegex("\\s*(\\w+)#(\\d*)"); + if (std::regex_match(lexeme, m, commonRegex)) { + std::string edgeType = m[1]; + std::string commonId = m[2]; + size_t commonIdx = commonId.empty() ? 0 : std::stoi(commonId) + 1; + std::string commonKey = edgeType + std::to_string(commonIdx); + return std::make_shared<FsmEdgeCommon> (source, dest, allTest.at(edgeType), commonKey); + } else { + throw std::invalid_argument("error lexem COMMON " + lexeme); + } + } else if (type == FsmEdgeTypes::UNIQUE) { + std::regex uniqueRegex("\\s*(\\w+)"); + std::smatch m; + if (std::regex_match(lexeme, m, uniqueRegex)) { + std::string edgeType = m[1]; + return std::make_shared<FsmEdgeUnique>(source, dest, allTest.at(edgeType)); + } else { + throw std::invalid_argument("error lexem UNIQUE \"" + std::string(lexeme) +" eee\""); + } + } else { + throw std::invalid_argument("Bad edge Type"); + } + } \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmGraph.cpp b/src/graphRegex/matchFsm/FsmGraph.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5a9f00d728cd2cd9f58c2228361f8393de2a3d9d --- /dev/null +++ b/src/graphRegex/matchFsm/FsmGraph.cpp @@ -0,0 +1,201 @@ +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" + +using namespace Aidge; + + + +FsmGraph::FsmGraph(/* args */){ + +} + +//TODO + std::shared_ptr<MatchResult> FsmGraph::test(std::vector<NodePtr>& startNodes){ + std::vector<std::shared_ptr<Aidge::FsmNode>> startNodesFsm = getStartNodes(); + if(startNodes.size() != startNodesFsm.size()){ + throw std::runtime_error("bad number of Start nodes"); + } + + std::vector<std::shared_ptr<FsmRunTimeContext>> walks; + for(std::size_t i = 0; i < startNodes.size(); i++){ + walks.push_back(std::make_shared<FsmRunTimeContext>(startNodesFsm[i],startNodes[i])); + } + std::vector<std::shared_ptr<FsmRunTimeContext>> nextWalks; + + std::vector<std::shared_ptr<FsmRunTimeContext>> allValidContext; + std::vector<std::shared_ptr<FsmRunTimeContext>> allContextSee; + + + + + while (!walks.empty()) + { + for(auto fsmContext : walks){ + allContextSee.push_back(fsmContext); + //if we are in a valid st we save it + //it's one solution of the posible solution of the matching + if(fsmContext->isOnValidState()){ + //not save 2 time the same end point + if(!std::any_of(allValidContext.begin(), allValidContext.end(), + [&](std::shared_ptr<Aidge::FsmRunTimeContext> oldValid) { + return fsmContext->areEqual(oldValid); + })){ + allValidContext.push_back(fsmContext); + } + + } + + //dont test 2 time a fsmContext + std::vector<std::shared_ptr<FsmRunTimeContext>> tmpNextWalks = fsmContext->getActState()->test(fsmContext); + for(auto PotentialFsmContext : tmpNextWalks){ + + if(!std::any_of(allContextSee.begin(), allContextSee.end(), + [&](std::shared_ptr<Aidge::FsmRunTimeContext> oldSee) { + return PotentialFsmContext->areEqual(oldSee); + })){ + nextWalks.push_back(PotentialFsmContext); + } + } + + } + walks.swap(nextWalks); + nextWalks.clear(); + } + + + return std::make_shared<MatchResult>(allValidContext,getNbSubFsm()); + +} + + +/////////////// +// FSM construction +/////////////// +const std::set<std::shared_ptr<FsmEdge>>& FsmGraph::getEdge(void){ + return mEdges; +} + +void FsmGraph::addEdge(std::shared_ptr<FsmEdge>& edge){ + edge->updateWeak(); + mEdges.insert(edge); + mAllOrigine.insert(edge->getDestNode()->getOrigine()); + mAllOrigine.insert(edge->getSourceNode()->getOrigine()); +} + +const std::vector<std::shared_ptr<FsmNode>> FsmGraph::getStartNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + std::vector<std::shared_ptr<FsmNode>> startNodes; + for(auto node :nodes){ + if(node->isStart()){ + startNodes.push_back(node); + } + } + return startNodes; +} + +const std::set<std::shared_ptr<FsmNode>> FsmGraph::getValidNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + std::set<std::shared_ptr<FsmNode>> ValidNodes; + for(auto node :nodes){ + if(node->isValid()){ + ValidNodes.insert(node); + } + } + //may short + return ValidNodes; +} + +const std::set<std::shared_ptr<FsmNode>> FsmGraph::getNodes(void){ + std::set<std::shared_ptr<FsmNode>> nodes; + for(auto edge : mEdges){ + nodes.insert(edge->getDestNode()); + nodes.insert(edge->getSourceNode()); + } + return nodes; +} + +void FsmGraph::setGroupe(std::size_t groupeIdx){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + for(auto node :nodes){ + node->setGroupe(groupeIdx); + } +} + +void FsmGraph::unionG(const std::shared_ptr<FsmGraph> fsmGraph){ + + for(auto edge : fsmGraph->getEdge()){ + addEdge(edge); + } +} + +void FsmGraph::mergeOneStartOneValid(const std::shared_ptr<FsmGraph> fsmGraph){ + std::set<std::shared_ptr<FsmNode>> validNodes = getValidNodes(); + std::vector<std::shared_ptr<FsmNode>> startNodes = fsmGraph->getStartNodes(); + + if (startNodes.size() != 1 || validNodes.size() != 1){ + + std::ostringstream errorMessage; + errorMessage <<"mergeOneStartOneValid start size: " << startNodes.size() << " valide size : " << validNodes.size() + <<" can only merge FSM 1 start 1 valide"; + throw std::runtime_error(errorMessage.str()); + } + + unionG(fsmGraph); + //for loop useless but for future merge it's coudl be used + for(auto valid : validNodes){ + valid->unValid(); + for(auto start : startNodes){ + start->unStart(); + _mergeNode(start,valid); + } + } +} + +std::size_t FsmGraph::getNbSubFsm(void){ + return mAllOrigine.size(); +} + +void FsmGraph::incOrigineAllNodeBy(std::size_t incr){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + for(auto node :nodes){ + node->incOrigine(incr); + } + std::set<std::size_t> updatedOrigin; + for(auto origin : mAllOrigine){ + updatedOrigin.insert(origin + incr); + } + mAllOrigine.swap(updatedOrigin); +} + +void FsmGraph::_mergeNode(std::shared_ptr<FsmNode> source,std::shared_ptr<FsmNode> dest){ + std::set<std::shared_ptr<FsmNode>> nodes = getNodes(); + + if(nodes.find(source) == nodes.end() || nodes.find(dest) == nodes.end()){ + throw std::runtime_error("FsmGraph can not merge node not in the graph"); + } + nodes.clear(); + + //probagate source attribut + if(source->isValid()){ + dest->valid(); + } + if(source->isStart()){ + dest->start(); + } + + //merge source to dest by replace source by dest in all EDGE + for(auto edge : mEdges){ + if(edge->getDestNode() == source ){ + edge->reSetDestNode(dest); + }else if(edge->getSourceNode() == source ){ + edge->reSetSouceNode(dest); + } + + } + //check is source is not in graph + nodes = getNodes(); + if(nodes.find(source) != nodes.end() ){ + throw std::runtime_error("FsmGraph merge node not effective"); + } + nodes.clear(); + +} diff --git a/src/graphRegex/matchFsm/FsmNode.cpp b/src/graphRegex/matchFsm/FsmNode.cpp new file mode 100644 index 0000000000000000000000000000000000000000..84b4a0c3fdbe0730a12a2a62db9158e2538d646f --- /dev/null +++ b/src/graphRegex/matchFsm/FsmNode.cpp @@ -0,0 +1,132 @@ +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + + + +FsmNode::FsmNode(bool isAValid,bool isAStart ){ + mIsAStart =isAStart; + mIsAValid =isAValid; + +} +const std::vector<std::shared_ptr<FsmRunTimeContext>> FsmNode::test( std::shared_ptr<FsmRunTimeContext> fsmContext){ + + + std::vector<std::shared_ptr<FsmRunTimeContext>> out; + + for(auto edge : mEdges){ + if (auto sharedEdge = edge.lock()) { + + std::shared_ptr<FsmNode> nextState = sharedEdge->getDestNode(); + + //make copy of the fsmContext + std::shared_ptr<FsmRunTimeContext> newFsmContext = std::make_shared<FsmRunTimeContext>(fsmContext); + + EdgeTestResult edgeRes = sharedEdge->test(newFsmContext); + + if(edgeRes.success){ + if(edgeRes.node.size() != 0){ + for(auto nextNode :edgeRes.node ){ + if(!newFsmContext->isAlreadyValid(nextNode)|| newFsmContext->isCommonDefined(nextNode) ){ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nextNode)); + + }else{ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nullptr)); + } + + } + }else{ + out.push_back( std::make_shared<FsmRunTimeContext>(newFsmContext,nextState,nullptr)); + } + } + newFsmContext.reset(); + + }else{ + throw std::runtime_error("test FsmNode weak pointer is expired" ); + } + + } + return out; +} + + + +std::size_t FsmNode::getOrigine(void){ + return mOrigineStm; +} +void FsmNode::incOrigine(std::size_t inc){ + mOrigineStm += inc; +} +void FsmNode::rmEdge(std::shared_ptr<FsmEdge> edge){ + mEdges.erase(edge); +} + +void FsmNode::addEdge(std::shared_ptr<FsmEdge> edge){ + std::weak_ptr<FsmEdge> edgeW(edge); + if (!edgeW.expired()) { + mEdges.insert(edgeW); + }else{ + throw std::runtime_error("addEdge FsmNode weak pointer is expired" ); + } +} + +// const std::set<std::shared_ptr<FsmNode>> FsmNode::getChildNodes(void){ +// std::set<std::shared_ptr<FsmNode>> children; +// for(auto edge : mEdges){ +// if (auto sharedEdge = edge.lock()) { +// children.insert(sharedEdge->getDestNode()); +// }else{ +// throw std::runtime_error("getChildNodes FsmNode weak pointer is expired" ); +// } +// } +// return children; +// } + + +const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>>& FsmNode::getParentNodes(void){ + return mParents; +} +const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& FsmNode::getEdges(void){ + return mEdges; +} + +void FsmNode::setGroupe(std::size_t groupeIdx){ + mGroupeStm = groupeIdx; + +} + +bool FsmNode::isValid(void){ + return mIsAValid; +} +bool FsmNode::isStart(void){ + return mIsAStart; +} +void FsmNode::unValid(void){ + mIsAValid =false; +} +void FsmNode::valid(void){ + mIsAValid =true; +} +void FsmNode::unStart(void){ + mIsAStart =false; +} +void FsmNode::start(void){ + mIsAStart =true; +} + + + +void FsmNode::addParent(std::shared_ptr<FsmNode> node){ + + std::weak_ptr<FsmNode> nodeW(node); + if (!nodeW.expired()) { + mParents.insert(nodeW); + }else{ + throw std::runtime_error("addParent FsmNode weak pointer is expired" ); + } +} +void FsmNode::rmParent(std::shared_ptr<FsmNode> node){ + mParents.erase(node); +} \ No newline at end of file diff --git a/src/graphRegex/matchFsm/FsmRunTimeContext.cpp b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..787cf2322a5b8e7001cdc59325345000dbb61553 --- /dev/null +++ b/src/graphRegex/matchFsm/FsmRunTimeContext.cpp @@ -0,0 +1,226 @@ +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" + +using namespace Aidge; + +std::vector<std::set<NodePtr>> FsmRunTimeContext::mRejectedNodes; + +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced ){ + mActOpNode = actOpNode; + mActState = actState; + + //not define case + if(idxRejeced == std::numeric_limits<std::size_t>::max()){ + mLocalIdxRejeced = mRejectedNodes.size(); + mRejectedNodes.push_back(std::set<NodePtr>()); + }else{ + if(idxRejeced > mRejectedNodes.size()-1 ){ + throw std::runtime_error("FsmRunTimeContext idxRejeced"); + } + mLocalIdxRejeced =idxRejeced; + } +} + + + +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime){ + mActOpNode = fsmRunTime->mActOpNode; + mActState = fsmRunTime->mActState; + mCommonNodes = fsmRunTime->mCommonNodes; + mValidNodes = fsmRunTime->mValidNodes; + mLocalIdxRejeced = fsmRunTime->mLocalIdxRejeced; +} +FsmRunTimeContext::FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime,std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ){ + mActOpNode = actOpNode; + mActState = actState; + mCommonNodes = fsmRunTime->mCommonNodes; + mValidNodes = fsmRunTime->mValidNodes; + mLocalIdxRejeced = fsmRunTime->mLocalIdxRejeced; +} + +void FsmRunTimeContext::addRejectedNode(NodePtr node){ + mRejectedNodes[mLocalIdxRejeced].insert(node); +} + +std::set<NodePtr> FsmRunTimeContext::getRejectedNodes(void){ + return mRejectedNodes[mLocalIdxRejeced]; +} + +bool FsmRunTimeContext::isOnValidState(void){ + return mActState->isValid(); +} + +bool FsmRunTimeContext::isCommonDefined(NodePtr node){ + //return mCommonNodes.find(node) != mCommonNodes.end(); + + std::set<NodePtr> nodes = getCommonNodes(); + for(const auto& nodeC : nodes){ + if(nodeC.get() == node.get()){ + return true; + } + } + return false; +} + +bool FsmRunTimeContext::isAlreadyValid(NodePtr node){ + + std::set<NodePtr> nodes = getValidNodes(); + for(const auto& nodeV : nodes){ + if(nodeV.get() == node.get()){ + return true; + } + } + return false; + + //return getValidNodes().find(node) != getValidNodes().end(); +} + +bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmContext){ + /* + see if 2 context can be merge + it need to have different mValidNodes exept for common + and the same idx for the common + */ + + //common node + + for (const auto& ref : getCommon()) { + for (const auto& test : fsmContext->getCommon()) { + //same index + if(ref.second == test.second){ + if(ref.first != test.first){ + return false; + } + } + } + } + + //valid nodes + std::set<NodePtr> commonElements; + std::set<NodePtr> A = getValidNodesNoCommon(); + std::set<NodePtr> B = fsmContext->getValidNodesNoCommon(); + std::set_intersection( + A.begin(),A.end(), + B.begin(), B.end(), + std::inserter(commonElements, commonElements.end()) + ); + + + if (!commonElements.empty()) { + return false; + } + + return true; +} + +bool FsmRunTimeContext::areEqual(std::shared_ptr<FsmRunTimeContext> fsmContext){ + if(getActNode() != fsmContext->getActNode()){ + return false; + } + if (getActState() != fsmContext->getActState()){ + return false; + } + if (getValidNodes() != fsmContext->getValidNodes()){ + return false; + } + if (getCommon() != fsmContext->getCommon()){ + return false; + } + + + return true; +} + +void FsmRunTimeContext::setCommon(NodePtr node,std::size_t commonIdx){ + if(isCommonDefined(node)){ + if (mCommonNodes.at(node) != commonIdx){ + throw std::runtime_error("conflict idx in the Common node"); + } + }else{ + mCommonNodes[node] = commonIdx; + } +} + +void FsmRunTimeContext::setValid(NodePtr node,std::shared_ptr<ConditionalInterpreter> tag){ + //we already find a node of this type + if(mValidNodes.find(tag) != mValidNodes.end()){ + if(isAlreadyValid(node) && !isCommonDefined(node) ){ + throw std::runtime_error("setValid you valid tow time"); + } + mValidNodes[tag].insert(node); + }else{ + mValidNodes[tag] = {node}; + } + +} + +std::size_t FsmRunTimeContext::getSubStmId(void){ + return mActState->getOrigine(); +} + +NodePtr FsmRunTimeContext::getCommonNodeFromIdx(std::size_t commonIdx){ + for (const auto& pair : mCommonNodes) { + if (pair.second == commonIdx) { + return pair.first; // Return the key when the value is found + } + } + throw std::runtime_error("getCommonNodeFromIdx Value not found in the map"); +} + +std::size_t FsmRunTimeContext::getCommonNodeIdx(NodePtr node){ + if(isCommonDefined(node)){ + return mCommonNodes.at(node); + } + throw std::runtime_error("getCommonNodeIdx node not found"); +} + +std::set<NodePtr> FsmRunTimeContext::getCommonNodes(void){ + std::set<NodePtr> nodes; + // Iterate over the map and insert values into the set + for (const auto& pair : mCommonNodes) { + nodes.insert(pair.first); + } + return nodes; +} + +std::map<NodePtr,std::size_t> FsmRunTimeContext::getCommon(void){ + return mCommonNodes; +} + +std::set<NodePtr> FsmRunTimeContext::getValidNodes(void){ + + auto sharedSet = std::make_shared<std::set<NodePtr>>(); + // Create a set to store the values from the map + std::set<NodePtr> nodes; + // Iterate over the map and insert values into the set + for (const auto& pair : mValidNodes) { + nodes.insert(pair.second.begin(),pair.second.end()); + } + return nodes; +} + +std::set<NodePtr> FsmRunTimeContext::getValidNodesNoCommon(void){ + std::set<NodePtr> differenceSet; + std::set<NodePtr> valide = getValidNodes(); + std::set<NodePtr> common = getCommonNodes(); + std::set_difference(valide.begin(), valide.end(), common.begin(), common.end(),std::inserter(differenceSet, differenceSet.end())); + return differenceSet; +} + +std::map<std::shared_ptr<ConditionalInterpreter>,std::set<NodePtr>> FsmRunTimeContext::getValid(void){ + return mValidNodes; +} + +NodePtr FsmRunTimeContext::getActNode(void){ + return mActOpNode; +} + +std::shared_ptr<FsmNode> FsmRunTimeContext::getActState(){ + return mActState; +} + + +void FsmRunTimeContext::rst(void){ + mRejectedNodes.clear(); +} + diff --git a/src/graphRegex/matchFsm/MatchResult.cpp b/src/graphRegex/matchFsm/MatchResult.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c35f1a7348e365baa8a27854ee6b0a833e342ee7 --- /dev/null +++ b/src/graphRegex/matchFsm/MatchResult.cpp @@ -0,0 +1,93 @@ +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" + +using namespace Aidge; + + +MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm):mIdToRunTime(nbSubStm){ + mAllValid = allValid; + mNbSubStm = nbSubStm; + + //mIdToRunTimm + for (const auto& contextPtr : allValid) { + mIdToRunTime[contextPtr->getSubStmId()].push_back(contextPtr); + } + + std::vector<std::shared_ptr<FsmRunTimeContext>> precedence; + //make all solution posible + _generateCombinationd(0,precedence); + //sort by solution number of elements + std::sort(mSolve.begin(), mSolve.end(), [](const std::set<NodePtr>& set1, const std::set<NodePtr>& set2) { + return set1.size() < set2.size(); + }); + + +} + +void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence){ + + //it's end , we are below the number of stm + if (idxSubStm == mNbSubStm) + { + //precedence containe a liste of FSM compatible, we just need to + //check if all the node have been valide by at least one contetext + + //1) make the set of all node for the comput graph that are valide in all the FsmRunTimeContext + std::set<NodePtr> validNode; + std::set<NodePtr> rejectNode; + for (const auto& contextPtr : precedence) { + std::set<NodePtr> tmpV = contextPtr->getValidNodes(); + validNode.insert(tmpV.begin(), tmpV.end()); + std::set<NodePtr> tmpR = contextPtr->getRejectedNodes(); + rejectNode.insert(tmpR.begin(),tmpR.end()); + } + // 2) all RejectedNodes need to be valide by an others stm + // if it's not the case the match is not valid + if(std::includes(validNode.begin(), validNode.end(), rejectNode.begin(), rejectNode.end())){ + //we can save the solution + mSolve.push_back(validNode); + } + precedence.pop_back(); + return; + } + + + for (const auto& contextPtrOneFsm : mIdToRunTime[idxSubStm]) + { + if(idxSubStm == 0){ + precedence.push_back(contextPtrOneFsm); + _generateCombinationd(idxSubStm+1,precedence); + + }else{ + //test if the new context is compatible whith all the context in the precedence + // + bool compatibleSolutionFsm = true; + for (const auto& contextPtrOfOtherFsm : precedence) { + if(!(contextPtrOneFsm->areCompatible(contextPtrOfOtherFsm))){ + compatibleSolutionFsm = false; + break; + } + } + + if(compatibleSolutionFsm){ + precedence.push_back(contextPtrOneFsm); + _generateCombinationd(idxSubStm+1,precedence); + } + + } + } + + if(idxSubStm != 0){ + precedence.pop_back(); + } + return; + +} + +std::set<NodePtr> MatchResult::getBiggerSolution(void){ + if(mSolve.empty()){ + return std::set<NodePtr>(); + }else{ + return mSolve[0]; + } + +} \ No newline at end of file diff --git a/src/nodeTester/ConditionalInterpreter.cpp b/src/nodeTester/ConditionalInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e01bdd76a28576451a1a09202d5fd1e87a4856e5 --- /dev/null +++ b/src/nodeTester/ConditionalInterpreter.cpp @@ -0,0 +1,344 @@ + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + +using namespace Aidge; + + +/////////////////////////////// +//ConditionalRegisterFunction +/////////////////////////////// + + ConditionalData* ConditionalRegisterFunction::run(const std::string key,std::vector<ConditionalData*> & datas){ + + auto lambdaIt = mWlambda.find(key); + if (lambdaIt != mWlambda.end()) { + return lambdaIt->second(datas); + }else { + throw std::runtime_error("can not run Lambda due to invalid key: " + key); + } + } + +////////////////////// +//ConditionalInterpreter +/////////////////////// + ConditionalInterpreter::ConditionalInterpreter(const std::string ConditionalExpressions) + :mLambdaRegiter() + { + + ConditionalParser conditionalParser = ConditionalParser(ConditionalExpressions); + mTree = conditionalParser.parse(); + ///lambda by default + mLambdaRegiter.insert("getType",+[](NodePtr NodeOp){return NodeOp->type();}); + + } + + + bool ConditionalInterpreter::test( const NodePtr nodeOp) + { + + clearRes(); + try{ + std::vector<ConditionalData*> r = visit({mTree},nodeOp); + + if (mResolution.size() != 1){ + throw std::runtime_error("Multi-output interpretation output"); + }else{ + if (!mResolution[0]->isTypeEqualTo<bool>()){ + throw std::runtime_error("TEST OUT MUST BE A BOOL "); + }else{ + return mResolution[0]->getValue<bool>(); + } + } + + }catch(const std::exception& e){ + std::ostringstream errorMessage; + errorMessage << "Error in test " << "\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + } + + void ConditionalInterpreter::insertLambda(const std::string key,std::function<bool(Aidge::NodePtr)> f){ + mLambdaRegiter.insert<std::function<bool(Aidge::NodePtr)> >(key, f); + } + + ///// + std::vector<ConditionalData*> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){ + std::vector<ConditionalData*> dataVector; + + for ( std::shared_ptr<AstNode<ConditionalTokenTypes>> node : nodes) { + try{ + switch (node->getType()){ + /////////////////////////////////// + //OPERATOR + /////////////////////////////////// + case ConditionalTokenTypes::NOT: + { + visit(node->getChilds(),nodeOp); + fNot(); + } + break; + case ConditionalTokenTypes::AND: + { + visit(node->getChilds(),nodeOp); + fAnd(); + } + break; + case ConditionalTokenTypes::OR: + { + visit(node->getChilds(),nodeOp); + fOr(); + } + break; + case ConditionalTokenTypes::EQ: + { + visit(node->getChilds(),nodeOp); + fEq(); + //dataVector.insert(dataVector.end(), tmp.begin(), tmp.end()); + } + break; + case ConditionalTokenTypes::NEQ: + { + visit(node->getChilds(),nodeOp); + fNeq(); + } + break; + + /////////////////////////////////// + //VALUE + /////////////////////////////////// + + case ConditionalTokenTypes::KEY: + + break; + case ConditionalTokenTypes::INTEGER: + { + fStrToInteger(node); + } + break; + case ConditionalTokenTypes::FLOAT: + { + fStrToFloat(node); + + } + break; + case ConditionalTokenTypes::STRING: + { + fStrToStr(node); + } + break; + + case ConditionalTokenTypes::NODE: //TODO + { + + ConditionalData* data = new ConditionalData; + data->setValue<NodePtr>(nodeOp); + mResolution.push_back(data); + + } + break; + + case ConditionalTokenTypes::LAMBDA: + { + visit(node->getChilds(),nodeOp); + fLambda(node); + + } + break; + + case ConditionalTokenTypes::BOOL: //TODO + { + ConditionalData* data = new ConditionalData; + + if(node->getValue() == "true"){ + data->setValue<bool>(true); + }else{ + data->setValue<bool>(false); + } + + mResolution.push_back(data); + + } + break; + + case ConditionalTokenTypes::ARGSEP: + case ConditionalTokenTypes::LPAREN: + case ConditionalTokenTypes::RPAREN: + case ConditionalTokenTypes::STOP: + default: + throw std::runtime_error("NODE TYPE NOT SUPORTED IN ConditionalInterpreter"); + } + }catch(const std::exception& e){ + std::ostringstream errorMessage; + errorMessage << "Error in visiting AST for node"<< nodeOp->name() << "\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + } + + return dataVector; + } + + + ////////////////////// + //value convertor + ///////////////////// + + + void ConditionalInterpreter::fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + ConditionalData* data = new ConditionalData; + data->setValue<int>(std::stoi(node->getValue())); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + + ConditionalData* data = new ConditionalData; + data->setValue<float>(std::stof(node->getValue())); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + ConditionalData* data = new ConditionalData; + data->setValue<std::string>(node->getValue()); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) + { + //if the lambda have input + ConditionalData* data; + try { + data = mLambdaRegiter.run(node->getValue(),mResolution); + } catch (const std::exception& e) { + std::ostringstream errorMessage; + errorMessage << "Error in conditional interpretation when run the "<< node->getValue() <<" Lambda\n\t" << e.what() << "\n"; + throw std::runtime_error(errorMessage.str()); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fEq(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("EQ need 2 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + if (a->getType() != b->getType()){ + throw std::runtime_error("EQ Unsuported between type :" + a->getType() +" "+ b->getType()); + } + + + + ConditionalData* data = new ConditionalData; + + if (a->isTypeEqualTo<int>()) { + data->setValue<bool>( a->getValue<int>() == b->getValue<int>()); + }else if (a->isTypeEqualTo<float>()){ + data->setValue<bool>( a->getValue<float>() == b->getValue<float>()); + }else if (a->isTypeEqualTo<std::string>()){ + data->setValue<bool>( a->getValue<std::string>() == b->getValue<std::string>()); + }else if (a->isTypeEqualTo<bool>()){ + data->setValue<bool>( a->getValue<bool>() == b->getValue<bool>()); + }else{ + throw std::runtime_error("EQ Unknown type encountered :" + a->getType() ); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fNeq(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("NEQ need 2 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + if (a->getType() != b->getType()){ + throw std::runtime_error("NEQ Unsuported between type :" + a->getType() +" "+ b->getType()); + } + + ConditionalData* data = new ConditionalData; + + if (a->isTypeEqualTo<int>()) { + data->setValue<bool>( a->getValue<int>() != b->getValue<int>()); + }else if (a->isTypeEqualTo<float>()){ + data->setValue<bool>( a->getValue<float>() != b->getValue<float>()); + }else if (a->isTypeEqualTo<std::string>()){ + data->setValue<bool>( a->getValue<std::string>() != b->getValue<std::string>()); + }else + { + throw std::runtime_error("NEQ Unknown type encountered :" + a->getType() ); + } + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fAnd(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("AND need 2 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + + if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ + throw std::runtime_error("AND Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( a->getValue<bool>() && b->getValue<bool>()); + + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fOr(void) + { + if (mResolution.size() != 2){ + throw std::runtime_error("OR need 2 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + auto b = mResolution[1]; + + + if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ + throw std::runtime_error("OR Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( a->getValue<bool>() || b->getValue<bool>()); + + + clearRes(); + mResolution.push_back(data); + } + + void ConditionalInterpreter::fNot() + { + if (mResolution.size() != 1){ + throw std::runtime_error("NOT need 1 arg and get :" + std::to_string(mResolution.size())); + } + auto a = mResolution[0]; + + if (a->getType() != typeid(bool).name()){ + throw std::runtime_error("NOT Unknown type encountered need bool get :" + a->getType() ); + } + + ConditionalData* data = new ConditionalData; + data->setValue<bool>( !a->getValue<bool>() ); + + clearRes(); + mResolution.push_back(data); + + } diff --git a/src/nodeTester/ConditionalLexer.cpp b/src/nodeTester/ConditionalLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9379bd8409f8f7ec4bae3e0122f88de79718e9dd --- /dev/null +++ b/src/nodeTester/ConditionalLexer.cpp @@ -0,0 +1,242 @@ +#include "aidge/nodeTester/ConditionalLexer.hpp" + +using namespace Aidge; + +////////////////// +//ConditionalLexer +////////////////// + + +ConditionalLexer::ConditionalLexer( const std::string ConditionalExpressions): +mConditionalExpressions(ConditionalExpressions) +{ + mPosition = 0; +} + +std::shared_ptr<ParsingToken<ConditionalTokenTypes>> ConditionalLexer::getNextToken(void){ + std::string currentChars = ""; + + while (mPosition < mConditionalExpressions.length()) + { + //erase all space + if (mConditionalExpressions[mPosition] != ' ') + { + currentChars += mConditionalExpressions[mPosition]; + } + else + { + mPosition++; + continue; + } + //performe tokenisation, find a regex and make a new token + + if (std::regex_match(currentChars,std::regex("\\&\\&")))// the AND TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::AND,""); + } + else if (std::regex_match(currentChars,std::regex("\\|\\|")))// the OR TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::OR,""); + } + else if (std::regex_match(currentChars,std::regex("\\!")))// the Not and not equ + { + mPosition++; + if ( mPosition < mConditionalExpressions.length()){ + currentChars += mConditionalExpressions[mPosition]; + if(std::regex_match(currentChars,std::regex("!="))){ + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NEQ,""); + }else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NOT,""); + } + } + //a not at the end not ok but it's the parseur work + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NOT,""); + } + else if (std::regex_match(currentChars,std::regex("==")))// the EQ TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::EQ,""); + } + else if (std::regex_match(currentChars,std::regex("\\(")))// the LPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::LPAREN,""); + } + else if (std::regex_match(currentChars,std::regex("\\)")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::RPAREN,""); + } + else if (std::regex_match(currentChars,std::regex(",")))// the RPAREN TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::ARGSEP,""); + } + else if (std::regex_match(currentChars,std::regex("\\$")))// the ACTNode TOKEN + { + mPosition++; + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::NODE,""); + } + + + ///// + //non const lent token + ///// + + //LAMBDA, KEY , bool //the fuction TAG + else if (std::regex_match(currentChars,std::regex("[A-Za-z_]")))// the KEY TOKEN (a char next ) + { + //read all the key + bool isLambda = false; + std::regex keyRegex("[A-Za-z_0-9]+"); + std::regex LambdaRegex("[A-Za-z_0-9]+\\("); + + while ( mPosition < mConditionalExpressions.length()) { + if(!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,LambdaRegex)) + { + currentChars.pop_back(); //the last char is the problemes + break; + } + else if (std::regex_match(currentChars,LambdaRegex)){ + isLambda = true; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mConditionalExpressions.length()-1) + { + if (!std::regex_match(currentChars,keyRegex) && !std::regex_match(currentChars,LambdaRegex)) + { + throw badTokenError(currentChars,mPosition); + } + //mPosition++; // we stop all by going pos > lengt + } + + + if (std::regex_match(currentChars,std::regex("(true|false)"))){ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::BOOL,currentChars); + + } else if (isLambda){ + currentChars.pop_back();//pop the ( of the lambda + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::LAMBDA,currentChars); + } else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::KEY,currentChars); + } + + } + //numeric value + else if (std::regex_match(currentChars,std::regex("[0-9]")))// the KEY TOKEN (a char next ) + { + //read all the key + bool isFloat = false; + std::regex integerRegex("[0-9]+$"); + std::regex floatRegex("[0-9]+\\.[0-9]*$"); + + while ( mPosition < mConditionalExpressions.length()) { + + if(!std::regex_match(currentChars,integerRegex) && !std::regex_match(currentChars,floatRegex)) + { + currentChars.pop_back(); // the last char match is not a good one + break; + } + else if (std::regex_match(currentChars,floatRegex)){ + isFloat = true; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + //we end the match 2 posibility + //we are at the end of the mConditionalExpressions and we need to ensure the match + //we are not we can continu + if (mPosition == mConditionalExpressions.length()-1) + { + if (!std::regex_match(currentChars,integerRegex) && !std::regex_match(currentChars,floatRegex)) + { + throw badTokenError(currentChars,mPosition); + } + } + + if(isFloat){ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::FLOAT,currentChars); + }else{ + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::INTEGER,currentChars); + } + + } + //string TODO + else if (std::regex_match(currentChars,std::regex("\'"))) // TODO ' or \' + { + std::regex strRegex("\'[A-Za-z_0-9\\s]*\'$"); + while ( mPosition < mConditionalExpressions.length()) { + if(std::regex_match(currentChars,strRegex)){ + break; + } + mPosition++; + if (mPosition < mConditionalExpressions.length()) currentChars += mConditionalExpressions[mPosition]; + //currentChars += mConditionalExpressions[mPosition]; + } + + //test the end condition + if (mPosition == mConditionalExpressions.length()-1 ){ + if (!std::regex_match(currentChars,strRegex)){ + throw badTokenError(currentChars,mPosition); + } + //mPosition++; // we stop all by going pos > lengt + } + + mPosition++; // go after the last " + //erase the " char + currentChars.pop_back(); + currentChars.erase(0,1); + + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::STRING,currentChars); + + } + + //Array TODO + + mPosition++; + } + + //no more to find no one match the currentChars + if (currentChars.empty()) { + return std::make_shared<ParsingToken<ConditionalTokenTypes>>(ConditionalTokenTypes::STOP,""); // Null shared pointer ; + }else{ + //std::ostringstream errorMessage; + //errorMessage << "\nBad syntax " << currentChars << " :\n" << mConditionalExpressions; + throw badTokenError(currentChars,mPosition); + } + +} + +void ConditionalLexer::rstPosition(void){ + if (isEnd()){ + mPosition = 0; + }else{ + throw badTokenError("end rst",mPosition); + } + +} + +bool ConditionalLexer::isEnd(void){ + return mPosition >= mConditionalExpressions.length(); +} + +std::runtime_error ConditionalLexer::badTokenError(const std::string& currentChars,std::size_t position){ + std::ostringstream errorMessage; + errorMessage << "\nBad syntax " << currentChars << " :\n" << mConditionalExpressions << "\n"; + for (std::size_t i = 0; i < position; i++) { + errorMessage << ' '; + } + errorMessage << "^\n"; + + return std::runtime_error(errorMessage.str()); +} \ No newline at end of file diff --git a/src/nodeTester/ConditionalParser.cpp b/src/nodeTester/ConditionalParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ca2843aabefe9f98bc8ad46a36fe03883d0baef --- /dev/null +++ b/src/nodeTester/ConditionalParser.cpp @@ -0,0 +1,188 @@ + +#include "aidge/nodeTester/ConditionalParser.hpp" + +using namespace Aidge; + + +////////////////////////////// +//ConditionalParser +////////////////////////////// + +ConditionalParser::ConditionalParser(const std::string ConditionalExpressions):mLexer(ConditionalExpressions){ + mCurrentToken = mLexer.getNextToken(); +} + +void ConditionalParser::rstParser(void){ + mLexer.rstPosition(); + mCurrentToken = mLexer.getNextToken(); +} + +void ConditionalParser::ackToken(ConditionalTokenTypes tokenType){ + if(mCurrentToken->getType() == tokenType ){ + + try { + mCurrentToken = mLexer.getNextToken(); + } catch (const std::runtime_error& e) { + std::ostringstream errorMessage; + errorMessage << "Conditional Lexer error in Parser :\n"<< e.what() << std::endl; + throw std::runtime_error(errorMessage.str()); + } + }else{ + + std::ostringstream errorMessage; + errorMessage << "Bad syntax ConditionalParser " << static_cast<int>(mCurrentToken->getType()) <<"!="<< static_cast<int>(tokenType) << "\n"; + errorMessage << mLexer.rep(); + throw std::runtime_error(errorMessage.str()); + } +} + + + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstVal(void){ + /* + val : (KEY|INTEGER|FOAT|STRING|LAMBDA) + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + + if (token->getType() == ConditionalTokenTypes::KEY){ + ackToken(ConditionalTokenTypes::KEY); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::INTEGER){ + ackToken(ConditionalTokenTypes::INTEGER); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::FLOAT){ + ackToken(ConditionalTokenTypes::FLOAT); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::BOOL){ + ackToken(ConditionalTokenTypes::BOOL); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + } + else if(token->getType() == ConditionalTokenTypes::STRING){ + ackToken(ConditionalTokenTypes::STRING); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + + }else if(token->getType() == ConditionalTokenTypes::NODE){ + ackToken(ConditionalTokenTypes::NODE); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token); + + }else if(token->getType() == ConditionalTokenTypes::LAMBDA){ + return constructAstLambda(); + } + + throw std::runtime_error("ConditionalParser unknow val type "+ token->rep().str() + "\n" + mLexer.rep()); + +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstLambda(void){ + /* + AstLambda : LAMBDA val (ARGSEP val)* RPAREN + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> tokenLdb = mCurrentToken->copy(); + ackToken(ConditionalTokenTypes::LAMBDA); + ASTNodeCh paramLambda; + //AT LEAST ONE VALUE AS INPUT OF A LAMBDA + paramLambda.push_back(constructAstVal()); + while (mCurrentToken->getType() != ConditionalTokenTypes::RPAREN) + { + ackToken(ConditionalTokenTypes::ARGSEP); + paramLambda.push_back(constructAstVal()); + } + ackToken(ConditionalTokenTypes::RPAREN); + return std::make_shared<AstNode<ConditionalTokenTypes>>(tokenLdb,paramLambda); +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstCmpr(void){ + /* + cmpr : val (EQ|NEQ) val | LPAREN expr RPAREN + NOT ir ? + */ + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + //we can check the type relation ir key (EQ|NEQ) val | val (EQ|NEQ) key , but val (EQ|NEQ) val is valid ? + if (token->getType() == ConditionalTokenTypes::LPAREN) + { + ackToken(ConditionalTokenTypes::LPAREN); + std::shared_ptr<AstNode<ConditionalTokenTypes>> node = constructAstExpr(); + ackToken(ConditionalTokenTypes::RPAREN); + return node; + }else{ + + std::shared_ptr<AstNode<ConditionalTokenTypes>> node = constructAstVal(); + token = mCurrentToken->copy(); + if (token->getType() == ConditionalTokenTypes::EQ){ + ackToken(ConditionalTokenTypes::EQ); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{node,constructAstVal()}); + }else if(token->getType() == ConditionalTokenTypes::NEQ){ + ackToken(ConditionalTokenTypes::NEQ); + return std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{node,constructAstVal()}); + }else{ + + throw std::runtime_error("constructAstCmpr "+ token->rep().str() + "\n" + mLexer.rep()); + } + + } +} + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::constructAstExpr(std::size_t precLimit /*= 0*/){ + /* + expr : cmpr ((AND | OR) cmpr)* + the NOT is not binary OP can be use in pratt + precedence H to L: TODO + AND + OR + */ + + //the not + std::shared_ptr<AstNode<ConditionalTokenTypes>> left; + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = mCurrentToken->copy(); + + if (mCurrentToken->getType() == ConditionalTokenTypes::NOT ){ + ackToken(ConditionalTokenTypes::NOT ); + left= std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{constructAstCmpr()}); + }else{ + left= constructAstCmpr(); + } + + //pratt + while (mCurrentToken->getType() != ConditionalTokenTypes::STOP ) //security + { + token = mCurrentToken->copy(); + //if the token is not in the map is not a operator so we consider a prec of 0 + if (ConditionalPrec.find(token->getType()) ==ConditionalPrec.end() ){ + return left; + } + + //if my actual operator have a prec <= of the last operator + std::size_t prec = ConditionalPrec.at(token->getType()); + if (prec <= precLimit){ + return left; + } + + //Act all AND and OR + ackToken(token->getType()); + + std::shared_ptr<AstNode<ConditionalTokenTypes>> right = constructAstExpr(prec); + + //i'm not sur what append to newNode + //std::shared_ptr<AstNode<ConditionalTokenTypes>> newNode = std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{left,constructAstCmpr()}); + std::shared_ptr<AstNode<ConditionalTokenTypes>> newNode = std::make_shared<AstNode<ConditionalTokenTypes>>(token,ASTNodeCh{left,right}); + left = newNode; + } + return left; +} + + +std::shared_ptr<AstNode<ConditionalTokenTypes>> ConditionalParser::parse(void){ + /* + expr : cmpr ((AND | OR) cmpr)* + cmpr : val (EQ|NEQ) val | LPAREN expr RPAREN | BOOL | LAMBDA + val : (KEY|INTEGER|FOAT|STRING|LAMBDA) + lambda : LAMBDA val (ARGSEP val)* RPAREN + */ + std::shared_ptr<AstNode<ConditionalTokenTypes>> astTree = constructAstExpr(); + + rstParser(); + return astTree; +} \ No newline at end of file diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1f58c68686d9359fa3b8ea4b5eb54244e988895 --- /dev/null +++ b/src/operator/MetaOperator.cpp @@ -0,0 +1,141 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/utils/ErrorHandling.hpp" + +Aidge::MetaOperator_Op::MetaOperator_Op(const char *type, const std::shared_ptr<GraphView>& graph, + std::vector<NodePtr> inputNodes, + std::vector<NodePtr> outputNodes) + : Operator(type), + mGraph(graph) +{ + mInputs = std::vector<std::shared_ptr<Tensor>>(mGraph->inputs().size()); + for (std::size_t i = 0; i < mInputs.size(); ++i) { + mInputs[i] = std::make_shared<Tensor>(); + } + mOutputs = std::vector<std::shared_ptr<Tensor>>(mGraph->outputs().size()); + for (std::size_t i = 0; i < mOutputs.size(); ++i) { + mOutputs[i] = std::make_shared<Tensor>(); + } + + // Fill inputsNodes and outputsNodes when there is no ambiguity + if (inputNodes.empty()) { + AIDGE_ASSERT(mGraph->inputNodes().size() == 1, "need to specify internal nodes input mapping"); + inputNodes.push_back(*mGraph->inputNodes().begin()); + } + + if (outputNodes.empty()) { + AIDGE_ASSERT(mGraph->outputNodes().size() == 1, "need to specify internal nodes output mapping"); + outputNodes.push_back(*mGraph->outputNodes().begin()); + } + + AIDGE_ASSERT(mGraph->inputNodes().size() == inputNodes.size(), "wrong number of specified input nodes"); + AIDGE_ASSERT(mGraph->outputNodes().size() == outputNodes.size(), "wrong number of specified output nodes"); + + // Identify inputs that are outside the micro-graph + for (const auto& inputNode : inputNodes) { + AIDGE_ASSERT(mGraph->inView(inputNode), "input node must be in the graph"); + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> inputNodeinputs = + inputNode->inputs(); + + int inputIdx = 0; // input idx relative to the current node + for (const auto& in : inputNodeinputs) { + if (in.first == nullptr || !mGraph->inView(in.first)) { + // The input is not connected inside the micro-graph + // (no connection to this input or connection outside the micro-graph) + // => it is therefore an input for the meta-operator + mInputOps.push_back(std::make_pair(inputNode->getOperator(), inputIdx)); + } + + ++inputIdx; + } + } + + // The outputs of the output nodes are also the outputs of the meta-operator + for (const auto& outputNode : outputNodes) { + AIDGE_ASSERT(mGraph->inView(outputNode), "output node must be in the graph"); + const std::vector<std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>>> outputNodeoutputs = + outputNode->outputs(); + + for (size_t outputIdx = 0; outputIdx < outputNodeoutputs.size(); ++outputIdx) { + mOutputOps.push_back(std::make_pair(outputNode->getOperator(), outputIdx)); + } + } + + AIDGE_INTERNAL_ASSERT(mInputOps.size() == mGraph->inputs().size()); + AIDGE_INTERNAL_ASSERT(mOutputOps.size() == mGraph->outputs().size()); +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { + if (mImpl) { + return mImpl->getNbRequiredData(inputIdx); + } + else { + const auto& inputOp = mInputOps[inputIdx]; + return inputOp.first->getNbRequiredData(inputOp.second); + } +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbConsumedData(IOIndex_t inputIdx) const { + if (mImpl) { + return mImpl->getNbConsumedData(inputIdx); + } + else { + const auto& inputOp = mInputOps[inputIdx]; + return inputOp.first->getNbConsumedData(inputOp.second); + } +} + +Aidge::NbElts_t Aidge::MetaOperator_Op::getNbProducedData(IOIndex_t outputIdx) const { + if (mImpl) { + return mImpl->getNbProducedData(outputIdx); + } + else { + const auto& outputOp = mOutputOps[outputIdx]; + return outputOp.first->getNbProducedData(outputOp.second); + } +} + +void Aidge::MetaOperator_Op::updateConsummerProducer() { + if (mImpl) { + mImpl->updateConsummerProducer(); + } + else { + if (!mScheduler) { + // Lazy initialization + mScheduler = std::make_shared<SequentialScheduler>(mGraph); + } + + // TODO: check that generateScheduling() can be called multiple time to iteratively update the schedule. + // It could be a good idea to unify updateConsummerProducer() and generateScheduling() into a "updateScheduling()" + mScheduler->generateScheduling(); + } +} + +void Aidge::MetaOperator_Op::forward() { + if (mImpl) { + // A custom implementation exists for this meta operator + mImpl->forward(); + } + else { + // No custom implementation, use the individual operators implementations + if (!mScheduler) { + // Lazy initialization + // TODO: should we assert that a scheduler already exists at this point? + // => should be created in updateConsummerProducer() + mScheduler = std::make_shared<SequentialScheduler>(mGraph); + mScheduler->generateScheduling(); + } + + mScheduler->forward(false); + } +} diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index 3a50ec3e7f83517267ef4ad04cb2c855f8f9df7e..e5e59582af68f66e6c54d09fac4cb1cc028493dd 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -79,10 +79,10 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n"); } - const DimSize_t channelsSize = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("InChannels"); + const DimSize_t channelsSize = std::dynamic_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("InChannels"); // TODO : suppose we have Conv2D ... - const std::array<DimSize_t, 2> kernelDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<std::array<DimSize_t, 2>>("KernelDims"); + const std::array<DimSize_t, 2> kernelDims = std::dynamic_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<std::array<DimSize_t, 2>>("KernelDims"); std::shared_ptr<Tensor> weight = conv->input(1).first->getOperator()->getOutput(conv->input(1).second); std::shared_ptr<Tensor> bias = conv->input(2).first->getOperator()->getOutput(conv->input(2).second); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 4dc8eb5c84ddb25546a32a672bdc84685a6f79f0..1f34091e54c0f83dae6b60589c20fb8fdf1d5064 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -40,13 +40,10 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // TODO: optimize memory usage // setup initial producers list - mComputationNumber = 0; std::set<std::shared_ptr<Node>> producers; for (const std::shared_ptr<Node>& nodePtr : mGraphView->getNodes()) { if (nodePtr->type() == "Producer") { producers.insert(nodePtr); - } else { - ++mComputationNumber; } } // add Data Input @@ -112,6 +109,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // Push consumers in the list of nodes to run and update the consumer producer system for (const auto& runnable : runnableConsumers) { + if (verbose) printf("Runnable: %s\n", (runnable->type() + "_" + std::to_string(reinterpret_cast<uintptr_t>(runnable.get()))).c_str()); runnable->getOperator()->updateConsummerProducer(); mStaticSchedule.push_back(runnable); } @@ -177,14 +175,19 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { // TODO: handle multiple inputs/outputs void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { + // Forward dims (if allowed) if (forwardDims) {mGraphView->forwardDims(); } - // add each Producer Node. - std::set<std::shared_ptr<Node>> computationOver; + // Generate scheduling *only if empty* + // If scheduling was already generated (in one or several steps, i.e. one or + // several successive call to generateScheduling()), do not generate it twice + if (mStaticSchedule.empty()) { + this->generateScheduling(); + } + // Clear previous scheduling results mScheduling.clear(); - this->generateScheduling(); int cpt = 0; for (const auto& runnable : mStaticSchedule) { if (verbose) @@ -202,12 +205,11 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { } if (!verbose) drawProgressBar(1.0, 50, " "); printf("\n"); - } void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const { FILE* fp = std::fopen((fileName + ".mmd").c_str(), "w"); - std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%s ms\n\n"); + std::fprintf(fp, "gantt\ndateFormat x\naxisFormat %%Q ms\n\n"); if (!mScheduling.empty()) { const auto globalStart = mScheduling[0].start; diff --git a/unit_tests/graph/Test_GraphView.cpp b/unit_tests/graph/Test_GraphView.cpp index 4b929286ba494a452c7f9cb71ce944c7d576c03a..9f014364636c70031b522b09c893e1144af3f133 100644 --- a/unit_tests/graph/Test_GraphView.cpp +++ b/unit_tests/graph/Test_GraphView.cpp @@ -161,7 +161,7 @@ TEST_CASE("[core/graph] GraphView(addChild)") { TEST_CASE("[core/graph] GraphView(inputs)") { auto g1 = std::make_shared<GraphView>("TestGraph"); std::shared_ptr<Node> conv = Conv(3, 32, {3, 3}); - g1->add(conv); + g1->add(conv, false); REQUIRE(g1->inputs() == conv->inputs()); } diff --git a/unit_tests/graph/Test_get.cpp b/unit_tests/graph/Test_get.cpp new file mode 100644 index 0000000000000000000000000000000000000000..afd1f42ee9f5d6cd668dd5cab82172cdc298e149 --- /dev/null +++ b/unit_tests/graph/Test_get.cpp @@ -0,0 +1,55 @@ + + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" + + +using namespace Aidge; +TEST_CASE("get Delta") { + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Conv", 1, 1, 1, "c2"); + std::shared_ptr<Node> conv3 = GenericOperator("Conv", 1, 1, 1, "c3"); + std::shared_ptr<Node> conv3_5 = GenericOperator("Conv", 1, 1, 1, "c3.5"); + std::shared_ptr<Node> conv4 = GenericOperator("Conv", 1, 1, 1, "c4"); + std::shared_ptr<Node> conv5 = GenericOperator("Conv", 1, 1, 1, "c5"); + + g1->add(conv); + g1->addChild(conv1, "c"); + + + std::set<Aidge::NodePtr> see; +conv->getNodeDelta(1,see); + + SECTION("Self return") { + see.clear(); + REQUIRE(conv->getNodeDelta(0,see) == std::set<std::shared_ptr<Node>>{conv}); + } + + + SECTION("child") { + see.clear(); + REQUIRE(conv->getNodeDelta(1,see) == std::set<std::shared_ptr<Node>>{conv1}); + } + + +} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_GRegex.cpp b/unit_tests/graphMatching/Test_GRegex.cpp index 7184fad76a921239753d4752ae1a4a61bf3aec16..2c5907d82e7c5b1d32f1fb38493c7333b68f8731 100644 --- a/unit_tests/graphMatching/Test_GRegex.cpp +++ b/unit_tests/graphMatching/Test_GRegex.cpp @@ -53,6 +53,10 @@ TEST_CASE("Create good init GRegex", "[GRegex]") { // Perform tests REQUIRE(GReg.getStmInit().size() == 1); REQUIRE(GReg.getStmFab().getNumberOfStm() == 1); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } @@ -101,6 +105,10 @@ TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex // Perform tests REQUIRE(result == true_result); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") { @@ -166,6 +174,10 @@ TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GR // Perform tests REQUIRE(result == true_result); REQUIRE(wrong_start_result == empty_result); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } /* diff --git a/unit_tests/graphMatching/Test_SeqStm.cpp b/unit_tests/graphMatching/Test_SeqStm.cpp index baabbbc3c10ec751c64a65ab01c2c4d502f58cb5..db8662e3329abe153d4a0fb2b3c46b950208d6bc 100644 --- a/unit_tests/graphMatching/Test_SeqStm.cpp +++ b/unit_tests/graphMatching/Test_SeqStm.cpp @@ -79,6 +79,10 @@ TEST_CASE("Create good init SeqStm", "[SeqStm]") { REQUIRE(stm.getAllCommonNode().size() == 0); REQUIRE(stm.getAllNodeTested().size() == 0); REQUIRE(stm.getAllNodeValidated().size() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test testNode function", "[SeqStm]") { @@ -156,4 +160,8 @@ TEST_CASE("Test testNode function", "[SeqStm]") { REQUIRE(stm.isStmBlocked() == true); REQUIRE(stm.getAllNodeTested() == testAllNodeTested); REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_StmFactory.cpp b/unit_tests/graphMatching/Test_StmFactory.cpp index b595372fd97a56f2ecf2575429c63db92484bbc0..3c66d0fa817cea674de5ab849091290c976e5735 100644 --- a/unit_tests/graphMatching/Test_StmFactory.cpp +++ b/unit_tests/graphMatching/Test_StmFactory.cpp @@ -36,6 +36,10 @@ TEST_CASE("Create good init StmFactory", "[StmFactory]") { } StmFactory stmF(nodesRegex); REQUIRE(stmF.getNumberOfStm() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { @@ -66,6 +70,10 @@ TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { //test the number of stm REQUIRE(stmF.getNumberOfStm() == 2); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { @@ -123,6 +131,9 @@ TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { REQUIRE(stm->getAllNodeTested() == testAllNodeTested); REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } @@ -185,5 +196,9 @@ TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") { REQUIRE(stmD->isStmBlocked() == false); REQUIRE(stmD->getAllNodeTested().size() == 0); REQUIRE(stmD->getAllNodeValidated().size() == 0); + + for (const std::string& key : nodeTypeKey) { + delete nodesRegex[key]; + } } diff --git a/unit_tests/graphRegex/Test_Fsm.cpp b/unit_tests/graphRegex/Test_Fsm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..e5950f21b323f07b380ae95f70637ca48a173481 --- /dev/null +++ b/unit_tests/graphRegex/Test_Fsm.cpp @@ -0,0 +1,195 @@ +#include <memory> + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/nodeTester/ConditionalInterpreter.hpp" + +#include "aidge/graphRegex/matchFsm/FsmNode.hpp" +#include "aidge/graphRegex/matchFsm/FsmEdge.hpp" +#include "aidge/graphRegex/matchFsm/FsmGraph.hpp" +#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp" + +using namespace Aidge; + +TEST_CASE("matchFSM", "FsmEdge") { + + + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + FsmEdgeUnique EdgeToTest(nodeA,nodeB,toTest); + + SECTION("FsmEdgeUnique constructor") { + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + SECTION("FsmEdgeCommon constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeCommon EdgeToTest(nodeA,nodeB,toTest,"A"); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == true); + } + + SECTION("FsmEdgeRef constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeRef EdgeToTest(nodeA,nodeB,0,-1); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + SECTION("FsmEdgeEmpty constructor") { + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + FsmEdgeEmpty EdgeToTest(nodeA,nodeB); + + REQUIRE(EdgeToTest.getSourceNode() == nodeA); + REQUIRE(EdgeToTest.getDestNode() == nodeB); + REQUIRE(EdgeToTest.isCommon() == false); + } + + + SECTION("FsmEdgeFactory"){ + + std::map<std::string, std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("true==true")}, + {"B",std::make_shared<ConditionalInterpreter>("true==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + +// make(std::shared_ptr<FsmNode> source, std::shared_ptr<FsmNode> dest, +// FsmEdgeTypes type,std::map<std::string, const std::shared_ptr<ConditionalInterpreter>> allTest, +// const std::string& lexeme = ""); + + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(true,false); +// EMPTY = 0, +// REF, +// COMMON, +// UNIQUE + + std::shared_ptr<FsmEdge> edgeE = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::EMPTY,allTest,""); + std::shared_ptr<FsmEdge> edgeU = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::UNIQUE,allTest,"A"); + std::shared_ptr<FsmEdge> edgeC = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::COMMON,allTest,"A#"); + std::shared_ptr<FsmEdge> edgeR = FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::REF,allTest,"(0,1)"); + + //test detection of bad syntax lexem + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::EMPTY,allTest,"A")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::UNIQUE,allTest,"A#")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::COMMON,allTest,"A")); + REQUIRE_THROWS(FsmEdgeFactory::make(nodeA,nodeB,FsmEdgeTypes::REF,allTest,"A")); + + REQUIRE(edgeE->getSourceNode() == nodeA); + REQUIRE(edgeE->getDestNode() == nodeB); + } + + SECTION("graph constructor") { + //make the nodes + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> nodeC = std::make_shared<FsmNode>(false,true); + + //make the edges + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); + + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + + graph->addEdge(edgeAB); + graph->addEdge(edgeBC); + + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeC}); + } + + + SECTION("graph merge") { + + std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); + + //make the nodes + std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> nodeC = std::make_shared<FsmNode>(true,false); + + //make the edges + + std::shared_ptr<FsmEdge> edgeAB = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + std::shared_ptr<FsmEdge> edgeBC = std::make_shared<FsmEdgeUnique>(nodeB,nodeC,toTest); + + std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + graph->addEdge(edgeAB); + graph->addEdge(edgeBC); + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{nodeC}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA,nodeB,nodeC}); + + //make the nodes + std::shared_ptr<FsmNode> node2A = std::make_shared<FsmNode>(false,true); + std::shared_ptr<FsmNode> node2B = std::make_shared<FsmNode>(false,false); + std::shared_ptr<FsmNode> node2C = std::make_shared<FsmNode>(true,false); + + + std::shared_ptr<FsmEdge> edge2AB = std::make_shared<FsmEdgeUnique>(node2A,node2B,toTest); + std::shared_ptr<FsmEdge> edge2BC = std::make_shared<FsmEdgeUnique>(node2B,node2C,toTest); + + std::shared_ptr<FsmGraph> graph2 = std::make_shared<FsmGraph>(); + + + graph2->addEdge(edge2AB); + graph2->addEdge(edge2BC); + + + REQUIRE(graph2->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{node2C}); + REQUIRE(graph2->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{node2A}); + REQUIRE(graph2->getNodes() == std::set<std::shared_ptr<FsmNode>>{node2A,node2B,node2C}); + + + graph->mergeOneStartOneValid(graph2); + + REQUIRE(graph->getValidNodes() == std::set<std::shared_ptr<FsmNode>>{node2C}); + REQUIRE(graph->getStartNodes() == std::vector<std::shared_ptr<FsmNode>>{nodeA}); + REQUIRE(graph->getNodes() == std::set<std::shared_ptr<FsmNode>>{nodeA,nodeB,nodeC,node2B,node2C}); + } + + + + +} + +// TEST_CASE("matchFSM", "FsmGraph") { + +// SECTION("FsmEdgeUnique constructor") { +// //make the nodes +// std::shared_ptr<FsmNode> nodeA = std::make_shared<FsmNode>(true,false); +// std::shared_ptr<FsmNode> nodeB = std::make_shared<FsmNode>(false,true); + +// //make the edges +// std::shared_ptr<ConditionalInterpreter> toTest = std::make_shared<ConditionalInterpreter>("true==true"); +// std::shared_ptr<FsmEdgeUnique> edge = std::make_shared<FsmEdgeUnique>(nodeA,nodeB,toTest); + +// std::shared_ptr<FsmGraph> graph = std::make_shared<FsmGraph>(); + +// graph->addEdge(edge); + + + +// } + +// } \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_FsmMatch.cpp b/unit_tests/graphRegex/Test_FsmMatch.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1fe75be1a47033f75af7ccc4dc5202774444cd10 --- /dev/null +++ b/unit_tests/graphRegex/Test_FsmMatch.cpp @@ -0,0 +1,89 @@ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" + + +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("FsmMatch") { + + SECTION("Construction") { + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"B",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + + allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest["B"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->A",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + + + //REQUIRE(fsm->getNodes().size() == 3); + //REQUIRE(fsm->getStartNodes().size() == 1); + + + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + + g1->add(conv); + g1->addChild(conv1, "c"); + + + REQUIRE(allTest["A"]->test(conv) == true); + REQUIRE(allTest["B"]->test(conv) == true); + + std::vector<std::shared_ptr<Node>> startNodes = {conv}; + + auto result = fsm->test(startNodes); + + REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1}); + } + + + SECTION("2 branche graph"){ + + std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); + std::shared_ptr<Node> conv = GenericOperator("Conv", 1, 1, 1, "c"); + std::shared_ptr<Node> conv1 = GenericOperator("Conv", 1, 1, 1, "c1"); + std::shared_ptr<Node> conv2 = GenericOperator("Fc", 1, 1, 1, "c2"); + + g1->add(conv); + g1->addChild(conv1,conv); + g1->addChild(conv2,conv); + + REQUIRE(g1->getNodes() == std::set<std::shared_ptr<Node>>({conv,conv1,conv2})); + REQUIRE(g1->inputNodes() == std::set<std::shared_ptr<Node>>({conv})); + REQUIRE(g1->outputNodes() == std::set<std::shared_ptr<Node>>({conv1,conv2})); + + + ///////////// + + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("isConv($)==true")}, + {"B",std::make_shared<ConditionalInterpreter>("isFc($)==true")} + }; + allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + allTest["B"]->insertLambda("isFc",+[](NodePtr NodeOp){return NodeOp->type() == "Fc";}); + + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A#->A; A#->B",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + std::vector<std::shared_ptr<Node>> startNodes = {conv,conv}; + auto result = fsm->test(startNodes); + REQUIRE( result->getBiggerSolution() == std::set<NodePtr>{conv,conv1,conv2}); + + } + +} diff --git a/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..9ce090506c9a61abd928b3ae590ee838afb05999 --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphFsmInterpreter.cpp @@ -0,0 +1,42 @@ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/graphRegex/GraphFsmInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("GraphFsmInterpreter", "GraphFsmInterpreter") { + + SECTION("Construction") { + std::map<std::string,std::shared_ptr<ConditionalInterpreter>> allTest = { + {"A",std::make_shared<ConditionalInterpreter>("true==true")}, + {"B",std::make_shared<ConditionalInterpreter>("true==true")}, + {"C",std::make_shared<ConditionalInterpreter>("true==true")} + }; + + //GraphFsmInterpreter("A->B",allTest); + std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->B",allTest); + std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret(); + + + REQUIRE(fsm->getNodes().size() == 3); + REQUIRE(fsm->getStartNodes().size() == 1); + REQUIRE(fsm->getEdge().size() == 2); + + for(auto node : fsm->getNodes()){ + if(node->isValid()){ + REQUIRE(node->getEdges().size() == 0); + }else{ + REQUIRE(node->getEdges().size() == 1); + } + + } + + + } + + + + + +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_GraphLexer.cpp b/unit_tests/graphRegex/Test_GraphLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1b8cc8e018546ebfe3f84202d9404db27b17449b --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphLexer.cpp @@ -0,0 +1,118 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphLexer.hpp" +#include "aidge/graphRegex/GraphRegexTypes.hpp" + +#include "aidge/utilsParsing/ParsingToken.hpp" + + +#include <iostream> +#include <map> +#include <functional> + +using namespace Aidge; + +// NEXT +// QOM +// QZM +// KEY +// CKEY +// SEP +// LPAREN +// RPAREN + +TEST_CASE("GraphRegex", "Lexer") { + SECTION("RandomGenerateTest") { + + std::map<gRegexTokenTypes, std::function<std::pair<std::string, std::string>()>> LexerTestMap{ + {gRegexTokenTypes::NEXT, +[](){return std::pair<std::string, std::string>("-> ","");}}, + {gRegexTokenTypes::QOM, +[](){return std::pair<std::string, std::string>("+ ","");}}, + {gRegexTokenTypes::QZM, +[](){return std::pair<std::string, std::string>("* ","");}}, + {gRegexTokenTypes::SEP, +[](){return std::pair<std::string, std::string>("; ","");}}, + + + + {gRegexTokenTypes::KEY, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {gRegexTokenTypes::CKEY, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890_"; + const std::string num = "1234567890"; + + std::size_t randomIndex = std::rand() % characters.size(); + std::size_t randomNum = std::rand() % num.size(); + std::string key; + std::string idx; + + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + idx += num[randomNum]; + randomIndex = std::rand() % characters.size(); + randomNum = std::rand() % num.size(); + } + + return std::pair<std::string, std::string>(key+"#"+idx+" ",key+"#"+idx);} + }, + + {gRegexTokenTypes::LPAREN, +[](){return std::pair<std::string, std::string>("( ","");}}, + {gRegexTokenTypes::RPAREN, +[](){return std::pair<std::string, std::string>(") ","");}} + //{gRegexTokenTypes::STOP, +[](){return std::pair<std::string, std::string>("","");}} + }; + + + ////////////////// + //TEST GENERATOR + ////////////////// + const std::size_t numRandomElements = 10000; + std::vector<std::tuple<gRegexTokenTypes, std::string>> testVector; + + std::string testString; + + for (std::size_t i = 0; i < numRandomElements; ++i) { + + int randomIndex = std::rand() % LexerTestMap.size(); + // Get an iterator to the random element in the map + auto it = std::next(LexerTestMap.begin(), randomIndex); + // Access the random key and lambda value separately using structured binding + gRegexTokenTypes randomKey = it->first; + + std::function<std::pair<std::string, std::string>()> randomValue = it->second; + std::pair<std::string, std::string> result = randomValue(); + + testString += result.first; + testVector.emplace_back(randomKey, result.second); + + + } + + GraphLexer graphLexer = GraphLexer(testString); + + for (std::tuple<gRegexTokenTypes, std::string> testToken : testVector) { + gRegexTokenTypes tokenToFind = std::get<0>(testToken); + std::string lexemToFind = std::get<1>(testToken); + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = graphLexer.getNextToken(); + + + std::ostringstream errorMessage; + errorMessage << "\n we whant :"<< lexemToFind << "\n we get : "<< token->getLexeme() <<"\n"<< "on \n" << testString << " :\n " ; + + CAPTURE(errorMessage.str()); + REQUIRE(token->getLexeme() == lexemToFind); + REQUIRE(token->getType() == tokenToFind); + } + std::shared_ptr<ParsingToken<gRegexTokenTypes>> token = graphLexer.getNextToken(); + REQUIRE(token->getType() == gRegexTokenTypes::STOP); + } + + +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_GraphParser.cpp b/unit_tests/graphRegex/Test_GraphParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..857caa06f4e5fa383e79ea22bfe1ca28ac0973c8 --- /dev/null +++ b/unit_tests/graphRegex/Test_GraphParser.cpp @@ -0,0 +1,82 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphParser.hpp" +#include "aidge/utilsParsing/AstNode.hpp" +#include <iostream> + + +using namespace Aidge; + + //generative function , + std::string domain(); + std::string exp() { + int randomValue = std::rand() % 3; + switch (randomValue) { + case 0: + return "A"; + case 1 : + return "A#"; + default: + return domain(); + + } + } + + std::string seq() { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return exp(); + default: + return exp()+"->"+seq(); + } + } + + std::string domain() { + int randomValue = std::rand() % 2; + + switch (randomValue) { + // case 0: + // return seq(); + // case 1: + // return seq() + "->" +domain(); + + case 0: + return "("+ seq() +")*"; + default: + return "("+ seq() +")+"; + + // case 4: + // return "("+ domain() +")*" + "->" +domain(); + // default: + // return "("+ domain() +")+" + "->" +domain(); + } + } + + std::string allExpr() { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return seq(); + default : + return seq()+ ";" +allExpr(); + } + } + +/* +exp : KEY(QOM | QZM)? | CKEY | domain +seq :exp (NEXT seq)* +domain : LPAREN seq RPAREN (QOM | QZM) +allExpr: seq (SEP allExpr)* +*/ +TEST_CASE("GraphParser", "Test_GraphParser") { + + SECTION("Empty") { + for (int i = 0; i < 100; ++i) { + const std::string test = allExpr(); + std::cout << test <<"\n"; + GraphParser graphParser = GraphParser(test); + std::shared_ptr<AstNode<gRegexTokenTypes>> tree = graphParser.parse(); + } + } +} \ No newline at end of file diff --git a/unit_tests/graphRegex/Test_graphRegexAST.cpp b/unit_tests/graphRegex/Test_graphRegexAST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..1cdb0bc1934983a26ab742bfe8879455077219cc --- /dev/null +++ b/unit_tests/graphRegex/Test_graphRegexAST.cpp @@ -0,0 +1,71 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/graphRegex/GraphStrInterpreter.hpp" + + +using namespace Aidge; +TEST_CASE("GraphStrInterpreter") { + + + + std::vector<std::string> tests = { + + + //sequ + "A;", + "A->B", + "A->B->C", + //seq and common + "A#", + "A#->B", + "A#->B#", + "A#->B#->C", + "A#->B#->C#", + "A->B#->C", + //sequ quantif + + "A+", + "A+->B+", + "A->B+->C", + //sequ quantif * + "A*", + "A*->B*", + "A->B*->C", + + //sequ quantif + "A*", + "A*->B+", + "A+->B*->C", + //others + + "(A#->B->C#)+", + "(A#->B)+;A#->B->C", + "B+->B->B", + "B#->R*", + "(B#->R)*", + "A->C->B#->B;B#->R", + "B#->R", + "A->C#;A->C#;A->C#;A->C#;A->C#;A->C#", + "B#->R;B#->R", + "A# -> C -> B#; B#->A#", + + // Add more test cases here + }; + + SECTION("AST Regex bijection") { + + for (const std::string& test : tests) { + std::shared_ptr<GraphStrInterpreter> strGenerator = std::make_shared<GraphStrInterpreter>(test); + std::string astString = strGenerator->interpret(); + //supress space in the test becase erase in the AST + std::string testNoS = test; + testNoS.erase(std::remove_if(testNoS.begin(), testNoS.end(), ::isspace), testNoS.end()); + //if the last char is ; (SEP) it will not in the AST and it's not a bug erase it + if (!testNoS.empty() && testNoS.back() == ';') { + // Remove the last character + testNoS.pop_back(); + } + //test + REQUIRE(astString == testNoS); + } + + } +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp new file mode 100644 index 0000000000000000000000000000000000000000..8b502fb546e2f1396b629ebc78bc1bd4d67842e2 --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp @@ -0,0 +1,66 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalInterpreter.hpp" +#include "aidge/operator/GenericOperator.hpp" + + +using namespace Aidge; + + + +TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { + + SECTION("custom Lambda") { + + const std::string test = " !toto($) == true " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + conditionalParser.insertLambda("toto",+[](NodePtr NodeOp){return false;}); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + + SECTION("syntax error") { + + const std::string test = "'A' == 'A' ,&& "; + REQUIRE_THROWS_AS( ConditionalInterpreter(test), std::runtime_error); + + } + + + SECTION("test false int ") { + + const std::string test = " 10 == 11 " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == false); + } + + SECTION("test true int ") { + const std::string test = " 42 == 42 " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + + SECTION("test false str ") { + const std::string test = " 'toto' == 'Corgi' " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == false); + } + + SECTION("test true str ") { + + const std::string test = " 'Corgi' == 'Corgi' " ; + ConditionalInterpreter conditionalParser = ConditionalInterpreter(test); + std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); + bool result = conditionalParser.test(nodeOp); + REQUIRE(result == true); + } + +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalLexer.cpp b/unit_tests/nodeTester/Test_ConditionalLexer.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a937c27227dde4fa03ed7733df9e9552c3c1ac7b --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalLexer.cpp @@ -0,0 +1,144 @@ +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalLexer.hpp" +#include "aidge/utilsParsing/ParsingToken.hpp" + +#include <iostream> +#include <map> +#include <functional> + +using namespace Aidge; + +TEST_CASE("nodeTester", "Lexer") { + SECTION("RandomGenerateTest") { + + std::map<ConditionalTokenTypes, std::function<std::pair<std::string, std::string>()>> LexerTestMap{ + {ConditionalTokenTypes::AND, +[](){return std::pair<std::string, std::string>("&& ","");}}, + {ConditionalTokenTypes::OR, +[](){return std::pair<std::string, std::string>("|| ","");}}, + {ConditionalTokenTypes::EQ, +[](){return std::pair<std::string, std::string>("== ","");}}, + {ConditionalTokenTypes::NEQ, +[](){return std::pair<std::string, std::string>("!= ","");}}, + + {ConditionalTokenTypes::KEY, +[](){return std::pair<std::string, std::string>("A ","A");}}, + + {ConditionalTokenTypes::BOOL, +[](){ + std::size_t keyLen = (std::rand() % 2); + const std::vector<std::string> characters = {"true","false"}; + + return std::pair<std::string, std::string>(characters[keyLen]+" ",characters[keyLen]);} + }, + + {ConditionalTokenTypes::INTEGER, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "1234567890"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {ConditionalTokenTypes::FLOAT, +[](){ + std::size_t keyLen = (std::rand() % 20)+2; + const std::string characters = "1234567890"; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen/2; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + key += "."; + for (std::size_t i = 0; i < keyLen/2; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + return std::pair<std::string, std::string>(key+" ",key);} + }, + + {ConditionalTokenTypes::STRING, +[](){ + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890 "; + std::size_t randomIndex = std::rand() % characters.size(); + std::string key; + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>("'"+key+"' ",key);} + }, + + {ConditionalTokenTypes::LAMBDA, +[](){ + + std::size_t keyLen = (std::rand() % 20)+1; + const std::string characters = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ1234567890"; + const std::string Startchar = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"; + + std::size_t randomIndex = std::rand() % characters.size(); + std::size_t randomStartIndex = std::rand() % Startchar.size(); + + std::string key; + key += Startchar[randomStartIndex]; + + for (std::size_t i = 0; i < keyLen; ++i) { + key += characters[randomIndex]; + randomIndex = std::rand() % characters.size(); + } + + return std::pair<std::string, std::string>(key+"( ",key);} + }, + + {ConditionalTokenTypes::ARGSEP, +[](){return std::pair<std::string, std::string>(", ","");}}, + {ConditionalTokenTypes::NODE, +[](){return std::pair<std::string, std::string>("$ ","");}}, + {ConditionalTokenTypes::LPAREN, +[](){return std::pair<std::string, std::string>("( ","");}}, + {ConditionalTokenTypes::RPAREN, +[](){return std::pair<std::string, std::string>(") ","");}} + //{ConditionalTokenTypes::STOP, +[](){return std::pair<std::string, std::string>("","");}} + }; + + + ////////////////// + //TEST GENERATOR + ////////////////// + const std::size_t numRandomElements = 100; + std::vector<std::tuple<ConditionalTokenTypes, std::string>> testVector; + + std::string testString; + + for (std::size_t i = 0; i < numRandomElements; ++i) { + + int randomIndex = std::rand() % LexerTestMap.size(); + // Get an iterator to the random element in the map + auto it = std::next(LexerTestMap.begin(), randomIndex); + // Access the random key and lambda value separately using structured binding + ConditionalTokenTypes randomKey = it->first; + + std::function<std::pair<std::string, std::string>()> randomValue = it->second; + std::pair<std::string, std::string> result = randomValue(); + + testString += result.first; + testVector.emplace_back(randomKey, result.second); + + + } + + ConditionalLexer conditionalLexer = ConditionalLexer(testString); + + for (std::tuple<ConditionalTokenTypes, std::string> testToken : testVector) { + ConditionalTokenTypes tokenToFind = std::get<0>(testToken); + std::string lexemToFind = std::get<1>(testToken); + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = conditionalLexer.getNextToken(); + + + std::ostringstream errorMessage; + errorMessage << "\n we whant :"<< lexemToFind << "\n we get : "<< token->getLexeme() <<"\n"<< "on \n" << testString << " :\n " ; + + CAPTURE(errorMessage.str()); + REQUIRE(token->getLexeme() == lexemToFind); + REQUIRE(token->getType() == tokenToFind); + } + std::shared_ptr<ParsingToken<ConditionalTokenTypes>> token = conditionalLexer.getNextToken(); + REQUIRE(token->getType() == ConditionalTokenTypes::STOP); + } + + +} \ No newline at end of file diff --git a/unit_tests/nodeTester/Test_ConditionalParser.cpp b/unit_tests/nodeTester/Test_ConditionalParser.cpp new file mode 100644 index 0000000000000000000000000000000000000000..56adb92b41745001e1790e087f07369918794c5d --- /dev/null +++ b/unit_tests/nodeTester/Test_ConditionalParser.cpp @@ -0,0 +1,75 @@ + +#include <catch2/catch_test_macros.hpp> +#include "aidge/nodeTester/ConditionalParser.hpp" +#include "aidge/utilsParsing/AstNode.hpp" + +using namespace Aidge; + + std::string gVal() { + int randomValue = std::rand() % 5; + switch (randomValue) { + case 0: + return std::to_string(std::rand() % 101); + + case 1: + return std::to_string(std::rand() % 101)+"."+std::to_string(std::rand() % 101); + + case 2: + return " 'toto' "; + case 3: + return " A "; + + case 4: + return " A(10) "; + + default: + return " true "; + + } + } + + std::string gExpr() ; + std::string gCmpr() { + int randomValue = std::rand() % 3; + switch (randomValue) { + case 0: + return gVal() + " == " +gVal(); + case 1: + return "("+ gExpr() +")"; + default: + return gVal() + " != " +gVal(); + + } + + + return gVal() + " == " +gVal(); + } + + std::string gExpr() { + std::string out = gCmpr(); + int iterations = std::rand() % 100; + for (int i = 0; i < iterations; ++i) { + int randomValue = std::rand() % 2; + switch (randomValue) { + case 0: + return out +" && " + gCmpr(); + break; + default: + return out +" || " + gCmpr(); + break; + } + } + return out; + } + + +TEST_CASE("ConditionalParser", "ConditionalParser") { + + SECTION("Empty") { + for (int i = 0; i < 100; ++i) { + const std::string test = gExpr(); + ConditionalParser conditionalParser = ConditionalParser(test); + std::shared_ptr<AstNode<ConditionalTokenTypes>> tree = conditionalParser.parse(); + } + } +} \ No newline at end of file diff --git a/unit_tests/operator/Test_MetaOperator.cpp b/unit_tests/operator/Test_MetaOperator.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c090427914390369452ce3259f47830f01ab1754 --- /dev/null +++ b/unit_tests/operator/Test_MetaOperator.cpp @@ -0,0 +1,53 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/operator/MetaOperator.hpp" +#include "aidge/operator/MetaOperatorDefs.hpp" +#include "aidge/graph/GraphView.hpp" +#include <cstddef> + +using namespace Aidge; + +TEST_CASE("[core/operators] MetaOperator", "[Operator]") { + SECTION("PaddedConv") { + auto op = PaddedConv(1, 3, {3, 3}, "padded_conv", {1, 1}, {1, 1, 1, 1}); + + auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraph(); + + REQUIRE(microGraph->getNodes().size() == 2); + REQUIRE(microGraph->inputNodes().size() == 2); // 2 because Conv has inputs outside the meta-op (Producers for weight and bias) + // Order not garanteed by the GraphView + //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->type() == "Pad"); + REQUIRE(microGraph->outputNodes().size() == 1); + REQUIRE((*microGraph->outputNodes().begin())->getOperator()->type() == "Conv"); + REQUIRE(op->nbInputs() == 3); + REQUIRE(op->nbDataInputs() == 1); + REQUIRE(op->nbOutputs() == 1); + + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(); + myInput->resize({2,3,5,5}); + op->getOperator()->associateInput(0,myInput); + op->getOperator()->computeOutputDims(); + + REQUIRE(op->getOperator()->outputDimsForwarded()); + REQUIRE(op->getOperator()->getOutput(0)->dims() == std::vector<size_t>({2,3,5,5})); + REQUIRE(op->getOperator()->getInput(0) == myInput); + // Order not garanteed by the GraphView + //REQUIRE((*microGraph->inputNodes().begin())->getOperator()->getInput(0) == myInput); + REQUIRE(op->getOperator()->getOutput(0) == (*microGraph->outputNodes().begin())->getOperator()->getOutput(0)); + + //op->getOperator()->updateConsummerProducer(); // require implementation + //auto microGraphScheduler = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator())->getMicroGraphScheduler(); + //REQUIRE(microGraphScheduler->getStaticScheduling().size() == 2); + } +}