diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index cc8763580076957d550c7c0702468a593e218569..65dc7f70cde21e522fee29ced4552e6801f0b923 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -20,7 +20,6 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/OpArgs.hpp" -#include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/Match.hpp" #include "aidge/graphmatching/NodeRegex.hpp" #include "aidge/graphmatching/SeqStm.hpp" diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 481099726843146173a37fcddc3bf69723b1a70e..4dea1ed974650ba9ae10c60c720733aa1581b055 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -428,4 +428,4 @@ private: }; } // namespace Aidge -#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */ \ No newline at end of file +#endif /* AIDGE_CORE_GRAPH_GRAPHVIEW_H_ */ diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index f1d0a39d4bd7dba6990a46d61f7456c03244e44e..de9f178347a228796d56d1653adddfed76ea7c5b 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -187,7 +187,7 @@ public: IOIndex_t getNbFreeDataInputs() const; /** - * @brief List input ids of children liked to outputs of the node + * @brief List input ids of children linked to outputs of the node * @return std::vector<std::vector<std::pair<std::shared_ptr<Node>, * IOIndex_t>>> */ diff --git a/include/aidge/graphmatching/GRegex.hpp b/include/aidge/graphmatching/GRegex.hpp deleted file mode 100644 index fd2d0c52ab47e0f03b3307bdbcfcb5a7b81d78d9..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/GRegex.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - - -#ifndef AIDGE_GREGEX_H_ -#define AIDGE_GREGEX_H_ - -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <regex> -#include <memory> // for shared_ptr -#include <algorithm> // for next_permutation - -#include "aidge/graphmatching/Utile.hpp" -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Match.hpp" - - -namespace Aidge{ - -class GRegex { -// __init__(self,nodes_regex:dict,seq_regexps:list) - - StmFactory mStmFab; - std::vector<SeqStm*> mStmInit; - -public: - GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ); - - std::set<NodeTmp> matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch); - - bool walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm); - - bool walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm); - - bool walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm); - - std::set<NodeTmp> get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm); - - std::vector<SeqStm*> getStmInit() const { - return mStmInit; - } - - StmFactory getStmFab() const { - return mStmFab; - } - - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> match(const std::shared_ptr<GraphView> graphToMatch); - Match match(const std::shared_ptr<GraphView> graphToMatch); - -}; - -} -#endif //AIDGE_GREGEX_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/Match.hpp b/include/aidge/graphmatching/Match.hpp deleted file mode 100644 index fc617a22869fde6531fba67c8641581572cbffc4..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/Match.hpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_MATCH_H_ -#define AIDGE_MATCH_H_ - -#include <vector> -#include <set> -#include <iostream> -#include <cassert> -#include "aidge/graphmatching/Utile.hpp" - - -namespace Aidge{ - -class Match { - -public: - Match(); - - size_t getNbMatch(); - - void insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes); - - std::vector<std::vector<NodeTmp>> getStartNodes(); - - std::vector<std::set<NodeTmp>> getMatchNodes(); - -protected: - std::vector<std::vector<NodeTmp>> mStartNodes; - std::vector<std::set<NodeTmp>> mMatchNodes; - -}; - -} -#endif //AIDGE_MATCH_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/NodeRegex.hpp b/include/aidge/graphmatching/NodeRegex.hpp deleted file mode 100644 index 10ba7225834e4abfb7f0f5cd45ffa91b22f2f87d..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/NodeRegex.hpp +++ /dev/null @@ -1,41 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_NODEREGEX_H_ -#define AIDGE_NODEREGEX_H_ -#include <cstdlib> -#include <iostream> -#include <cstring> -#include "aidge/graph/Node.hpp" - - -namespace Aidge { - -class NodeRegex -{ - public: - std::string mCondition; - - NodeRegex(const std::string c){ - mCondition = c; - }; - - // Version 1 - Only test the type of the node (no need for a lexer) - // Input : Node_op - // Output : bool - // return mCondition == Node_op.type - bool _is(std::shared_ptr<Node> &Node_op); - bool isA(std::string NodeType); -}; - -} - -#endif /* _AIDGE_NODEREGEX_H__ */ \ No newline at end of file diff --git a/include/aidge/graphmatching/SeqStm.hpp b/include/aidge/graphmatching/SeqStm.hpp deleted file mode 100755 index 0823b5fc0f292d8cf28f7ead53d01bd8dd8adbfe..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/SeqStm.hpp +++ /dev/null @@ -1,127 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_SEQSTM_H_ -#define AIDGE_SEQSTM_H_ - -#include <iostream> -#include <map> -#include <regex> -#include <set> -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <string> -#include <utility> -#include <vector> - - -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Utile.hpp" - - -namespace Aidge { - -class SeqStm { - -private: - const int mStmIdx; - const std::vector<std::vector<int>> mTransitionMatrix; - // str key of type like 'A' that ce use in the A->B .. extpr - const std::map<std::string, NodeRegex *> mNodesRegex; - // mTypeToIdxTransition.first = std::pair node_type , common_tag - // mTypeToIdxTransition.segond = idx in trans matrix - const std::map<NodeTypeKey, int> mTypeToIdxTransition; - - int mActSt; - std::set<NodeTmp> mAllNodeValidated; - std::set<NodeTmp> mAllNodeTested; - std::set<std::pair<NodeTmp, std::string>> mAllCommonNode; - bool mStmIsValid; - - std::pair<NodeRegex *, std::string> getNodeRegexAndCommonAt(int idxType); - - /** - * @brief test the stm on a type - * @return the common tag - */ - std::string transitionOnNodeType(NodeType nodeType); - -public: - SeqStm(const int mStmIdx, - const std::vector<std::vector<int>> &mTransitionMatrix, - const std::map<std::string, NodeRegex *> &mNodesRegex, - const std::map<NodeTypeKey, int> &mTypeToIdxTransition, int mActSt, - std::set<NodeTmp> mAllNodeValidated, std::set<NodeTmp> mAllNodeTested, - std::set<std::pair<NodeTmp, std::string>> mAllCommonNode, - bool mStmIsValid); - - ////////////////////////////////////// - // STM test - ///////////////////////////////////// - - /** - * @brief get if a st is a valide one - * @return bool - */ - bool isAValidSt(int st) { - std::size_t size = mTransitionMatrix.size(); - return st == static_cast<int>(size - 1) ? true : false; - } - - /** - * @brief true if the stm is blocked into st - * @return bool - */ - bool isStmBlocked() { return mActSt == -1 ? true : false; } - - /** - * @brief true if the stm into valide st - * @return bool - */ - bool isValid() { return mStmIsValid; } - - ///////////////////////////////////// - // utile - ///////////////////////////////////// - /** - * @brief extract from a node is type - * @return bool - */ - NodeType getTheNodeType(NodeTmp node); - - void drawStm(); - ///////////////////////////////////// - // geter - ///////////////////////////////////// - - std::set<std::pair<NodeTmp, std::string>> getAllCommonNode() { - return mAllCommonNode; - } - std::set<NodeTmp> getAllNodeTested() { return mAllNodeTested; } - - std::set<NodeTmp> getAllNodeValidated() { return mAllNodeValidated; } - - SeqStm *duplicateStm(); - - int getStmIdx() { return mStmIdx; } - - int getState() { return mActSt; } - ////////////////////////////////////////// - // USE - ////////////////////////////////////////// - /** - * @brief test the stm on a node - * @return pair new stm state, the common tag - */ - std::pair<int, std::string> testNode(const NodeTmp node); -}; -} // namespace Aidge - -#endif /* AIDGE_SEQSTM_H_ */ \ No newline at end of file diff --git a/include/aidge/graphmatching/StmFactory.hpp b/include/aidge/graphmatching/StmFactory.hpp deleted file mode 100644 index b5850e4a00691ef6c808554a86a6ceec8c38ad19..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/StmFactory.hpp +++ /dev/null @@ -1,55 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#ifndef AIDGE_STMFACTORY_H_ -#define AIDGE_STMFACTORY_H_ - -#include <map> -#include <utility> -#include <set> -#include <string> -#include <vector> -#include <iostream> -#include <stdexcept> // for exception, runtime_error, out_of_range -#include <regex> - -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/Utile.hpp" - -namespace Aidge{ - - - -class StmFactory { - - const std::map<std::string,NodeRegex*>& mNodesRegex; - std::size_t mCmptStm = 0; -public: - StmFactory(const std::map<std::string,NodeRegex*>& nodesRegex); - //StmFactory(){}; - - SeqStm* makeNewStm(const std::string& sequRegex); - SeqStm* duplicateStm(SeqStm* stm); - - std::size_t getNumberOfStm(){ - return mCmptStm; - } -private: - - ParsingReturn initParsingSequRegex(const std::string& sequRegex); - - std::vector<std::vector<int>> initTransitionMatrix(ParsingReturn& parsing); - -}; -} - -#endif //AIDGE_STMFACTORY_H_ \ No newline at end of file diff --git a/include/aidge/graphmatching/Utile.hpp b/include/aidge/graphmatching/Utile.hpp deleted file mode 100644 index acda78cd181519c86ab0b14d5b01bf91223cec9d..0000000000000000000000000000000000000000 --- a/include/aidge/graphmatching/Utile.hpp +++ /dev/null @@ -1,50 +0,0 @@ - -/** - * @file - * @brief - * @version file 1.0.0 - * @author vl241552 - * @copyright - * Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory. - * All rights reserved. - */ - -#ifndef _utile_H_ -#define _utile_H_ - -#include <map> - -#include "aidge/graph/Node.hpp" -#include <map> - -namespace Aidge { - -using NodeTmp = std::shared_ptr<Node>; -using NodeType = std::string; -using CommonTag = std::string; -using NodeTypeKey = std::pair<NodeType, CommonTag>; - -// type def -// struct NodeTypeKey { -// NodeType nodeType; -// std::string commonTag; - -// // for map find -// bool operator<(const NodeTypeKey& other) const { -// if (nodeType != other.nodeType or commonTag != other.commonTag) { -// return false; -// } else { -// return true; -// } -// } - -// }; - -struct ParsingReturn { - std::map<NodeTypeKey, int> typeToIdxTransition; - std::vector<std::pair<NodeTypeKey, std::string>> transition; -}; - -} // namespace Aidge - -#endif //_utile_H_ \ No newline at end of file diff --git a/include/aidge/hook/ExecTime.hpp b/include/aidge/hook/ExecTime.hpp index 212fef58696be702e89c8ad973dcc0dd0fc389ae..0964d9575b7ad345d5e07c9f19c7e56a3b69c813 100644 --- a/include/aidge/hook/ExecTime.hpp +++ b/include/aidge/hook/ExecTime.hpp @@ -18,7 +18,7 @@ #define execTime_H_ #include "aidge/operator/Operator.hpp" -#include "aidge/hook/hook.hpp" +#include "aidge/hook/Hook.hpp" #include <memory> #include <chrono> #include <vector> diff --git a/include/aidge/hook/OutputRange.hpp b/include/aidge/hook/OutputRange.hpp index a2da2a997d594c0ef78fb7c31f33b32c3495c4eb..355f4aaa15a6bcd77d99ec2dad344a45f8f9edc0 100644 --- a/include/aidge/hook/OutputRange.hpp +++ b/include/aidge/hook/OutputRange.hpp @@ -18,7 +18,7 @@ #define AIDGE_CORE_HOOK_OUTPUTRANGE_H_ #include "aidge/operator/Operator.hpp" -#include "aidge/hook/hook.hpp" +#include "aidge/hook/Hook.hpp" #include <memory> #include <chrono> #include <vector> diff --git a/include/aidge/nodeTester/ConditionalInterpreter.hpp b/include/aidge/nodeTester/ConditionalInterpreter.hpp index 674e942c7b77a6e572b0ffbaa90a2571f7a8118a..af6a3b920bb9ca389724860d55250d7ef4540677 100644 --- a/include/aidge/nodeTester/ConditionalInterpreter.hpp +++ b/include/aidge/nodeTester/ConditionalInterpreter.hpp @@ -22,7 +22,7 @@ namespace Aidge{ ///////////////////////////// /** * @brief class used to register any lambda function without context, - * it encapsulates the source lambda in a lambda which takes as argument ConditionalData* which are any type. + * it encapsulates the source lambda in a lambda which takes as argument std::shared_ptr<ConditionalData> which are any type. * @see ConditionalData */ class ConditionalRegisterFunction { @@ -31,12 +31,12 @@ class ConditionalRegisterFunction { ////////////////////////// /** - * @brief recast the ConditionalData* to the argument type of the lambda + * @brief recast the std::shared_ptr<ConditionalData> to the argument type of the lambda * @tparam T type of the lambda argument * @see ConditionalData */ template <typename T> - T safeCastInput(ConditionalData* data) { + T safeCastInput( std::shared_ptr<ConditionalData> data) { //cnvertion and type cheking if (data->isTypeEqualTo<T>()){ return data->getValue<T>(); @@ -48,14 +48,14 @@ class ConditionalRegisterFunction { /** - * @brief recaste the output of the lambda to a ConditionalData* + * @brief recaste the output of the lambda to a std::shared_ptr<ConditionalData> * @tparam T type of the lambda return * @see ConditionalData */ template <typename T> - ConditionalData* safeCastOutput(T data) { + std::shared_ptr<ConditionalData> safeCastOutput(T data) { - ConditionalData* out = new ConditionalData; + std::shared_ptr<ConditionalData> out = std::make_shared<ConditionalData>(); out->setValue<T>(data); return out; @@ -111,11 +111,11 @@ class ConditionalRegisterFunction { }; ///////////////////// - //change the function to ConditionalData*(std::vector<ConditionalData*>) + //change the function to std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>) ///////////////////// /** - * @brief Converts a function to a ConditionalData*(std::vector<ConditionalData*>). + * @brief Converts a function to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>). * @tparam F The type of the function to convert. * @tparam ParamsIdx The indices of the function parameters. * @param f The function to convert. @@ -124,25 +124,31 @@ class ConditionalRegisterFunction { template <class F, std::size_t... ParamsIdx> auto funcPointer(F f, std::index_sequence<ParamsIdx...>) { //wrapp the lambda in a new one that as ConditionalData as inputs and output - return [this,f](std::vector<ConditionalData*> &args) { - if (args.size() != sizeof...(ParamsIdx)){ + return [this,f](std::vector< std::shared_ptr<ConditionalData>> &args) { + if (args.size() < sizeof...(ParamsIdx)){ std::ostringstream errorMessage; errorMessage << "bad Number of argument: get " << args.size() << " need " << sizeof...(ParamsIdx) << "\n"; throw std::runtime_error(errorMessage.str()); } - //assert(args.size() == sizeof...(ParamsIdx));//the size of the vector valide + //we used std::vector< std::shared_ptr<ConditionalData>> as a fifo + std::size_t offset = args.size()-sizeof...(ParamsIdx); using FuncTraits = function_traits<decltype(f)>; using outType = typename FuncTraits::return_type; - outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[ParamsIdx])...); + outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[offset+ParamsIdx])...); + + //suppress what we used + for (size_t i = 0; i < sizeof...(ParamsIdx); ++i) { + args.pop_back(); + } //typename return safeCastOutput<outType>(result); }; } /** - * @brief Converts a function pointer to a ConditionalData*(std::vector<ConditionalData*>). + * @brief Converts a function pointer to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>). * @tparam R The return type of the function. * @tparam Params The parameter types of the function. * @param f The function pointer to convert. @@ -154,7 +160,7 @@ class ConditionalRegisterFunction { } /** - * @brief Converts a std::function to a ConditionalData*(std::vector<ConditionalData*>). + * @brief Converts a std::function to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>). * @tparam R The return type of the function. * @tparam Params The parameter types of the function. * @param f The function pointer to convert. @@ -196,7 +202,7 @@ class ConditionalRegisterFunction { * @param datas The vector of input data. * @return A pointer to the output ConditionalData object. */ - ConditionalData* run(const std::string key,std::vector<ConditionalData*> & datas); + std::shared_ptr<ConditionalData> run(const std::string key,std::vector< std::shared_ptr<ConditionalData>> & datas); bool isLambdaRegister(const std::string &key) { if(mWlambda.find(key) != mWlambda.end()){ @@ -207,7 +213,7 @@ class ConditionalRegisterFunction { private: /// @brief map of name and the converted function. - std::map<const std::string, std::function<ConditionalData*(std::vector<ConditionalData*> &)>> mWlambda; + std::map<const std::string, std::function< std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>> &)>> mWlambda; }; /////////////////// @@ -237,15 +243,15 @@ class ConditionalInterpreter ConditionalRegisterFunction mLambdaRegister; - std::vector<ConditionalData*> mResolution ; + std::vector< std::shared_ptr<ConditionalData>> mResolution ; - void clearRes(){ + // void clearRes(){ - for (std::size_t i = 0; i < mResolution.size(); ++i) { - delete mResolution[i]; - } - mResolution.clear(); - } + // for (std::size_t i = 0; i < mResolution.size(); ++i) { + // delete mResolution[i]; + // } + // mResolution.clear(); + // } public: @@ -258,7 +264,7 @@ class ConditionalInterpreter ConditionalInterpreter(const std::string key,const std::string ConditionalExpressions); - ~ConditionalInterpreter(){clearRes();} + ~ConditionalInterpreter(){} /** * @brief get the condition key @@ -293,12 +299,12 @@ class ConditionalInterpreter * @param NodeOp The node currently being tested * @param nodes The AST given by the parsing process */ - std::vector<ConditionalData*> visit(const ASTNodeCh& nodes, const NodePtr NodeOp ); + std::vector< std::shared_ptr<ConditionalData>> visit(const ASTNodeCh& nodes, const NodePtr NodeOp ); /** * @defgroup ASTnodeInterpreterF Functions for interpreting AST nodes * @brief For each node type in the AST, function defines the processing to be performed - * they return a std::vector<ConditionalData*> which corresponds to the value(s) obtained + * they return a std::vector< std::shared_ptr<ConditionalData>> which corresponds to the value(s) obtained */ /** @@ -308,38 +314,38 @@ class ConditionalInterpreter void fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); /** * @ingroup ASTnodeInterpreterF - * @brief Converted the lexeme to a int and to ConditionalData* + * @brief Converted the lexeme to a int and to std::shared_ptr<ConditionalData> */ void fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); /** * @ingroup ASTnodeInterpreterF - * @brief Converted the lexeme to a float and to ConditionalData* + * @brief Converted the lexeme to a float and to std::shared_ptr<ConditionalData> */ void fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); /** * @ingroup ASTnodeInterpreterF - * @brief Converted the lexeme to a str and to ConditionalData* + * @brief Converted the lexeme to a str and to std::shared_ptr<ConditionalData> */ void fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node); /** * @ingroup ASTnodeInterpreterF - * @brief makes the == operation between two previously converted ConditionalData* + * @brief makes the == operation between two previously converted std::shared_ptr<ConditionalData> */ void fEq(void); /** * @ingroup ASTnodeInterpreterF - * @brief makes the != operation between two previously converted ConditionalData* + * @brief makes the != operation between two previously converted std::shared_ptr<ConditionalData> */ void fNeq(void); /** * @ingroup ASTnodeInterpreterF - * @brief makes the && operation between two previously converted ConditionalData* in bool + * @brief makes the && operation between two previously converted std::shared_ptr<ConditionalData> in bool */ void fAnd(void); /** * @ingroup ASTnodeInterpreterF - * @brief makes the || operation between two previously converted ConditionalData* in bool + * @brief makes the || operation between two previously converted std::shared_ptr<ConditionalData> in bool */ void fOr(void); diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 353666fb3950d034a7dbe8ec1d3ebdb312679f95..43dd7beb10b49c3695e6c55fac0449a34565dd7f 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -28,12 +28,12 @@ namespace Aidge { enum class ScalingAttr { - scalingFactor + scalingFactor, quantizedNbBits, isOutputUnsigned }; class Scaling_Op : public Operator, public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, - public StaticAttributes<ScalingAttr, float> { + public StaticAttributes<ScalingAttr, float, size_t, bool> { public: // FIXME: change accessibility std::shared_ptr<Tensor> mInput = std::make_shared<Tensor>(); @@ -44,16 +44,18 @@ public: Scaling_Op() = delete; - using Attributes_ = StaticAttributes<ScalingAttr, float>; + using Attributes_ = StaticAttributes<ScalingAttr, float, std::size_t, bool>; template <ScalingAttr e> using attr = typename Attributes_::template attr<e>; - Scaling_Op(float scalingFactor) + Scaling_Op(float scalingFactor, std::size_t nbBits, bool isOutputUnsigned) : Operator(Type), Attributes_( - attr<ScalingAttr::scalingFactor>(scalingFactor)) - { - setDatatype(DataType::Float32); - } + attr<ScalingAttr::scalingFactor>(scalingFactor), + attr<ScalingAttr::quantizedNbBits>(nbBits), + attr<ScalingAttr::isOutputUnsigned>(isOutputUnsigned)) { + + setDatatype(DataType::Float32); + } /** * @brief Copy-constructor. Copy the operator attributes and its output tensor(s), but not its input tensors (the new operator has no input associated). @@ -154,15 +156,21 @@ public: } }; +/* inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::string& name = "") { return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name); } +*/ +inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name); +} + } namespace { template <> const char* const EnumStrings<Aidge::ScalingAttr>::data[] - = {"scalingFactor"}; + = {"scalingFactor", "quantizedNbBits", "isOutputUnsigned"}; } #endif /* __AIDGE_CORE_OPERATOR_RELU_H__ */ diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 1896894ee8690cedaef696394da0829604e36211..faf6c49bdbe28e7214f06a4d116cf23a1739154f 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -64,6 +64,9 @@ public: std::vector<std::shared_ptr<Node>> getStaticScheduling(){ return mStaticSchedule; } + std::shared_ptr<GraphView> getGraphView(){ + return mGraphView; + } private: /** diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index c110c9cf8e2ccc84112f7ac48b438f470ee21465..bf6683d342f2b140aef41459bd6633340de3e93d 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -17,6 +17,8 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graphRegex/matchFsm/MatchResult.hpp" + namespace Aidge{ @@ -27,7 +29,12 @@ namespace Aidge{ * * @param nodes Strict set of Node to merge. */ -void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); +//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); + +void fuseMulAdd(std::shared_ptr<MatchSolution> solution); + +void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); + /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * @@ -43,7 +50,11 @@ void fuseMulAdd(std::shared_ptr<GraphView> graphView); * * @param nodes Strict set of Node to merge. */ -void removeFlatten(std::set<std::shared_ptr<Node>> nodes); +void removeFlatten(std::shared_ptr<Node> flatten); + + +void removeFlatten(std::shared_ptr<MatchSolution> solution); + /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * @@ -59,7 +70,12 @@ void removeFlatten(std::shared_ptr<GraphView> graphView); * * @param nodes Strict set of Node to merge. */ -void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes); +void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm); + + + +void fuseBatchNorm(std::shared_ptr<MatchSolution> solution); + /** * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ diff --git a/python_binding/graphRegex/pybind_GraphRegex.cpp b/python_binding/graphRegex/pybind_GraphRegex.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be3cd9e9124ba1306226dcbdc13ee39748cf0606 --- /dev/null +++ b/python_binding/graphRegex/pybind_GraphRegex.cpp @@ -0,0 +1,69 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "aidge/graphRegex/GraphRegex.hpp" + +namespace py = pybind11; +namespace Aidge { +void init_GraphRegex(py::module& m){ + + + py::class_<GraphRegex, std::shared_ptr<GraphRegex>>(m, "GraphRegex", "GraphRegex class describes a regex to test a graph.") + .def(py::init<>()) + + .def("add_query", &GraphRegex::addQuery, R"mydelimiter( + :rtype: str + )mydelimiter") + + .def("set_key_from_graph", &GraphRegex::setKeyFromGraph, R"mydelimiter( + :param ref: The graph use to define type of Node. + :type ref: :py:class:`aidge_core.GraphView` + )mydelimiter") + +// void setNodeKey(const std::string key, const std::string conditionalExpressions ); +// void setNodeKey(const std::string key,std::function<bool(NodePtr)> f); + + .def("match", &GraphRegex::match, R"mydelimiter( + :param graphToMatch: The graph to perform the matching algorithm on. + :type graphToMatch: :py:class:`aidge_core.GraphView` + )mydelimiter") + + + + .def("set_node_key", + (void (GraphRegex::*)(const std::string, const std::string )) & + GraphRegex::setNodeKey, + py::arg("key"), py::arg("conditionalExpressions"), + R"mydelimiter( + Add a node test + :param key: the key of the node test to use in the query. + :param conditionalExpressions: the test to do . + + )mydelimiter") + + + .def("set_node_key", + (void (GraphRegex::*)(const std::string, std::function<bool(NodePtr)>)) & + GraphRegex::setNodeKey, + py::arg("key"), py::arg("f"), + R"mydelimiter( + Add a node test + :param key: the key of the lambda test to use in the conditional expressions. + :param f: bool lambda (nodePtr) . + + )mydelimiter") + + + + ; +} +} diff --git a/python_binding/graphmatching/pybind_GRegex.cpp b/python_binding/graphmatching/pybind_GRegex.cpp deleted file mode 100644 index 48d0e19ff22c1480636b67b5bde70bf1caa1f1b5..0000000000000000000000000000000000000000 --- a/python_binding/graphmatching/pybind_GRegex.cpp +++ /dev/null @@ -1,42 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> -#include "aidge/graph/GraphView.hpp" -#include "aidge/graphmatching/GRegex.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_GRegex(py::module& m){ - py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex", "GRegex class combines a Node Regex and a list of Graph Regex that together describes a graph pattern as a graph regular expression. GRegex find patterns in a given graph that matches the graph regular expression.") - .def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps"), R"mydelimiter( - Constructor of GRegex - - :param nodesRegex: Describe the conditions an operator has to fulfill. - :type nodesRegex: Dict[str,:py:class:`aidge_core.NodeRegex`] - :param seqRegexps: Describe the graph topological pattern. List of Graph Regex as strings. - :type seqRegexps: List[str] - - )mydelimiter") - .def("match", &GRegex::match, py::arg("graphToMatch"), R"mydelimiter( - Launch the graph matching algorithm on a given graph. - - :param graphToMatch: The graph to perform the matching algorithm on. - :type graphToMatch: :py:class:`aidge_core.GraphView` - - :returns: Matched graph patterns. - :rtype: :py:class:`aidge_core.Match` - - )mydelimiter") - ; -} -} diff --git a/python_binding/graphmatching/pybind_Match.cpp b/python_binding/graphmatching/pybind_Match.cpp deleted file mode 100644 index a2d2654f40ed50e20e8761be57e2c8bb98ce4e3b..0000000000000000000000000000000000000000 --- a/python_binding/graphmatching/pybind_Match.cpp +++ /dev/null @@ -1,34 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include <pybind11/stl.h> -#include "aidge/graphmatching/Match.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_Match(py::module& m){ - py::class_<Match, std::shared_ptr<Match>>(m, "Match", "Match class stores the matched patterns resulting from a graph matching query. A matched pattern is the combinaison of the graph pattern start nodes and the set of all the nodes in the matched pattern (including the start nodes)") - .def(py::init<>()) - .def("get_nb_match", &Match::getNbMatch, R"mydelimiter( - :returns: The number of graph patterns matched - :rtype: int - )mydelimiter") - .def("get_start_nodes", &Match::getStartNodes, R"mydelimiter( - :returns: All matched graph patterns start nodes - :rtype: List[List[:py:class:`aidge_core.Nodes`]] - )mydelimiter") - .def("get_match_nodes", &Match::getMatchNodes, R"mydelimiter( - :returns: All matched graph patterns sets of matched nodes - :rtype: List[Set[:py:class:`aidge_core.Nodes`]] - )mydelimiter"); -} -} diff --git a/python_binding/graphmatching/pybind_NodeRegex.cpp b/python_binding/graphmatching/pybind_NodeRegex.cpp deleted file mode 100644 index 034987f9ccae200a1b8877ecd8b3e878c84e8fc3..0000000000000000000000000000000000000000 --- a/python_binding/graphmatching/pybind_NodeRegex.cpp +++ /dev/null @@ -1,28 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <pybind11/pybind11.h> -#include "aidge/graphmatching/NodeRegex.hpp" - -namespace py = pybind11; -namespace Aidge { -void init_NodeRegex(py::module& m){ - py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex", "NodeRegex class describes a condition to test on any operator. Current version only supports testing the type of the operator.") - .def(py::init<const std::string>(), py::arg("condition"), R"mydelimiter( - Constructor of NodeRegex - - :param condition: Condition to be fulfilled by an operator. - :type condition: str - - )mydelimiter") - ; -} -} diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index 6b535e8cf3293b26aaa64f95ca2f9a394768935f..ef02b8aaef9f4ea3bd97559ad9e94c38c5b1d29e 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -20,13 +20,17 @@ void init_Operator(py::module& m){ py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("output", &Operator::output, py::arg("outputIdx")) .def("input", &Operator::input, py::arg("inputIdx")) + .def("nb_inputs", &Operator::nbInputs) .def("nb_data_inputs", &Operator::nbDataInputs) + .def("nb_outputs", &Operator::nbOutputs) .def("associate_input", &Operator::associateInput, py::arg("inputIdx"), py::arg("data")) .def("set_datatype", &Operator::setDatatype, py::arg("datatype")) .def("set_backend", &Operator::setBackend, py::arg("name")) .def("forward", &Operator::forward) // py::keep_alive forbide Python to garbage collect implementation will the Operator is not garbade collected ! .def("set_impl", &Operator::setImpl, py::arg("implementation"), py::keep_alive<1, 2>()) + .def("get_hook", &Operator::getHook) + .def("add_hook", &Operator::addHook) ; } } diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index a482191c78ff56b000e043cd7350ca1c150d1d6e..6cc597b5ee934e4a3b849d45e92e5cb62be1b312 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -45,9 +45,7 @@ void init_GraphView(py::module&); void init_OpArgs(py::module&); void init_Connector(py::module&); -void init_Match(py::module&); -void init_NodeRegex(py::module&); -void init_GRegex(py::module&); +void init_GraphRegex(py::module&); void init_Recipies(py::module&); @@ -87,9 +85,8 @@ void init_Aidge(py::module& m){ init_Sub(m); init_Producer(m); - init_Match(m); - init_NodeRegex(m); - init_GRegex(m); + init_GraphRegex(m); + init_Recipies(m); init_Scheduler(m); init_TensorUtils(m); diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 93c131ef7417135bfdbc657c5c809339430616ed..87abf32073734b37803e4330d56888388c63b9af 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -28,12 +28,13 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. - :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The MatMul and Add nodes to fuse. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. @@ -41,18 +42,20 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( - Recipie to remove a flatten operator. - :param nodes: The flatten operator to remove. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); - m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( - Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + // m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( + // Recipie to remove a flatten operator. - :param nodes: The MatMul and Add nodes to fuse. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The flatten operator to remove. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); + + // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( + // Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. + + // :param nodes: The MatMul and Add nodes to fuse. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. @@ -60,11 +63,12 @@ void init_Recipies(py::module &m) { :param graph_view: Graph view on which we want to apply the recipie :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( - Recipie to remove a flatten operator. + + // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter( + // Recipie to remove a flatten operator. - :param nodes: The flatten operator to remove. - :type nodes: list of :py:class:`aidge_core.Node` - )mydelimiter"); + // :param nodes: The flatten operator to remove. + // :type nodes: list of :py:class:`aidge_core.Node` + // )mydelimiter"); } } // namespace Aidge diff --git a/src/graphmatching/GRegex.cpp b/src/graphmatching/GRegex.cpp deleted file mode 100644 index 6b54c5a476e0319c3fab0751c0528a2084ebc0a7..0000000000000000000000000000000000000000 --- a/src/graphmatching/GRegex.cpp +++ /dev/null @@ -1,301 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graph/GraphView.hpp" - -using namespace Aidge; - -GRegex::GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ):mStmFab(nodesRegex){ - - - //setup all the STM - for (const std::string& sequRegex : seqRegexps) { - mStmInit.push_back(mStmFab.makeNewStm(sequRegex)); - } - -} - -bool GRegex::walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm){ - //test if all stm type are in a valid state - std::vector<int> number_of_valid; - number_of_valid.resize(all_stm.size()); - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - number_of_valid[i] = 0; - for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) { - SeqStm* stm = *it; - if (stm->isValid()){ - number_of_valid[i] +=1; - } - } - } - - for (std::size_t i = 0; i < number_of_valid.size(); ++i) { - if (number_of_valid[i] == 0) { - //std::cout << "NO MATCH at least one stm are not valid" << std::endl; - return false; - } - if (number_of_valid[i] > 1) { - //std::cout << "NO MATCH multiple brach match of stm (// quantification)" << std::endl; - return false; - } - } - return true; -} - -bool GRegex::walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm){ - std::set<NodeTmp> all_stm_node_tested; - std::set<NodeTmp> all_stm_node_validated; - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - //std::cout << "all stm index " << i << " on dimension 1 of size " << all_stm.size() <<std::endl; - for (std::size_t j = 0; j < all_stm[i].size(); ++j) { - //std::cout << "all stm index " << j << " on dimension 2 of size " << all_stm[i].size() <<std::endl; - - std::set<NodeTmp> stm_node_tested = all_stm[i][j]->getAllNodeTested(); - std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated(); - - all_stm_node_tested.insert(stm_node_tested.begin(), stm_node_tested.end()); - all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end()); - } - } - - - std::set<NodeTmp> test_but_not_valid; - for (const auto& x : all_stm_node_tested) { - if (all_stm_node_validated.find(x) == all_stm_node_validated.end()) { - test_but_not_valid.insert(x); - } - } - - - if (!test_but_not_valid.empty()) { - std::cout << "NO MATCH. The node(s) "; - for (const auto& x : test_but_not_valid) { - std::cout << x.get() << ", "; - } - std::cout << " have been tested but not validated." << std::endl; - return false; - } - return true; - -} - -bool GRegex::walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm){ - std::map<NodeTmp, std::pair<std::string,int>> node_to_common_tag; - for (std::size_t i = 0; i < all_stm.size(); ++i) { - for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) { - SeqStm* stm = *it; - - if (!stm->isValid()){ - continue; - } - - for (const auto& pair : stm->getAllCommonNode()) { - const NodeTmp node = pair.first; - const std::string common_tag = pair.second; - - if (node_to_common_tag.find(node) != node_to_common_tag.end()) { - std::string tag = node_to_common_tag[node].first; - int& occurence = node_to_common_tag[node].second; - if (tag!=common_tag){ - std::cout << "NO MATCH. The node " << node << " have two different tags "<< tag << " and " << common_tag << std::endl; - return false; - } else { - occurence += 1; - } - } else { - node_to_common_tag.insert(std::make_pair(node, std::make_pair(common_tag, 1))); - } - } - } - } - /*std::cout << "Node to common tag "; - for (const auto& x : node_to_common_tag) { - std::cout << "(" << x.first << ", " << "[" << x.second.first << ", " << x.second.second << "]" << ") ; "; - } - std::cout << std::endl;*/ - - - for (const auto& pair : node_to_common_tag) { - const std::pair<std::string, int> tag_occurence_pair = pair.second; - if (tag_occurence_pair.second < 1){ - //std::cout << "NO MATCH. The common tag " << tag_occurence_pair.first << " did not match " << std::endl; - return false; - } - } - - return true; -} - -std::set<NodeTmp> GRegex::get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm){ - std::set<NodeTmp> all_stm_node_validated; - - for (std::size_t i = 0; i < all_stm.size(); ++i) { - for (std::size_t j = 0; j < all_stm[i].size(); ++j) { - std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated(); - all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end()); - } - } - return all_stm_node_validated; -} - - -std::set<NodeTmp> GRegex::matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch){ - std::set<NodeTmp> empty_set_return; - //ASSERT - if(startNodes.size() != mStmInit.size()){ - throw std::runtime_error ("bad GRegex start nodes"); - } - - //init the walk - std::vector<std::vector<SeqStm*>> allStm; - std::vector<std::pair<NodeTmp,SeqStm*>> currentWalk; - - for (SeqStm* seqStmPtr : mStmInit) { - SeqStm* newStm = mStmFab.duplicateStm(seqStmPtr); - std::size_t idxStart = newStm->getStmIdx(); - currentWalk.push_back(std::make_pair(startNodes[idxStart],newStm)); - allStm.push_back(std::vector<SeqStm*>()); - } - - //walk - while (currentWalk.size()!=0) - { - std::vector<std::pair<NodeTmp,SeqStm*>> newWalk; - for (const auto& pair : currentWalk) { - const NodeTmp node = pair.first; - SeqStm* stmPtr = pair.second; - - std::pair<int,std::string> test = stmPtr->testNode(node); - int res = test.first; - std::string commonTag = test.second; - - std::set<NodeTmp> next_nodes = graphToMatch->getChildren(node); - - /*std::cout << "Next nodes : " ; - for (const auto& x : next_nodes) { - std::cout << x->name() << ", "; - } - std::cout << std::endl;*/ - - // Test Match - if (commonTag == "" && next_nodes.size() > 1) { - std::cout << "NO MATCH. The node " << node.get() << " is not common and has more than one child" << std::endl; - return empty_set_return; - } - - // If there is no more nodes --> Archive the branch - if (res == -1 || next_nodes.empty()) { - int indexToInsert = stmPtr->getStmIdx(); - allStm[indexToInsert].push_back(stmPtr); - //std::cout << "No more nodes --> STM archived : " << indexToInsert << std::endl; - continue; // TODEV : replace this with 'else' that encapsulate the rest of the function ? - } - - bool first = true; - - // Use an iterator to read through the next_nodes - std::set<NodeTmp>::iterator it; - for (it = next_nodes.begin(); it != next_nodes.end(); ++it) { - // Access the current element using the iterator - std::shared_ptr<Aidge::Node> next_node = *it; - if (first){ - newWalk.push_back(std::make_pair(next_node, stmPtr)); - first = false; - } else { - SeqStm* new_stmPtr = mStmFab.duplicateStm(stmPtr); - newWalk.push_back(std::make_pair(next_node, new_stmPtr)); - } - } - } - currentWalk = newWalk; - } - - //std::cout << "Walk finished" << std::endl; - - if (!walk_validation_all_stm_are_valid(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_all_stm_are_valid finished" << std::endl; - - - if (!walk_validation_all_node_read_validate_by_one_stm(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_all_node_read_validate_by_one_stm finished" << std::endl; - - - if (!walk_validation_common_nodes_same_tag_for_all_stm(allStm)){ - return empty_set_return; - } - //std::cout << "walk_validation_common_nodes_same_tag_for_all_stm finished" << std::endl; - - //std::cout << "MATCH" << std::endl; - - return get_all_validate_nodes(allStm); - -} - - - -Match GRegex::match(const std::shared_ptr<GraphView> graphToMatch){ - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches; - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches; - Match matches; - std::size_t nbStartNodes = mStmInit.size(); - std::set<NodeTmp> allNodes = graphToMatch->getNodes(); - std::size_t nbAllNodes = allNodes.size(); - - std::vector<std::size_t> indices(nbStartNodes, 0); - - while (true) { - // Generate all permutations of the current combination - do { - std::vector<NodeTmp> startNodes; - //std::cout <<"start nodes :"; - for (std::size_t i = 0; i < nbStartNodes; ++i) { - auto it = std::begin(allNodes); - std::advance(it, indices[i]); - //std::cout << (*it).get() << " "; - startNodes.push_back(*it); - } - //std::cout <<"\n"; - - std::set<NodeTmp> match = matchFromStartNodes(startNodes, graphToMatch); - //std::cout << "match size : " << match.size() << " "; - if(match.size() != 0){ - //matches.push_back(std::make_pair(startNodes,match)); - //matches.insert(std::make_pair(startNodes,match)); - matches.insert(startNodes,match); - } - - } while (std::next_permutation(indices.begin(), indices.end())); - - // Generate the next combination with replacement - std::size_t i = nbStartNodes - 1; - while (true) { - if (indices[i] < nbAllNodes - 1) { - ++indices[i]; - break; - } - if (i == 0) { - return matches; - } - --i; - } - std::fill(indices.begin() + i + 1, indices.end(), indices[i]); - } - - return matches; -} \ No newline at end of file diff --git a/src/graphmatching/Match.cpp b/src/graphmatching/Match.cpp deleted file mode 100644 index 6c08b30b11ab220310b476bab2c6d17ed86e4fd1..0000000000000000000000000000000000000000 --- a/src/graphmatching/Match.cpp +++ /dev/null @@ -1,37 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/Match.hpp" - -using namespace Aidge; - -Match::Match(){ - //ctr -} - -size_t Match::getNbMatch(){ - assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted"); - return mStartNodes.size(); -} - -void Match::insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes){ - assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted"); - mStartNodes.push_back(startnodes); - mMatchNodes.push_back(matchnodes); -} - -std::vector<std::vector<NodeTmp>> Match::getStartNodes(){ - return mStartNodes; -} - -std::vector<std::set<NodeTmp>> Match::getMatchNodes(){ - return mMatchNodes; -} \ No newline at end of file diff --git a/src/graphmatching/NodeRegex.cpp b/src/graphmatching/NodeRegex.cpp deleted file mode 100644 index 9bf164f60255c17492e528b0f27dec8c53f74979..0000000000000000000000000000000000000000 --- a/src/graphmatching/NodeRegex.cpp +++ /dev/null @@ -1,46 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/NodeRegex.hpp" - - -// Verification done by the Attribute system - - -// Version 1 - Only test the type of the node (no need for a lexer) -// Input : Node_op -// Output : bool -// return mCondition == Node_op.type -bool Aidge::NodeRegex::_is(std::shared_ptr<Node> &Node_op){ - - std::string NodeType = Node_op->type(); - - return strcmp(NodeType.c_str(), mCondition.c_str()) == 0; -} - - -bool Aidge::NodeRegex::isA(std::string NodeType){ - - return strcmp(NodeType.c_str(), mCondition.c_str()) == 0; -} - -// Version 2 - Test the node to an advanced condition -// Input : Node_op -// Output : bool -// return mCondition applied on Node -/**bool NodeRegex::_is(string &Node_op){ - // Parsing the condition is done in the initialization of the NodeRegex - - // assert attributes exist in the node with the attribute function hasAttr() - - // get the attributes - -}*/ diff --git a/src/graphmatching/SeqStm.cpp b/src/graphmatching/SeqStm.cpp deleted file mode 100755 index 84553cb44cb898535943b31b8c955378e73ccbd5..0000000000000000000000000000000000000000 --- a/src/graphmatching/SeqStm.cpp +++ /dev/null @@ -1,247 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/SeqStm.hpp" - -using namespace Aidge; - - - - - /////////////////////////////////////////////////////// - - SeqStm::SeqStm( - const int stmIdx, - const std::vector<std::vector<int>>& transitionMatrix, - const std::map<std::string,NodeRegex*>& nodesRegex, - const std::map<NodeTypeKey,int>& typeToIdxTransition, - int actSt, - std::set<NodeTmp> allNodeValidated, - std::set<NodeTmp> allNodeTested, - std::set<std::pair<NodeTmp,std::string>> allCommonNode, - bool stmIsValid):mStmIdx(stmIdx), - mTransitionMatrix(transitionMatrix), - mNodesRegex(nodesRegex), - mTypeToIdxTransition(typeToIdxTransition) - { - - //assert - if (transitionMatrix.size() == 0){ - throw std::runtime_error ("no transitionMatrix"); - } - if(transitionMatrix[0].size() == 0 || transitionMatrix[0].size() != typeToIdxTransition.size()){ - throw std::runtime_error ("bad transitionMatrix"); - } - int size = static_cast<int>(transitionMatrix.size()); - if (actSt >= size){ - throw std::runtime_error ("bad actSt"); - } - - - mActSt = actSt; - mAllNodeValidated = allNodeValidated; - mAllNodeTested = allNodeTested; - mAllCommonNode = allCommonNode; - mStmIsValid = stmIsValid; - - } - - SeqStm* SeqStm::duplicateStm(){ - - //deep copy of the set - // std::set<Node> cAllNodeValidated(mAllNodeValidated.begin(), mAllNodeValidated.end()); - // std::set<Node> cAllNodeTested(mAllNodeTested.begin(), mAllNodeTested.end()); - - // std::set<std::pair<Node,std::string>> cAllCommonNode; - // for (const auto& p : mAllCommonNode) { - // cAllCommonNode.insert(p); - // } - - auto newStm = new SeqStm( - mStmIdx, - mTransitionMatrix, - mNodesRegex, - mTypeToIdxTransition, - mActSt, - mAllNodeValidated, - mAllNodeTested, - mAllCommonNode, - mStmIsValid - ); - - return newStm; - } - - - std::pair<NodeRegex*,std::string> SeqStm::getNodeRegexAndCommonAt(int idxType) - { - //std::cout << "!" << idxType << "\n"; - for (auto const& x : mTypeToIdxTransition) - { - //x.second is the value : idx in mTransitionMatrix for the type - //x.first pair of the node regex class and a string that is the common tag '',#,#n - if (x.second == idxType ){ - - if (mNodesRegex.find(x.first.first) != mNodesRegex.end()){ - return std::make_pair(mNodesRegex.find(x.first.first)->second, x.first.second); - }else{ - throw std::runtime_error ("a type is not define in NodesRegex"); - } - } - } - throw std::runtime_error ("bad idx in mNodesRegex"); - return std::make_pair(nullptr,nullptr); - } - - - NodeType SeqStm::getTheNodeType(NodeTmp node) - { - //the node is a str of '{type}{idx}' and we juste want type - // // std::regex re("([a-zA-Z]+)[0-9]+"); - // // std::smatch match; - // // if (std::regex_search(node, match, re) == true) { - // // return match.str(1); - // // } - // // throw std::runtime_error ("Type node not found"); - // // return ""; - - //return node->name(); - return node->type(); - } - - - std::string SeqStm::transitionOnNodeType(NodeType nodeType){ - - if (!isStmBlocked()){ - int idxType = 0; - for (auto & nextSt : mTransitionMatrix[mActSt]) { - // There are a next step for this type - //std::cout << "transition matrix next state -> "<< nextSt<<"\n" ; - if (nextSt != -1){ - //std::cout << "next -> "<< nextSt<< " "<< isAValidSt(nextSt) <<"\n" ; - auto nodeRegex = getNodeRegexAndCommonAt(idxType); - //std::cout << "-> "<< nodeRegex.second<<"\n" ; - if (nodeRegex.first->isA(nodeType)){ - //std::cout << "nodetype tested !"<<"\n" ; - if(isAValidSt(nextSt)){ - //std::cout << "Valid state !"<<"\n" ; - mStmIsValid = true; - } - mActSt = nextSt; - return nodeRegex.second; - } - - } - idxType += 1; - } - - mActSt =-1; - } - - return ""; - } - - - std::pair<int,std::string> SeqStm::testNode(const NodeTmp node){ - - std::string commonTag = ""; - //std::cout << "0\n" ; - if (!isStmBlocked()){ - bool isNextStEnd = std::all_of(mTransitionMatrix[mActSt].begin(), mTransitionMatrix[mActSt].end(), [&](int x){ return x == -1; }); - //std::cout << "1:"<< isNextStEnd <<"\n" ; - //if the next state if full of -1 can we relay add the node test to all node tested - // oker y test it but it sure that not be valid - if(!isNextStEnd){ - mAllNodeTested.insert(node); - } - //std::cout << "2\n" ; - //recurtion avoidance - if(mAllNodeValidated.find(node) == mAllNodeValidated.end()){ - - NodeType nodeType = getTheNodeType(node); - //std::cout << "3 " << nodeType << "\n" ; - commonTag = transitionOnNodeType(nodeType); - //after the transition test, if the node is != -1 the node is valid for the stm - //std::cout << " mActSt = " << mActSt << "\n" ; - if( mActSt != -1 ){ - mAllNodeValidated.insert(node); - } - }else{ - mActSt = -1; - } - } - - if(commonTag != ""){ - mAllCommonNode.insert(std::make_pair(node,commonTag)); - } - return std::make_pair(mActSt,commonTag); - } - - -void SeqStm::drawStm(){ - - //mTransitionMatrix - // Find the maximum width of each column - std::vector<std::size_t> max_widths(mTransitionMatrix[0].size(), 0); - for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i) - { - for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j) - { - std::size_t width = std::to_string(mTransitionMatrix[i][j]).length(); - if (width > max_widths[j]) - { - max_widths[j] = width; - } - } - } - - // Print the vector with aligned columns - for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i) - { - for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j) - { - int i_int = static_cast<int>(i); - if (mActSt == -1 ){ - if(mStmIsValid){ - std::cout << "\033[48;5;40m"; - }else{ - std::cout << "\033[48;5;9m"; - } - } - else if (mActSt == i_int){ - std::cout << "\033[48;5;30m"; - }else{ - std::cout << "\033[48;5;27m"; - } - - // Pad the value with spaces to align it with the maximum width - std::size_t width = std::to_string(mTransitionMatrix[i][j]).length(); - std::string padding(max_widths[j] - width, ' '); - std::cout << padding << mTransitionMatrix[i][j] << " "; - std::cout << "\033[0m"; - } - std::cout << "\n"; - } - - std::cout << "mAllNodeTested : "; - for (const auto& x : mAllNodeTested) { - std::cout << x << ", "; - } - std::cout << "\n"; - - - std::cout << "mAllNodeValidated : "; - for (const auto& x : mAllNodeValidated) { - std::cout << x << ", "; - } - std::cout << "\n"; -} - diff --git a/src/graphmatching/StmFactory.cpp b/src/graphmatching/StmFactory.cpp deleted file mode 100644 index 30b1fad81fc9e7f97dab03f7e6d091a27eeec32b..0000000000000000000000000000000000000000 --- a/src/graphmatching/StmFactory.cpp +++ /dev/null @@ -1,150 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include "aidge/graphmatching/StmFactory.hpp" - -using namespace Aidge; - -StmFactory::StmFactory(const std::map<std::string, NodeRegex *> &nodesRegex) - : mNodesRegex(nodesRegex) {} - -SeqStm *StmFactory::duplicateStm(SeqStm *stm) { return stm->duplicateStm(); } - -SeqStm *StmFactory::makeNewStm(const std::string &sequRegex) { - - ParsingReturn parsing = initParsingSequRegex(sequRegex); - std::vector<std::vector<int>> transitionMatrix = - initTransitionMatrix(parsing); - - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp, std::string>> allCommonNode; - - SeqStm *newStm = new SeqStm(static_cast<int>(mCmptStm), transitionMatrix, mNodesRegex, - parsing.typeToIdxTransition, 0, allNodeValidated, - allNodeTested, allCommonNode, false); - mCmptStm += 1; - - return newStm; -} - -ParsingReturn StmFactory::initParsingSequRegex(const std::string &sequRegex) { - - std::string toMatch; - std::regex re("\\s*([A-Za-z]+)(#\\d*)?([+*])?\\s*(->|;)"); - std::smatch matches; - - int idxType = 0; - // return - ParsingReturn parsing; - // std::map<std::pair<NodeType,std::string>,int> typeToIdxTransition; - // std::vector<std::pair<std::pair<NodeType,std::string>,std::string>> - // transition; - // assert - std::map<NodeType, std::string> assertCommonNodeTypes; - - for (std::size_t i = 0; i < sequRegex.length(); i++) { - toMatch += sequRegex[i]; - if (std::regex_match(toMatch, matches, re)) { - - std::string type = matches.str(1); - std::string commonTag = matches.str(2); - std::string quantification = matches.str(3); - - if ((commonTag != "") && (quantification != "")) { - throw std::runtime_error("bad commonTag and quantification"); - } - - // make the typeToIdxTransition - NodeTypeKey typeTag = std::make_pair(type, commonTag); - /*std::cout << " typeTag: " << type << " " << commonTag - << parsing.typeToIdxTransition.size() << std::endl;*/ - if (parsing.typeToIdxTransition.find(typeTag) == - parsing.typeToIdxTransition.end()) { - parsing.typeToIdxTransition[typeTag] = idxType; - idxType += 1; - } - //////////////////////////////////////////////////////////// - // ASSERT - // SAME Common node in the sequ - if (commonTag != "") { - if (assertCommonNodeTypes.find(type) != assertCommonNodeTypes.end()) { - if (assertCommonNodeTypes[type] == commonTag) { - throw std::runtime_error("same common node in the sequ regex"); - } - } else { - assertCommonNodeTypes[type] = commonTag; - } - } - - // save all transition - parsing.transition.push_back(std::make_pair(typeTag, quantification)); - - /*std::cout << "Match found: " << matches.str() << std::endl; - std::cout << "Type: " << matches.str(1) << std::endl; - std::cout << "Common tag: " << matches.str(2) << std::endl; - std::cout << "Quantification: " << matches.str(3) << std::endl;*/ - - toMatch = ""; - } - } - if (parsing.transition.size() == 0) { - throw std::runtime_error("Bad Parsing SequRegex "); - } - - return parsing; -} - -std::vector<std::vector<int>> -StmFactory::initTransitionMatrix(ParsingReturn &parsing) { - - // std::pair<NodeTypeKey,std::string> - std::vector<std::vector<int>> transitionMatrix; - std::size_t numberOfType = parsing.typeToIdxTransition.size(); - - if (numberOfType == 0) { - throw std::runtime_error("Bad number Of Type "); - } - // init start st - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - - std::size_t idxTransition = 0; - int idxState = 0; - for (const auto &pair : parsing.transition) { - const NodeTypeKey &nodeTypeKey = pair.first; - const std::string &quant = pair.second; - - /*std::cout << "Key: {" << nodeTypeKey.first << ", " << nodeTypeKey.second - << "}, Value: " << quant << std::endl; - std::cout << "idxState " << idxState << " TM: " << transitionMatrix.size() - << std::endl;*/ - std::size_t idxType = parsing.typeToIdxTransition[nodeTypeKey]; - /*std::cout << "idxType " << idxType << " TM: " << transitionMatrix[0].size() - << "type" << numberOfType << std::endl;*/ - - if (quant == "*") { - transitionMatrix[idxTransition][idxType] = idxState; - } else if (quant == "+") { - idxState += 1; - transitionMatrix[idxTransition][idxType] = idxState; - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - idxTransition += 1; - transitionMatrix[idxTransition][idxType] = idxState; - } else { - - idxState += 1; - transitionMatrix[idxTransition][idxType] = idxState; - transitionMatrix.push_back(std::vector<int>(numberOfType, -1)); - idxTransition += 1; - } - } - return transitionMatrix; -} \ No newline at end of file diff --git a/src/nodeTester/ConditionalInterpreter.cpp b/src/nodeTester/ConditionalInterpreter.cpp index 59515d0acd77a6202e698ca1e8f1bb28b105266c..f40e62305334f740057f88ef21cdab749d64bd99 100644 --- a/src/nodeTester/ConditionalInterpreter.cpp +++ b/src/nodeTester/ConditionalInterpreter.cpp @@ -8,7 +8,7 @@ using namespace Aidge; //ConditionalRegisterFunction /////////////////////////////// - ConditionalData* ConditionalRegisterFunction::run(const std::string key,std::vector<ConditionalData*> & datas){ + std::shared_ptr<ConditionalData> ConditionalRegisterFunction::run(const std::string key,std::vector< std::shared_ptr<ConditionalData>> & datas){ auto lambdaIt = mWlambda.find(key); if (lambdaIt != mWlambda.end()) { @@ -45,10 +45,9 @@ using namespace Aidge; bool ConditionalInterpreter::test( const NodePtr nodeOp) { - - clearRes(); + mResolution.clear(); try{ - std::vector<ConditionalData*> r = visit({mTree},nodeOp); + std::vector< std::shared_ptr<ConditionalData>> r = visit({mTree},nodeOp); if (mResolution.size() != 1){ throw std::runtime_error("Multi output interpretation output"); @@ -72,8 +71,8 @@ using namespace Aidge; } ///// - std::vector<ConditionalData*> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){ - std::vector<ConditionalData*> dataVector; + std::vector< std::shared_ptr<ConditionalData>> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){ + std::vector< std::shared_ptr<ConditionalData>> dataVector; for ( std::shared_ptr<AstNode<ConditionalTokenTypes>> node : nodes) { try{ @@ -140,7 +139,7 @@ using namespace Aidge; case ConditionalTokenTypes::NODE: //TODO { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<NodePtr>(nodeOp); mResolution.push_back(data); @@ -157,7 +156,7 @@ using namespace Aidge; case ConditionalTokenTypes::BOOL: //TODO { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); if(node->getValue() == "true"){ data->setValue<bool>(true); @@ -195,7 +194,8 @@ using namespace Aidge; void ConditionalInterpreter::fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); + data->setValue<int>(std::stoi(node->getValue())); mResolution.push_back(data); } @@ -203,14 +203,14 @@ using namespace Aidge; void ConditionalInterpreter::fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<float>(std::stof(node->getValue())); mResolution.push_back(data); } void ConditionalInterpreter::fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) { - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<std::string>(node->getValue()); mResolution.push_back(data); } @@ -218,7 +218,7 @@ using namespace Aidge; void ConditionalInterpreter::fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node) { //if the lambda have input - ConditionalData* data; + std::shared_ptr<ConditionalData> data; try { data = mLambdaRegister.run(node->getValue(),mResolution); } catch (const std::exception& e) { @@ -227,17 +227,20 @@ using namespace Aidge; throw std::runtime_error(errorMessage.str()); } - clearRes(); + //clearRes(); mResolution.push_back(data); } void ConditionalInterpreter::fEq(void) { - if (mResolution.size() != 2){ + if (mResolution.size() < 2){ throw std::runtime_error("EQ need 2 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; - auto b = mResolution[1]; + auto a = mResolution.back(); + mResolution.pop_back(); + auto b = mResolution.back(); + mResolution.pop_back(); + if (a->getType() != b->getType()){ throw std::runtime_error("EQ Unsupported between type :" + a->getType() +" "+ b->getType()); @@ -245,7 +248,7 @@ using namespace Aidge; - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); if (a->isTypeEqualTo<int>()) { data->setValue<bool>( a->getValue<int>() == b->getValue<int>()); @@ -259,23 +262,25 @@ using namespace Aidge; throw std::runtime_error("EQ Unknown type encountered :" + a->getType() ); } - clearRes(); + mResolution.push_back(data); } void ConditionalInterpreter::fNeq(void) { - if (mResolution.size() != 2){ + if (mResolution.size() < 2){ throw std::runtime_error("NEQ need 2 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; - auto b = mResolution[1]; + auto a = mResolution.back(); + mResolution.pop_back(); + auto b = mResolution.back(); + mResolution.pop_back(); if (a->getType() != b->getType()){ throw std::runtime_error("NEQ Unsupported between type :" + a->getType() +" "+ b->getType()); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); if (a->isTypeEqualTo<int>()) { data->setValue<bool>( a->getValue<int>() != b->getValue<int>()); @@ -288,67 +293,72 @@ using namespace Aidge; throw std::runtime_error("NEQ Unknown type encountered :" + a->getType() ); } - clearRes(); + mResolution.push_back(data); } void ConditionalInterpreter::fAnd(void) { - if (mResolution.size() != 2){ + if (mResolution.size() < 2){ throw std::runtime_error("AND need 2 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; - auto b = mResolution[1]; + auto a = mResolution.back(); + mResolution.pop_back(); + auto b = mResolution.back(); + mResolution.pop_back(); if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ throw std::runtime_error("AND Unknown type encountered need bool get :" + a->getType() ); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<bool>( a->getValue<bool>() && b->getValue<bool>()); - clearRes(); + mResolution.push_back(data); } void ConditionalInterpreter::fOr(void) { - if (mResolution.size() != 2){ + if (mResolution.size() < 2){ throw std::runtime_error("OR need 2 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; - auto b = mResolution[1]; + auto a = mResolution.back(); + mResolution.pop_back(); + auto b = mResolution.back(); + mResolution.pop_back(); if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){ throw std::runtime_error("OR Unknown type encountered need bool get :" + a->getType() ); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<bool>( a->getValue<bool>() || b->getValue<bool>()); - clearRes(); + mResolution.push_back(data); } void ConditionalInterpreter::fNot() { - if (mResolution.size() != 1){ + if (mResolution.size() < 1){ throw std::runtime_error("NOT need 1 arg and get :" + std::to_string(mResolution.size())); } - auto a = mResolution[0]; + auto a = mResolution.back(); + mResolution.pop_back(); if (a->getType() != typeid(bool).name()){ throw std::runtime_error("NOT Unknown type encountered need bool get :" + a->getType() ); } - ConditionalData* data = new ConditionalData; + std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>(); data->setValue<bool>( !a->getValue<bool>() ); - clearRes(); + mResolution.push_back(data); } diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index 4b2f7a811c022ee80eec98548049853d56951edb..6e345a6474821230f95900cc20cba501feabd1d9 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -21,28 +21,17 @@ #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" + + +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + using namespace Aidge; -void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ +void Aidge::fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm){ + - assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - // Assert the nodes types are correct to be fused - std::shared_ptr<Node> conv; - std::shared_ptr<Node> batchnorm; - for (const auto& element : nodes) { - assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace"); - if (element->type() == "Conv"){ - conv = element; - } - else if (element->type() == "BatchNorm") { - batchnorm = element; - } - } - // TODO : check if batchnorm is the only child of the Conv or FC std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second); @@ -127,19 +116,32 @@ void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ } +void Aidge::fuseBatchNorm(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n"); + assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n"); + + for (const auto& op : solution->at("OP")) { + for (const auto& batchNorm : solution->at("BatchNorm")) { + fuseBatchNorm(op,batchNorm); + } + } + +} + void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ - std::map<std::string,NodeRegex*> nodesRegex ; - nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm"); - nodesRegex["Conv"] = new NodeRegex("Conv"); - nodesRegex["FC"] = new NodeRegex("FC"); - - - std::vector<std::string> seqRegex; - seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC) - GRegex GReg(nodesRegex, seqRegex); - Match matches = GReg.match(graphView); - std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - for (size_t i = 0; i < matches.getNbMatch(); ++i) { - fuseBatchNorm(matchNodes[i]); + + + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'"); + regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' "); + + regex->addQuery("OP -> BatchNorm"); + + for (const auto& solution : regex->match(graphView)) { + + fuseBatchNorm(solution); + } + } diff --git a/src/recipies/FuseMulAdd.cpp b/src/recipies/FuseMulAdd.cpp index 528d57e31a5ecf3f5a633a20205e79f7926a1f61..df0fb5eff2febc93edee1719939dfcfde1bc210a 100644 --- a/src/recipies/FuseMulAdd.cpp +++ b/src/recipies/FuseMulAdd.cpp @@ -22,30 +22,17 @@ #include "aidge/operator/GenericOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + using namespace Aidge; -void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ +void Aidge::fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add){//std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) - assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ? + assert((matmul->type() == "MatMul" && add->type() == "Add") && "Wrong type for the nodes to replace"); - // Step 0 : Assert the nodes types are correct to be fused - std::shared_ptr<Node> add; - std::shared_ptr<Node> matmul; - for (const auto& element : nodes) { - assert((element->type() == "MatMul" || element->type() == "Add") && "Wrong type for the nodes to replace"); - if (element->type() == "MatMul"){ - matmul = element; - } - else if (element->type() == "Add") { - add = element; - } - } // Step 1 : Create FC // Fetch the output dimension throught the bias size @@ -78,17 +65,35 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ } + +void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); + assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); + + for (const auto& matmul : solution->at("MatMul")) { + for (const auto& add : solution->at("Add")) { + fuseMulAdd(matmul,add); + } + } +} + + void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){ - std::map<std::string,NodeRegex*> nodesRegex ; - nodesRegex["MatMul"] = new NodeRegex("MatMul"); - nodesRegex["Add"] = new NodeRegex("Add"); - std::vector<std::string> seqRegex; - seqRegex.push_back("MatMul -> Add;"); - GRegex GReg(nodesRegex, seqRegex); - Match matches = GReg.match(graphView); - std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - for (size_t i = 0; i < matches.getNbMatch(); ++i) { - fuseMulAdd(matchNodes[i]); + + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Add","getType($) =='Add'"); + regex->setNodeKey("MatMul","getType($) =='MatMul'"); + regex->addQuery("MatMul -> Add ;"); + + for (const auto& solution : regex->match(graphView)) { + + fuseMulAdd(solution); + + + } + + } diff --git a/src/recipies/RemoveFlatten.cpp b/src/recipies/RemoveFlatten.cpp index fdfdbfd4aea7543dde31d5f5d4845e54e930feac..0dc8d856f88f1fbf7d530338072aa5b34007caaf 100644 --- a/src/recipies/RemoveFlatten.cpp +++ b/src/recipies/RemoveFlatten.cpp @@ -15,36 +15,41 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/utils/Recipies.hpp" -// Graph Regex -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" -namespace Aidge { - void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { - assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); - std::shared_ptr<Node> flatten; - for (const auto& element : nodes) { - assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); - if (element->type() == "Flatten"){ - flatten = element; - } - } +namespace Aidge { + void removeFlatten(std::shared_ptr<Node> flatten) { + GraphView::replace({flatten}, {}); } + void removeFlatten(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n"); + assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n"); + + for (const auto& flatten : solution->at("Flatten")) { + removeFlatten(flatten); + } + } + + + void removeFlatten(std::shared_ptr<GraphView> graphView){ - std::map<std::string,NodeRegex*> nodesRegex ; - nodesRegex["Flatten"] = new NodeRegex("Flatten"); - nodesRegex["FC"] = new NodeRegex("FC"); - std::vector<std::string> seqRegex; - seqRegex.push_back("Flatten->FC;"); - GRegex GReg(nodesRegex, seqRegex); - Match matches = GReg.match(graphView); - std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); - for (size_t i = 0; i < matches.getNbMatch(); ++i) { - removeFlatten(matchNodes[i]); + + + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Flatten","getType($) =='Flatten'"); + regex->setNodeKey("FC","getType($) =='FC'"); + regex->addQuery("Flatten->FC"); + + for (const auto& solution : regex->match(graphView)) { + removeFlatten(solution); } + + } } diff --git a/unit_tests/graphMatching/Test_GRegex.cpp b/unit_tests/graphMatching/Test_GRegex.cpp deleted file mode 100644 index 2c5907d82e7c5b1d32f1fb38493c7333b68f8731..0000000000000000000000000000000000000000 --- a/unit_tests/graphMatching/Test_GRegex.cpp +++ /dev/null @@ -1,318 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/GRegex.hpp" -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/graphmatching/Match.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/graph/GraphView.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init GRegex", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("A->B;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - // Perform tests - REQUIRE(GReg.getStmInit().size() == 1); - REQUIRE(GReg.getStmFab().getNumberOfStm() == 1); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - - -TEST_CASE("Function matchFromStartNodes | One Match of Nodes sequence", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"Conv","BN","ReLU"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("Conv->BN->ReLU;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> Random2 = GenericOperator("Random2", 1, 1, 1); - - - g1->add(Conv1); - g1->addChild(BN1, Conv1); - g1->addChild(ReLU1, BN1); - g1->addChild(Random, ReLU1); - //g1->addChild(BN1, Random2); - - std::vector<std::shared_ptr<Node>> startNodes1; - std::set<std::shared_ptr<Node>> result; - - startNodes1.push_back(Conv1); - result = GReg.matchFromStartNodes(startNodes1, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(Conv1); - true_result.insert(BN1); - true_result.insert(ReLU1); - - // Perform tests - REQUIRE(result == true_result); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Function matchFromStartNodes | One Match of parallel branches ", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"Add","FC","Conv"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("Add#->Conv;"); - seqRegex.push_back("Add#->FC;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - // Instanciate a graphView - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> Add1 = GenericOperator("Add", 1, 1, 1); - std::shared_ptr<Node> Conv1 = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> BN1 = GenericOperator("BN", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - - g1->add(Random0); - g1->addChild(Add1, Random0); - g1->addChild(Conv1, Add1); - g1->addChild(BN1, Conv1); - g1->addChild(ReLU1, BN1); - g1->addChild(FC1, Add1); - g1->addChild(Random, FC1); - - // Test 1 : Find the match - std::vector<std::shared_ptr<Node>> startNodes; - std::set<std::shared_ptr<Node>> result; - - startNodes.push_back(Add1); - startNodes.push_back(Add1); - result = GReg.matchFromStartNodes(startNodes, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(Add1); - true_result.insert(Conv1); - true_result.insert(FC1); - - // Test 2 : Return an empty set when the start nodes are wrong - std::vector<std::shared_ptr<Node>> wrong_startNodes; - std::set<std::shared_ptr<Node>> wrong_start_result; - std::set<std::shared_ptr<Node>> empty_result; - - wrong_startNodes.push_back(Random0); - wrong_startNodes.push_back(Random0); - wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); - - // Perform tests - REQUIRE(result == true_result); - REQUIRE(wrong_start_result == empty_result); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -/* -TEST_CASE("Function matchFromStartNodes | Match a sequence with quantifier ", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"FC"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("FC+;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - - // Instanciate a graphView - std::shared_ptr<GraphView> g1 = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> FC1 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> FC2 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> FC3 = GenericOperator("FC", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - - g1->add(Random0); - g1->addChild(FC1, Random0); - g1->addChild(FC2, FC1); - g1->addChild(FC3, FC2); - g1->addChild(ReLU1, FC3); - - // Test 1 : Find the match - std::vector<std::shared_ptr<Node>> startNodes; - std::set<std::shared_ptr<Node>> result; - - startNodes.push_back(FC1); - result = GReg.matchFromStartNodes(startNodes, g1); - - std::set<std::shared_ptr<Node>> true_result; - true_result.insert(FC1); - true_result.insert(FC2); - true_result.insert(FC3); - - // Test 2 : Return an empty set when the start nodes are wrong - std::vector<std::shared_ptr<Node>> wrong_startNodes; - std::set<std::shared_ptr<Node>> wrong_start_result; - std::set<std::shared_ptr<Node>> empty_result; - - wrong_startNodes.push_back(Random0); - wrong_start_result = GReg.matchFromStartNodes(wrong_startNodes, g1); - - // Perform tests - REQUIRE(result == true_result); - REQUIRE(wrong_start_result == empty_result); -} -*/ - -TEST_CASE("Function match | ALL matches of Nodes sequence", "[GRegex]") { - // init all input for GRegex - // Nodes Regex map : std::map<std::string,NodeRegex*>& nodesRegex - // Sequential Regex vector : std::vector<std::string>& seqRegexps - - // init the Nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"GEMM"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - // init the Sequential Regex vector - std::vector<std::string> seqRegex; - seqRegex.push_back("GEMM;"); - - // Instanciate a GRegex - GRegex GReg(nodesRegex, seqRegex); - - //init the input graph - std::shared_ptr<GraphView> graphToMatch = std::make_shared<GraphView>("TestGraph"); - std::shared_ptr<Node> Random0 = GenericOperator("Random", 1, 1, 1); - std::shared_ptr<Node> GEMM1 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> ReLU1 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> GEMM2 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> GEMM3 = GenericOperator("GEMM", 1, 1, 1); - std::shared_ptr<Node> ReLU2 = GenericOperator("ReLU", 1, 1, 1); - std::shared_ptr<Node> Random = GenericOperator("Random", 1, 1, 1); - - graphToMatch->add(Random0); - graphToMatch->addChild(GEMM1, Random0); - graphToMatch->addChild(ReLU1, GEMM1); - graphToMatch->addChild(GEMM2, ReLU1); - graphToMatch->addChild(GEMM3, GEMM2); - graphToMatch->addChild(ReLU2, GEMM3); - graphToMatch->addChild(Random, ReLU2); - - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); - //std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs = GReg.match(graphToMatch); - Match matches = GReg.match(graphToMatch); - - size_t nb = matches.getNbMatch(); - std::vector<std::vector<NodeTmp>> gm_startnodes = matches.getStartNodes(); - std::vector<std::set<NodeTmp>> gm_matchnodes = matches.getMatchNodes(); - - std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matchs; - - for (size_t i = 0; i < nb; ++i) { - matchs.insert(std::make_pair(gm_startnodes[i], gm_matchnodes[i])); - } - - //std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; - std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> toMatchs ; - // Carefull : as the assert is on a vector, the Order of match matters - std::vector<NodeTmp> startNode = {GEMM1}; - std::set<NodeTmp> matchNode = {GEMM1}; - //toMatchs.push_back(std::make_pair(startNode,matchNode)); - toMatchs.insert(std::make_pair(startNode,matchNode)); - - std::vector<NodeTmp> startNode2 = {GEMM2}; - std::set<NodeTmp> matchNode2 = {GEMM2}; - //toMatchs.push_back(std::make_pair(startNode2,matchNode2)); - toMatchs.insert(std::make_pair(startNode2,matchNode2)); - - std::vector<NodeTmp> startNode3 = {GEMM3}; - std::set<NodeTmp> matchNode3 = {GEMM3}; - //toMatchs.push_back(std::make_pair(startNode3,matchNode3)); - toMatchs.insert(std::make_pair(startNode3,matchNode3)); - - REQUIRE(matchs == toMatchs); - REQUIRE(nb == 3); -} - - diff --git a/unit_tests/graphMatching/Test_NodeRegex.cpp b/unit_tests/graphMatching/Test_NodeRegex.cpp deleted file mode 100644 index 2866642bf1355f49a451edffec9e1b62c802ae1f..0000000000000000000000000000000000000000 --- a/unit_tests/graphMatching/Test_NodeRegex.cpp +++ /dev/null @@ -1,44 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> - -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -#include "aidge/operator/GenericOperator.hpp" - - -using namespace Aidge; - -TEST_CASE("Create Noderegex", "[Noderegex]") { - std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("conv"); -} - -TEST_CASE("Test _is function", "[Noderegex]") { - // Create Noderegex with only condition on the name of the Node - // Create several operators to pass into Noderegex _is function - // Assert Noderegex._is(operators) are correct - std::shared_ptr<NodeRegex> nr = std::make_shared<NodeRegex>("Conv"); - - std::shared_ptr<Node> Conv = GenericOperator("Conv", 1, 1, 1); - std::shared_ptr<Node> FC = GenericOperator("FC", 1, 1, 1); - - REQUIRE(nr->_is(Conv) == true); - REQUIRE(nr->_is(FC) == false); - REQUIRE(nr->isA("Conv") == true); - REQUIRE(nr->isA("FC") == false); - -} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_SeqStm.cpp b/unit_tests/graphMatching/Test_SeqStm.cpp deleted file mode 100644 index db8662e3329abe153d4a0fb2b3c46b950208d6bc..0000000000000000000000000000000000000000 --- a/unit_tests/graphMatching/Test_SeqStm.cpp +++ /dev/null @@ -1,167 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/SeqStm.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init SeqStm", "[SeqStm]") { - //init all iniput for SeqStm - - - int stmIdx = 0; - //matrix that in B->C - std::vector<std::vector<int>> transitionMatrix { - { -1, 1, -1 }, - { -1, -1, 2 }, - { -1, -1, -1 } }; - - //std::cout << transitionMatrix.size() << "\n"; - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // - - std::map<NodeTypeKey,int> typeToIdxTransition; - std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; - //init nodeTypeCommonTag - int idx = 0; - for (const NodeTypeKey& key : nodeTypeCommonTag) { - typeToIdxTransition[key] = idx; - idx += 1; - } - - int actSt = 0; - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp,std::string>> allCommonNode; - bool stmIsValid =false; - - - SeqStm stm( - stmIdx, - transitionMatrix, - nodesRegex, - typeToIdxTransition, - actSt, - allNodeValidated, - allNodeTested, - allCommonNode, - stmIsValid); - - REQUIRE(stm.getStmIdx() == 0); - REQUIRE(stm.isValid() == false); - REQUIRE(stm.getAllCommonNode().size() == 0); - REQUIRE(stm.getAllNodeTested().size() == 0); - REQUIRE(stm.getAllNodeValidated().size() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test testNode function", "[SeqStm]") { - - int stmIdx = 0; - std::map<NodeTypeKey,int> typeToIdxTransition; - std::vector<NodeTypeKey> nodeTypeCommonTag {{"A",""},{"B",""},{"C",""}}; - //init nodeTypeCommonTag - int idx = 0; - for (const NodeTypeKey& key : nodeTypeCommonTag) { - typeToIdxTransition[key] = idx; - idx += 1; - } - //matrix that in B->C - std::vector<std::vector<int>> transitionMatrix { - { -1, 1, -1 }, - { -1, -1, 2 }, - { -1, -1, -1 } }; - - //std::cout << transitionMatrix.size() << "\n"; - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - // - int actSt = 0; - std::set<NodeTmp> allNodeValidated; - std::set<NodeTmp> allNodeTested; - std::set<std::pair<NodeTmp,std::string>> allCommonNode; - bool stmIsValid =false; - - SeqStm stm( - stmIdx, - transitionMatrix, - nodesRegex, - typeToIdxTransition, - actSt, - allNodeValidated, - allNodeTested, - allCommonNode, - stmIsValid); - REQUIRE(stm.getStmIdx() == 0); - //test a node - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - stm.testNode(nodeB); - REQUIRE(stm.isValid() == false); - REQUIRE(stm.getState() == 1); - REQUIRE(stm.isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - - stm.testNode(nodeC); - REQUIRE(stm.isValid() == true); - REQUIRE(stm.getState() == 2); - REQUIRE(stm.isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - stm.testNode(nodeC); - REQUIRE(stm.isValid() == true); - REQUIRE(stm.getState() == -1); - REQUIRE(stm.isStmBlocked() == true); - REQUIRE(stm.getAllNodeTested() == testAllNodeTested); - REQUIRE(stm.getAllNodeValidated() == testAllNodeValidated); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} \ No newline at end of file diff --git a/unit_tests/graphMatching/Test_StmFactory.cpp b/unit_tests/graphMatching/Test_StmFactory.cpp deleted file mode 100644 index 3c66d0fa817cea674de5ab849091290c976e5735..0000000000000000000000000000000000000000 --- a/unit_tests/graphMatching/Test_StmFactory.cpp +++ /dev/null @@ -1,204 +0,0 @@ -/******************************************************************************** - * Copyright (c) 2023 CEA-List - * - * This program and the accompanying materials are made available under the - * terms of the Eclipse Public License 2.0 which is available at - * http://www.eclipse.org/legal/epl-2.0. - * - * SPDX-License-Identifier: EPL-2.0 - * - ********************************************************************************/ - -#include <iostream> -#include <map> -#include <memory> -#include <vector> -#include <utility> -#include <cassert> - -#include <catch2/catch_test_macros.hpp> -//test -#include "aidge/graphmatching/StmFactory.hpp" -#include "aidge/graphmatching/NodeRegex.hpp" -//use -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/operator/GenericOperator.hpp" -#include "aidge/operator/Producer.hpp" - -using namespace Aidge; - -TEST_CASE("Create good init StmFactory", "[StmFactory]") { - // init the nodes Regex map - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - StmFactory stmF(nodesRegex); - REQUIRE(stmF.getNumberOfStm() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test in makeNewStm the getStmIdx StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - StmFactory stmF(nodesRegex); - - std::string seq1 = "A->B+->A#;"; - SeqStm* stm = stmF.makeNewStm(seq1); - REQUIRE(stm->getStmIdx() == 0); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getAllCommonNode().size() == 0); - REQUIRE(stm->getAllNodeTested().size() == 0); - REQUIRE(stm->getAllNodeValidated().size() == 0); - - std::string seq2 = "A->B;"; - SeqStm* stm2 = stmF.makeNewStm(seq2); - REQUIRE(stm2->getStmIdx() == 1); - REQUIRE(stm2->isValid() == false); - REQUIRE(stm2->getAllCommonNode().size() == 0); - REQUIRE(stm2->getAllNodeTested().size() == 0); - REQUIRE(stm2->getAllNodeValidated().size() == 0); - - //test the number of stm - REQUIRE(stmF.getNumberOfStm() == 2); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - -TEST_CASE("Test in makeNewStm the stm StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - - StmFactory stmF(nodesRegex); - std::string seq1 = "B->C;"; - SeqStm* stm = stmF.makeNewStm(seq1); - //test the number of stm - REQUIRE(stmF.getNumberOfStm() == 1); - - //std::shared_ptr<Node> nodeB = GenericOperator("B",1,1,1); - //std::shared_ptr<Node> nodeC = GenericiOperator("C",1,1,1); - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 0); - REQUIRE(stm->isStmBlocked() == false); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeB); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 1); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == 2); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == -1); - REQUIRE(stm->isStmBlocked() == true); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } - -} - -TEST_CASE("Test in duplicateStm StmFactory", "[SeqStm]") { - - std::map<std::string,NodeRegex*> nodesRegex ; - std::vector<std::string> nodeTypeKey {"A","B","C"}; - for (const std::string& key : nodeTypeKey) { - nodesRegex[key] = new NodeRegex(key); - } - - - StmFactory stmF(nodesRegex); - std::string seq1 = "B->C;"; - SeqStm* stm = stmF.makeNewStm(seq1); - SeqStm* stmD = stmF.duplicateStm(stm); - - std::shared_ptr<Node> nodeB = GenericOperator("B", 1, 1, 1); - std::shared_ptr<Node> nodeC = GenericOperator("C", 1, 1, 1); - //set use to test the state of the smt - std::set<NodeTmp> testAllNodeTested; - std::set<NodeTmp> testAllNodeValidated; - - //run the stm - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 0); - REQUIRE(stm->isStmBlocked() == false); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeB); - REQUIRE(stm->isValid() == false); - REQUIRE(stm->getState() == 1); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeB); - testAllNodeValidated.insert(nodeB); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == 2); - REQUIRE(stm->isStmBlocked() == false); - testAllNodeTested.insert(nodeC); - testAllNodeValidated.insert(nodeC); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - stm->testNode(nodeC); - REQUIRE(stm->isValid() == true); - REQUIRE(stm->getState() == -1); - REQUIRE(stm->isStmBlocked() == true); - REQUIRE(stm->getAllNodeTested() == testAllNodeTested); - REQUIRE(stm->getAllNodeValidated() == testAllNodeValidated); - - //check if stmD not move - REQUIRE(stmD->isValid() == false); - REQUIRE(stmD->getState() == 0); - REQUIRE(stmD->isStmBlocked() == false); - REQUIRE(stmD->getAllNodeTested().size() == 0); - REQUIRE(stmD->getAllNodeValidated().size() == 0); - - for (const std::string& key : nodeTypeKey) { - delete nodesRegex[key]; - } -} - diff --git a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp index 6143b7e3d8c4331c178afa6267de723cbea7dfdb..ec068358a34567e57c417a664284bd1db76d7a69 100644 --- a/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp +++ b/unit_tests/nodeTester/Test_ConditionalInterpreter.cpp @@ -12,13 +12,38 @@ TEST_CASE("ConditionalInterpreter", "ConditionalInterpreter") { SECTION("custom Lambda") { - const std::string test = " !toto($) == true " ; - ConditionalInterpreter conditionalParser = ConditionalInterpreter("A",test); - conditionalParser.insertLambda("toto",+[](NodePtr NodeOp){return false;}); + + ConditionalInterpreter conditionalParserB = ConditionalInterpreter("A"," bad($) == false "); + ConditionalInterpreter conditionalParserG = ConditionalInterpreter("A"," good($) == true "); + + + conditionalParserB.insertLambda("bad",+[](NodePtr NodeOp){return NodeOp->name() == "ZZ";}); + conditionalParserG.insertLambda("good",+[](NodePtr NodeOp){return NodeOp->name() == "Gop1";}); std::shared_ptr<Node> nodeOp = GenericOperator("conv", 0, 0, 0, "Gop1"); - bool result = conditionalParser.test(nodeOp); - REQUIRE(result == true); + REQUIRE(conditionalParserB.test(nodeOp) == true); + REQUIRE(conditionalParserG.test(nodeOp) == true); + } + + + ConditionalInterpreter conditionalParserT = ConditionalInterpreter("A","isConv($)==true"); + conditionalParserT.insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";}); + std::shared_ptr<Node> zz = GenericOperator("conv", 0, 0, 0, "Gop1"); + conditionalParserT.test(zz); + + SECTION("Lambdas") { + ConditionalInterpreter conditionalParser = ConditionalInterpreter("OP_test","getType($) =='Conv' || getType($) =='FC' "); + + std::shared_ptr<Node> A = GenericOperator("Conv", 0, 0, 0, "A"); + REQUIRE(conditionalParser.test(A) == true); + + std::shared_ptr<Node> B = GenericOperator("FC", 0, 0, 0, "B"); + REQUIRE(conditionalParser.test(B) == true); + + + std::shared_ptr<Node> C = GenericOperator("A", 0, 0, 0, "C"); + conditionalParser.test(C); + REQUIRE(conditionalParser.test(C) == false); } SECTION("syntax error") { diff --git a/unit_tests/recipies/Test_FuseBatchNorm.cpp b/unit_tests/recipies/Test_FuseBatchNorm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..13facefd2979a9b0ca4409ead6972013cb1bc0a8 --- /dev/null +++ b/unit_tests/recipies/Test_FuseBatchNorm.cpp @@ -0,0 +1,70 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +/* +#include <catch2/catch_test_macros.hpp> +#include <set> + + +//#include "aidge/backend/cpu/operator/BatchNormImpl.hpp" +//#include "aidge/backend/cpu/operator/ConvImpl.hpp" + + + +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/BatchNorm.hpp" +#include "aidge/utils/Recipies.hpp" + +//#include "aidge/backend/TensorImpl.hpp" +//#include "aidge/backend/cpu.hpp" +//#include "aidge/" + +#include <cstddef> + + +namespace Aidge { + + + TEST_CASE("[FuseBatchNorm] conv") { + auto g1 = Sequential({ + Producer({16, 3, 224, 224}, "dataProvider"), + Conv(3, 32, {3, 3}, "conv1"), + BatchNorm<2>() + }); + + g1->setDatatype(DataType::Float32); + g1->setBackend("cpu"); + g1->forwardDims(); + + // std::set<std::string> availableBackends = Tensor::getAvailableBackends(); + // if (availableBackends.find("cpu") != availableBackends.end()){ + // g1->setBackend("cpu"); + // newTensor->getImpl()->setRawPtr(static_cast<T*>(info.ptr)); + // }else{ + // printf("Warning : Could not use aidge_cpu backend, verify you have `import aidge_cpu`\n"); + // } + + fuseBatchNorm(g1); + + SECTION("Check resulting nodes") { + // REQUIRE(g1->getNodes().size() == 2); + // REQUIRE(g2->getNode("conv1")->getOperator()->type() == "MaxPooling"); + // REQUIRE(g2->getNode("conv1")->getOperator()->getOutput(0) == g2->getNode("conv2")->getOperator()->getInput(0)); + // REQUIRE(g2->getNode("conv2")->getOperator()->type() == "MaxPooling"); + // REQUIRE(g2->getNode("conv2")->getOperator()->getOutput(0) == g2->getNode("conv3")->getOperator()->getInput(0)); + // REQUIRE(g2->getNode("conv3")->getOperator()->type() == "MaxPooling"); + } + } + +} +*/ \ No newline at end of file diff --git a/unit_tests/recipies/Test_FuseMulAdd.cpp b/unit_tests/recipies/Test_FuseMulAdd.cpp index da53642055a3146c71a211ad7816f21c9b92d6cd..b99de66d3e23377c13ed86526f6c1a318a00e4e8 100644 --- a/unit_tests/recipies/Test_FuseMulAdd.cpp +++ b/unit_tests/recipies/Test_FuseMulAdd.cpp @@ -26,6 +26,7 @@ namespace Aidge { + TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { // generate the original GraphView auto matmul0 = MatMul(5, "matmul0"); @@ -74,4 +75,5 @@ TEST_CASE("[cpu/recipies] FuseMulAdd", "[FuseMulAdd][recipies]") { REQUIRE(((node->type() == "Producer") || (node->type() == "FC"))); } } + } // namespace Aidge \ No newline at end of file