Skip to content
Snippets Groups Projects
Commit 0686c4ed authored by vincent  lorrain's avatar vincent lorrain
Browse files

Merge branch 'refactor/recipies' into 'main'

Refactor/recipies

See merge request !48
parents 5686be10 f8884a1e
No related branches found
No related tags found
1 merge request!48Refactor/recipies
Pipeline #34241 passed
Showing
with 201 additions and 1323 deletions
......@@ -20,7 +20,6 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/Match.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_GREGEX_H_
#define AIDGE_GREGEX_H_
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <regex>
#include <memory> // for shared_ptr
#include <algorithm> // for next_permutation
#include "aidge/graphmatching/Utile.hpp"
#include "aidge/graphmatching/StmFactory.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/Match.hpp"
namespace Aidge{
class GRegex {
// __init__(self,nodes_regex:dict,seq_regexps:list)
StmFactory mStmFab;
std::vector<SeqStm*> mStmInit;
public:
GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps );
std::set<NodeTmp> matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch);
bool walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm);
bool walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm);
bool walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm);
std::set<NodeTmp> get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm);
std::vector<SeqStm*> getStmInit() const {
return mStmInit;
}
StmFactory getStmFab() const {
return mStmFab;
}
//std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> match(const std::shared_ptr<GraphView> graphToMatch);
Match match(const std::shared_ptr<GraphView> graphToMatch);
};
}
#endif //AIDGE_GREGEX_H_
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_MATCH_H_
#define AIDGE_MATCH_H_
#include <vector>
#include <set>
#include <iostream>
#include <cassert>
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge{
class Match {
public:
Match();
size_t getNbMatch();
void insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes);
std::vector<std::vector<NodeTmp>> getStartNodes();
std::vector<std::set<NodeTmp>> getMatchNodes();
protected:
std::vector<std::vector<NodeTmp>> mStartNodes;
std::vector<std::set<NodeTmp>> mMatchNodes;
};
}
#endif //AIDGE_MATCH_H_
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_NODEREGEX_H_
#define AIDGE_NODEREGEX_H_
#include <cstdlib>
#include <iostream>
#include <cstring>
#include "aidge/graph/Node.hpp"
namespace Aidge {
class NodeRegex
{
public:
std::string mCondition;
NodeRegex(const std::string c){
mCondition = c;
};
// Version 1 - Only test the type of the node (no need for a lexer)
// Input : Node_op
// Output : bool
// return mCondition == Node_op.type
bool _is(std::shared_ptr<Node> &Node_op);
bool isA(std::string NodeType);
};
}
#endif /* _AIDGE_NODEREGEX_H__ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_SEQSTM_H_
#define AIDGE_SEQSTM_H_
#include <iostream>
#include <map>
#include <regex>
#include <set>
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <string>
#include <utility>
#include <vector>
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge {
class SeqStm {
private:
const int mStmIdx;
const std::vector<std::vector<int>> mTransitionMatrix;
// str key of type like 'A' that ce use in the A->B .. extpr
const std::map<std::string, NodeRegex *> mNodesRegex;
// mTypeToIdxTransition.first = std::pair node_type , common_tag
// mTypeToIdxTransition.segond = idx in trans matrix
const std::map<NodeTypeKey, int> mTypeToIdxTransition;
int mActSt;
std::set<NodeTmp> mAllNodeValidated;
std::set<NodeTmp> mAllNodeTested;
std::set<std::pair<NodeTmp, std::string>> mAllCommonNode;
bool mStmIsValid;
std::pair<NodeRegex *, std::string> getNodeRegexAndCommonAt(int idxType);
/**
* @brief test the stm on a type
* @return the common tag
*/
std::string transitionOnNodeType(NodeType nodeType);
public:
SeqStm(const int mStmIdx,
const std::vector<std::vector<int>> &mTransitionMatrix,
const std::map<std::string, NodeRegex *> &mNodesRegex,
const std::map<NodeTypeKey, int> &mTypeToIdxTransition, int mActSt,
std::set<NodeTmp> mAllNodeValidated, std::set<NodeTmp> mAllNodeTested,
std::set<std::pair<NodeTmp, std::string>> mAllCommonNode,
bool mStmIsValid);
//////////////////////////////////////
// STM test
/////////////////////////////////////
/**
* @brief get if a st is a valide one
* @return bool
*/
bool isAValidSt(int st) {
std::size_t size = mTransitionMatrix.size();
return st == static_cast<int>(size - 1) ? true : false;
}
/**
* @brief true if the stm is blocked into st
* @return bool
*/
bool isStmBlocked() { return mActSt == -1 ? true : false; }
/**
* @brief true if the stm into valide st
* @return bool
*/
bool isValid() { return mStmIsValid; }
/////////////////////////////////////
// utile
/////////////////////////////////////
/**
* @brief extract from a node is type
* @return bool
*/
NodeType getTheNodeType(NodeTmp node);
void drawStm();
/////////////////////////////////////
// geter
/////////////////////////////////////
std::set<std::pair<NodeTmp, std::string>> getAllCommonNode() {
return mAllCommonNode;
}
std::set<NodeTmp> getAllNodeTested() { return mAllNodeTested; }
std::set<NodeTmp> getAllNodeValidated() { return mAllNodeValidated; }
SeqStm *duplicateStm();
int getStmIdx() { return mStmIdx; }
int getState() { return mActSt; }
//////////////////////////////////////////
// USE
//////////////////////////////////////////
/**
* @brief test the stm on a node
* @return pair new stm state, the common tag
*/
std::pair<int, std::string> testNode(const NodeTmp node);
};
} // namespace Aidge
#endif /* AIDGE_SEQSTM_H_ */
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_STMFACTORY_H_
#define AIDGE_STMFACTORY_H_
#include <map>
#include <utility>
#include <set>
#include <string>
#include <vector>
#include <iostream>
#include <stdexcept> // for exception, runtime_error, out_of_range
#include <regex>
#include "aidge/graphmatching/NodeRegex.hpp"
#include "aidge/graphmatching/SeqStm.hpp"
#include "aidge/graphmatching/Utile.hpp"
namespace Aidge{
class StmFactory {
const std::map<std::string,NodeRegex*>& mNodesRegex;
std::size_t mCmptStm = 0;
public:
StmFactory(const std::map<std::string,NodeRegex*>& nodesRegex);
//StmFactory(){};
SeqStm* makeNewStm(const std::string& sequRegex);
SeqStm* duplicateStm(SeqStm* stm);
std::size_t getNumberOfStm(){
return mCmptStm;
}
private:
ParsingReturn initParsingSequRegex(const std::string& sequRegex);
std::vector<std::vector<int>> initTransitionMatrix(ParsingReturn& parsing);
};
}
#endif //AIDGE_STMFACTORY_H_
\ No newline at end of file
/**
* @file
* @brief
* @version file 1.0.0
* @author vl241552
* @copyright
* Copyright (c) 2023 CEA, LIST, Embedded Artificial Intelligence Laboratory.
* All rights reserved.
*/
#ifndef _utile_H_
#define _utile_H_
#include <map>
#include "aidge/graph/Node.hpp"
#include <map>
namespace Aidge {
using NodeTmp = std::shared_ptr<Node>;
using NodeType = std::string;
using CommonTag = std::string;
using NodeTypeKey = std::pair<NodeType, CommonTag>;
// type def
// struct NodeTypeKey {
// NodeType nodeType;
// std::string commonTag;
// // for map find
// bool operator<(const NodeTypeKey& other) const {
// if (nodeType != other.nodeType or commonTag != other.commonTag) {
// return false;
// } else {
// return true;
// }
// }
// };
struct ParsingReturn {
std::map<NodeTypeKey, int> typeToIdxTransition;
std::vector<std::pair<NodeTypeKey, std::string>> transition;
};
} // namespace Aidge
#endif //_utile_H_
\ No newline at end of file
......@@ -22,7 +22,7 @@ namespace Aidge{
/////////////////////////////
/**
* @brief class used to register any lambda function without context,
* it encapsulates the source lambda in a lambda which takes as argument ConditionalData* which are any type.
* it encapsulates the source lambda in a lambda which takes as argument std::shared_ptr<ConditionalData> which are any type.
* @see ConditionalData
*/
class ConditionalRegisterFunction {
......@@ -31,12 +31,12 @@ class ConditionalRegisterFunction {
//////////////////////////
/**
* @brief recast the ConditionalData* to the argument type of the lambda
* @brief recast the std::shared_ptr<ConditionalData> to the argument type of the lambda
* @tparam T type of the lambda argument
* @see ConditionalData
*/
template <typename T>
T safeCastInput(ConditionalData* data) {
T safeCastInput( std::shared_ptr<ConditionalData> data) {
//cnvertion and type cheking
if (data->isTypeEqualTo<T>()){
return data->getValue<T>();
......@@ -48,14 +48,14 @@ class ConditionalRegisterFunction {
/**
* @brief recaste the output of the lambda to a ConditionalData*
* @brief recaste the output of the lambda to a std::shared_ptr<ConditionalData>
* @tparam T type of the lambda return
* @see ConditionalData
*/
template <typename T>
ConditionalData* safeCastOutput(T data) {
std::shared_ptr<ConditionalData> safeCastOutput(T data) {
ConditionalData* out = new ConditionalData;
std::shared_ptr<ConditionalData> out = std::make_shared<ConditionalData>();
out->setValue<T>(data);
return out;
......@@ -111,11 +111,11 @@ class ConditionalRegisterFunction {
};
/////////////////////
//change the function to ConditionalData*(std::vector<ConditionalData*>)
//change the function to std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>)
/////////////////////
/**
* @brief Converts a function to a ConditionalData*(std::vector<ConditionalData*>).
* @brief Converts a function to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>).
* @tparam F The type of the function to convert.
* @tparam ParamsIdx The indices of the function parameters.
* @param f The function to convert.
......@@ -124,25 +124,31 @@ class ConditionalRegisterFunction {
template <class F, std::size_t... ParamsIdx>
auto funcPointer(F f, std::index_sequence<ParamsIdx...>) {
//wrapp the lambda in a new one that as ConditionalData as inputs and output
return [this,f](std::vector<ConditionalData*> &args) {
if (args.size() != sizeof...(ParamsIdx)){
return [this,f](std::vector< std::shared_ptr<ConditionalData>> &args) {
if (args.size() < sizeof...(ParamsIdx)){
std::ostringstream errorMessage;
errorMessage << "bad Number of argument: get " << args.size() << " need " << sizeof...(ParamsIdx) << "\n";
throw std::runtime_error(errorMessage.str());
}
//assert(args.size() == sizeof...(ParamsIdx));//the size of the vector valide
//we used std::vector< std::shared_ptr<ConditionalData>> as a fifo
std::size_t offset = args.size()-sizeof...(ParamsIdx);
using FuncTraits = function_traits<decltype(f)>;
using outType = typename FuncTraits::return_type;
outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[ParamsIdx])...);
outType result = f(safeCastInput<typename FuncTraits::template argument<ParamsIdx>::type>(args[offset+ParamsIdx])...);
//suppress what we used
for (size_t i = 0; i < sizeof...(ParamsIdx); ++i) {
args.pop_back();
}
//typename
return safeCastOutput<outType>(result);
};
}
/**
* @brief Converts a function pointer to a ConditionalData*(std::vector<ConditionalData*>).
* @brief Converts a function pointer to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>).
* @tparam R The return type of the function.
* @tparam Params The parameter types of the function.
* @param f The function pointer to convert.
......@@ -154,7 +160,7 @@ class ConditionalRegisterFunction {
}
/**
* @brief Converts a std::function to a ConditionalData*(std::vector<ConditionalData*>).
* @brief Converts a std::function to a std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>>).
* @tparam R The return type of the function.
* @tparam Params The parameter types of the function.
* @param f The function pointer to convert.
......@@ -196,7 +202,7 @@ class ConditionalRegisterFunction {
* @param datas The vector of input data.
* @return A pointer to the output ConditionalData object.
*/
ConditionalData* run(const std::string key,std::vector<ConditionalData*> & datas);
std::shared_ptr<ConditionalData> run(const std::string key,std::vector< std::shared_ptr<ConditionalData>> & datas);
bool isLambdaRegister(const std::string &key) {
if(mWlambda.find(key) != mWlambda.end()){
......@@ -207,7 +213,7 @@ class ConditionalRegisterFunction {
private:
/// @brief map of name and the converted function.
std::map<const std::string, std::function<ConditionalData*(std::vector<ConditionalData*> &)>> mWlambda;
std::map<const std::string, std::function< std::shared_ptr<ConditionalData>(std::vector< std::shared_ptr<ConditionalData>> &)>> mWlambda;
};
///////////////////
......@@ -237,15 +243,15 @@ class ConditionalInterpreter
ConditionalRegisterFunction mLambdaRegister;
std::vector<ConditionalData*> mResolution ;
std::vector< std::shared_ptr<ConditionalData>> mResolution ;
void clearRes(){
// void clearRes(){
for (std::size_t i = 0; i < mResolution.size(); ++i) {
delete mResolution[i];
}
mResolution.clear();
}
// for (std::size_t i = 0; i < mResolution.size(); ++i) {
// delete mResolution[i];
// }
// mResolution.clear();
// }
public:
......@@ -258,7 +264,7 @@ class ConditionalInterpreter
ConditionalInterpreter(const std::string key,const std::string ConditionalExpressions);
~ConditionalInterpreter(){clearRes();}
~ConditionalInterpreter(){}
/**
* @brief get the condition key
......@@ -293,12 +299,12 @@ class ConditionalInterpreter
* @param NodeOp The node currently being tested
* @param nodes The AST given by the parsing process
*/
std::vector<ConditionalData*> visit(const ASTNodeCh& nodes, const NodePtr NodeOp );
std::vector< std::shared_ptr<ConditionalData>> visit(const ASTNodeCh& nodes, const NodePtr NodeOp );
/**
* @defgroup ASTnodeInterpreterF Functions for interpreting AST nodes
* @brief For each node type in the AST, function defines the processing to be performed
* they return a std::vector<ConditionalData*> which corresponds to the value(s) obtained
* they return a std::vector< std::shared_ptr<ConditionalData>> which corresponds to the value(s) obtained
*/
/**
......@@ -308,38 +314,38 @@ class ConditionalInterpreter
void fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node);
/**
* @ingroup ASTnodeInterpreterF
* @brief Converted the lexeme to a int and to ConditionalData*
* @brief Converted the lexeme to a int and to std::shared_ptr<ConditionalData>
*/
void fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node);
/**
* @ingroup ASTnodeInterpreterF
* @brief Converted the lexeme to a float and to ConditionalData*
* @brief Converted the lexeme to a float and to std::shared_ptr<ConditionalData>
*/
void fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node);
/**
* @ingroup ASTnodeInterpreterF
* @brief Converted the lexeme to a str and to ConditionalData*
* @brief Converted the lexeme to a str and to std::shared_ptr<ConditionalData>
*/
void fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node);
/**
* @ingroup ASTnodeInterpreterF
* @brief makes the == operation between two previously converted ConditionalData*
* @brief makes the == operation between two previously converted std::shared_ptr<ConditionalData>
*/
void fEq(void);
/**
* @ingroup ASTnodeInterpreterF
* @brief makes the != operation between two previously converted ConditionalData*
* @brief makes the != operation between two previously converted std::shared_ptr<ConditionalData>
*/
void fNeq(void);
/**
* @ingroup ASTnodeInterpreterF
* @brief makes the && operation between two previously converted ConditionalData* in bool
* @brief makes the && operation between two previously converted std::shared_ptr<ConditionalData> in bool
*/
void fAnd(void);
/**
* @ingroup ASTnodeInterpreterF
* @brief makes the || operation between two previously converted ConditionalData* in bool
* @brief makes the || operation between two previously converted std::shared_ptr<ConditionalData> in bool
*/
void fOr(void);
......
......@@ -17,6 +17,8 @@
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graphRegex/matchFsm/MatchResult.hpp"
namespace Aidge{
......@@ -27,7 +29,12 @@ namespace Aidge{
*
* @param nodes Strict set of Node to merge.
*/
void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
void fuseMulAdd(std::shared_ptr<MatchSolution> solution);
void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
/**
* @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
*
......@@ -43,7 +50,11 @@ void fuseMulAdd(std::shared_ptr<GraphView> graphView);
*
* @param nodes Strict set of Node to merge.
*/
void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
void removeFlatten(std::shared_ptr<Node> flatten);
void removeFlatten(std::shared_ptr<MatchSolution> solution);
/**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
*
......@@ -59,7 +70,12 @@ void removeFlatten(std::shared_ptr<GraphView> graphView);
*
* @param nodes Strict set of Node to merge.
*/
void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes);
void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm);
void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
/**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
......
......@@ -10,33 +10,60 @@
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphRegex/GraphRegex.hpp"
namespace py = pybind11;
namespace Aidge {
void init_GRegex(py::module& m){
py::class_<GRegex, std::shared_ptr<GRegex>>(m, "GRegex", "GRegex class combines a Node Regex and a list of Graph Regex that together describes a graph pattern as a graph regular expression. GRegex find patterns in a given graph that matches the graph regular expression.")
.def(py::init<const std::map<std::string,NodeRegex*>&, std::vector<std::string>&>(), py::arg("nodesRegex"), py::arg("seqRegexps"), R"mydelimiter(
Constructor of GRegex
void init_GraphRegex(py::module& m){
:param nodesRegex: Describe the conditions an operator has to fulfill.
:type nodesRegex: Dict[str,:py:class:`aidge_core.NodeRegex`]
:param seqRegexps: Describe the graph topological pattern. List of Graph Regex as strings.
:type seqRegexps: List[str]
py::class_<GraphRegex, std::shared_ptr<GraphRegex>>(m, "GraphRegex", "GraphRegex class describes a regex to test a graph.")
.def(py::init<>())
.def("add_query", &GraphRegex::addQuery, R"mydelimiter(
:rtype: str
)mydelimiter")
.def("match", &GRegex::match, py::arg("graphToMatch"), R"mydelimiter(
Launch the graph matching algorithm on a given graph.
.def("set_key_from_graph", &GraphRegex::setKeyFromGraph, R"mydelimiter(
:param ref: The graph use to define type of Node.
:type ref: :py:class:`aidge_core.GraphView`
)mydelimiter")
// void setNodeKey(const std::string key, const std::string conditionalExpressions );
// void setNodeKey(const std::string key,std::function<bool(NodePtr)> f);
.def("match", &GraphRegex::match, R"mydelimiter(
:param graphToMatch: The graph to perform the matching algorithm on.
:type graphToMatch: :py:class:`aidge_core.GraphView`
)mydelimiter")
:returns: Matched graph patterns.
:rtype: :py:class:`aidge_core.Match`
.def("set_node_key",
(void (GraphRegex::*)(const std::string, const std::string )) &
GraphRegex::setNodeKey,
py::arg("key"), py::arg("conditionalExpressions"),
R"mydelimiter(
Add a node test
:param key: the key of the node test to use in the query.
:param conditionalExpressions: the test to do .
)mydelimiter")
.def("set_node_key",
(void (GraphRegex::*)(const std::string, std::function<bool(NodePtr)>)) &
GraphRegex::setNodeKey,
py::arg("key"), py::arg("f"),
R"mydelimiter(
Add a node test
:param key: the key of the lambda test to use in the conditional expressions.
:param f: bool lambda (nodePtr) .
)mydelimiter")
;
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/graphmatching/Match.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Match(py::module& m){
py::class_<Match, std::shared_ptr<Match>>(m, "Match", "Match class stores the matched patterns resulting from a graph matching query. A matched pattern is the combinaison of the graph pattern start nodes and the set of all the nodes in the matched pattern (including the start nodes)")
.def(py::init<>())
.def("get_nb_match", &Match::getNbMatch, R"mydelimiter(
:returns: The number of graph patterns matched
:rtype: int
)mydelimiter")
.def("get_start_nodes", &Match::getStartNodes, R"mydelimiter(
:returns: All matched graph patterns start nodes
:rtype: List[List[:py:class:`aidge_core.Nodes`]]
)mydelimiter")
.def("get_match_nodes", &Match::getMatchNodes, R"mydelimiter(
:returns: All matched graph patterns sets of matched nodes
:rtype: List[Set[:py:class:`aidge_core.Nodes`]]
)mydelimiter");
}
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/graphmatching/NodeRegex.hpp"
namespace py = pybind11;
namespace Aidge {
void init_NodeRegex(py::module& m){
py::class_<NodeRegex, std::shared_ptr<NodeRegex>>(m, "NodeRegex", "NodeRegex class describes a condition to test on any operator. Current version only supports testing the type of the operator.")
.def(py::init<const std::string>(), py::arg("condition"), R"mydelimiter(
Constructor of NodeRegex
:param condition: Condition to be fulfilled by an operator.
:type condition: str
)mydelimiter")
;
}
}
......@@ -45,9 +45,7 @@ void init_GraphView(py::module&);
void init_OpArgs(py::module&);
void init_Connector(py::module&);
void init_Match(py::module&);
void init_NodeRegex(py::module&);
void init_GRegex(py::module&);
void init_GraphRegex(py::module&);
void init_Recipies(py::module&);
......@@ -87,9 +85,8 @@ void init_Aidge(py::module& m){
init_Sub(m);
init_Producer(m);
init_Match(m);
init_NodeRegex(m);
init_GRegex(m);
init_GraphRegex(m);
init_Recipies(m);
init_Scheduler(m);
init_TensorUtils(m);
......
......@@ -28,12 +28,13 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
......@@ -41,18 +42,20 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator.
:param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator.
:param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
// :param nodes: The flatten operator to remove.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
......@@ -60,11 +63,12 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator.
// m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator.
:param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
// :param nodes: The flatten operator to remove.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graph/GraphView.hpp"
using namespace Aidge;
GRegex::GRegex(const std::map<std::string,NodeRegex*>& nodesRegex,std::vector<std::string>& seqRegexps ):mStmFab(nodesRegex){
//setup all the STM
for (const std::string& sequRegex : seqRegexps) {
mStmInit.push_back(mStmFab.makeNewStm(sequRegex));
}
}
bool GRegex::walk_validation_all_stm_are_valid(const std::vector<std::vector<SeqStm*>> all_stm){
//test if all stm type are in a valid state
std::vector<int> number_of_valid;
number_of_valid.resize(all_stm.size());
for (std::size_t i = 0; i < all_stm.size(); ++i) {
number_of_valid[i] = 0;
for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) {
SeqStm* stm = *it;
if (stm->isValid()){
number_of_valid[i] +=1;
}
}
}
for (std::size_t i = 0; i < number_of_valid.size(); ++i) {
if (number_of_valid[i] == 0) {
//std::cout << "NO MATCH at least one stm are not valid" << std::endl;
return false;
}
if (number_of_valid[i] > 1) {
//std::cout << "NO MATCH multiple brach match of stm (// quantification)" << std::endl;
return false;
}
}
return true;
}
bool GRegex::walk_validation_all_node_read_validate_by_one_stm(const std::vector<std::vector<SeqStm*>> all_stm){
std::set<NodeTmp> all_stm_node_tested;
std::set<NodeTmp> all_stm_node_validated;
for (std::size_t i = 0; i < all_stm.size(); ++i) {
//std::cout << "all stm index " << i << " on dimension 1 of size " << all_stm.size() <<std::endl;
for (std::size_t j = 0; j < all_stm[i].size(); ++j) {
//std::cout << "all stm index " << j << " on dimension 2 of size " << all_stm[i].size() <<std::endl;
std::set<NodeTmp> stm_node_tested = all_stm[i][j]->getAllNodeTested();
std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated();
all_stm_node_tested.insert(stm_node_tested.begin(), stm_node_tested.end());
all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end());
}
}
std::set<NodeTmp> test_but_not_valid;
for (const auto& x : all_stm_node_tested) {
if (all_stm_node_validated.find(x) == all_stm_node_validated.end()) {
test_but_not_valid.insert(x);
}
}
if (!test_but_not_valid.empty()) {
std::cout << "NO MATCH. The node(s) ";
for (const auto& x : test_but_not_valid) {
std::cout << x.get() << ", ";
}
std::cout << " have been tested but not validated." << std::endl;
return false;
}
return true;
}
bool GRegex::walk_validation_common_nodes_same_tag_for_all_stm(const std::vector<std::vector<SeqStm*>> all_stm){
std::map<NodeTmp, std::pair<std::string,int>> node_to_common_tag;
for (std::size_t i = 0; i < all_stm.size(); ++i) {
for (auto it = all_stm[i].begin(); it != all_stm[i].end(); ++it) {
SeqStm* stm = *it;
if (!stm->isValid()){
continue;
}
for (const auto& pair : stm->getAllCommonNode()) {
const NodeTmp node = pair.first;
const std::string common_tag = pair.second;
if (node_to_common_tag.find(node) != node_to_common_tag.end()) {
std::string tag = node_to_common_tag[node].first;
int& occurence = node_to_common_tag[node].second;
if (tag!=common_tag){
std::cout << "NO MATCH. The node " << node << " have two different tags "<< tag << " and " << common_tag << std::endl;
return false;
} else {
occurence += 1;
}
} else {
node_to_common_tag.insert(std::make_pair(node, std::make_pair(common_tag, 1)));
}
}
}
}
/*std::cout << "Node to common tag ";
for (const auto& x : node_to_common_tag) {
std::cout << "(" << x.first << ", " << "[" << x.second.first << ", " << x.second.second << "]" << ") ; ";
}
std::cout << std::endl;*/
for (const auto& pair : node_to_common_tag) {
const std::pair<std::string, int> tag_occurence_pair = pair.second;
if (tag_occurence_pair.second < 1){
//std::cout << "NO MATCH. The common tag " << tag_occurence_pair.first << " did not match " << std::endl;
return false;
}
}
return true;
}
std::set<NodeTmp> GRegex::get_all_validate_nodes(const std::vector<std::vector<SeqStm*>> all_stm){
std::set<NodeTmp> all_stm_node_validated;
for (std::size_t i = 0; i < all_stm.size(); ++i) {
for (std::size_t j = 0; j < all_stm[i].size(); ++j) {
std::set<NodeTmp> stm_node_validated = all_stm[i][j]->getAllNodeValidated();
all_stm_node_validated.insert(stm_node_validated.begin(), stm_node_validated.end());
}
}
return all_stm_node_validated;
}
std::set<NodeTmp> GRegex::matchFromStartNodes(const std::vector<NodeTmp> startNodes,const std::shared_ptr<GraphView> graphToMatch){
std::set<NodeTmp> empty_set_return;
//ASSERT
if(startNodes.size() != mStmInit.size()){
throw std::runtime_error ("bad GRegex start nodes");
}
//init the walk
std::vector<std::vector<SeqStm*>> allStm;
std::vector<std::pair<NodeTmp,SeqStm*>> currentWalk;
for (SeqStm* seqStmPtr : mStmInit) {
SeqStm* newStm = mStmFab.duplicateStm(seqStmPtr);
std::size_t idxStart = newStm->getStmIdx();
currentWalk.push_back(std::make_pair(startNodes[idxStart],newStm));
allStm.push_back(std::vector<SeqStm*>());
}
//walk
while (currentWalk.size()!=0)
{
std::vector<std::pair<NodeTmp,SeqStm*>> newWalk;
for (const auto& pair : currentWalk) {
const NodeTmp node = pair.first;
SeqStm* stmPtr = pair.second;
std::pair<int,std::string> test = stmPtr->testNode(node);
int res = test.first;
std::string commonTag = test.second;
std::set<NodeTmp> next_nodes = graphToMatch->getChildren(node);
/*std::cout << "Next nodes : " ;
for (const auto& x : next_nodes) {
std::cout << x->name() << ", ";
}
std::cout << std::endl;*/
// Test Match
if (commonTag == "" && next_nodes.size() > 1) {
std::cout << "NO MATCH. The node " << node.get() << " is not common and has more than one child" << std::endl;
return empty_set_return;
}
// If there is no more nodes --> Archive the branch
if (res == -1 || next_nodes.empty()) {
int indexToInsert = stmPtr->getStmIdx();
allStm[indexToInsert].push_back(stmPtr);
//std::cout << "No more nodes --> STM archived : " << indexToInsert << std::endl;
continue; // TODEV : replace this with 'else' that encapsulate the rest of the function ?
}
bool first = true;
// Use an iterator to read through the next_nodes
std::set<NodeTmp>::iterator it;
for (it = next_nodes.begin(); it != next_nodes.end(); ++it) {
// Access the current element using the iterator
std::shared_ptr<Aidge::Node> next_node = *it;
if (first){
newWalk.push_back(std::make_pair(next_node, stmPtr));
first = false;
} else {
SeqStm* new_stmPtr = mStmFab.duplicateStm(stmPtr);
newWalk.push_back(std::make_pair(next_node, new_stmPtr));
}
}
}
currentWalk = newWalk;
}
//std::cout << "Walk finished" << std::endl;
if (!walk_validation_all_stm_are_valid(allStm)){
return empty_set_return;
}
//std::cout << "walk_validation_all_stm_are_valid finished" << std::endl;
if (!walk_validation_all_node_read_validate_by_one_stm(allStm)){
return empty_set_return;
}
//std::cout << "walk_validation_all_node_read_validate_by_one_stm finished" << std::endl;
if (!walk_validation_common_nodes_same_tag_for_all_stm(allStm)){
return empty_set_return;
}
//std::cout << "walk_validation_common_nodes_same_tag_for_all_stm finished" << std::endl;
//std::cout << "MATCH" << std::endl;
return get_all_validate_nodes(allStm);
}
Match GRegex::match(const std::shared_ptr<GraphView> graphToMatch){
//std::vector<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches;
//std::set<std::pair<std::vector<NodeTmp>,std::set<NodeTmp>>> matches;
Match matches;
std::size_t nbStartNodes = mStmInit.size();
std::set<NodeTmp> allNodes = graphToMatch->getNodes();
std::size_t nbAllNodes = allNodes.size();
std::vector<std::size_t> indices(nbStartNodes, 0);
while (true) {
// Generate all permutations of the current combination
do {
std::vector<NodeTmp> startNodes;
//std::cout <<"start nodes :";
for (std::size_t i = 0; i < nbStartNodes; ++i) {
auto it = std::begin(allNodes);
std::advance(it, indices[i]);
//std::cout << (*it).get() << " ";
startNodes.push_back(*it);
}
//std::cout <<"\n";
std::set<NodeTmp> match = matchFromStartNodes(startNodes, graphToMatch);
//std::cout << "match size : " << match.size() << " ";
if(match.size() != 0){
//matches.push_back(std::make_pair(startNodes,match));
//matches.insert(std::make_pair(startNodes,match));
matches.insert(startNodes,match);
}
} while (std::next_permutation(indices.begin(), indices.end()));
// Generate the next combination with replacement
std::size_t i = nbStartNodes - 1;
while (true) {
if (indices[i] < nbAllNodes - 1) {
++indices[i];
break;
}
if (i == 0) {
return matches;
}
--i;
}
std::fill(indices.begin() + i + 1, indices.end(), indices[i]);
}
return matches;
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/Match.hpp"
using namespace Aidge;
Match::Match(){
//ctr
}
size_t Match::getNbMatch(){
assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted");
return mStartNodes.size();
}
void Match::insert(std::vector<NodeTmp> startnodes, std::set<NodeTmp> matchnodes){
assert(mStartNodes.size() == mMatchNodes.size() && "Match corrupted");
mStartNodes.push_back(startnodes);
mMatchNodes.push_back(matchnodes);
}
std::vector<std::vector<NodeTmp>> Match::getStartNodes(){
return mStartNodes;
}
std::vector<std::set<NodeTmp>> Match::getMatchNodes(){
return mMatchNodes;
}
\ No newline at end of file
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/NodeRegex.hpp"
// Verification done by the Attribute system
// Version 1 - Only test the type of the node (no need for a lexer)
// Input : Node_op
// Output : bool
// return mCondition == Node_op.type
bool Aidge::NodeRegex::_is(std::shared_ptr<Node> &Node_op){
std::string NodeType = Node_op->type();
return strcmp(NodeType.c_str(), mCondition.c_str()) == 0;
}
bool Aidge::NodeRegex::isA(std::string NodeType){
return strcmp(NodeType.c_str(), mCondition.c_str()) == 0;
}
// Version 2 - Test the node to an advanced condition
// Input : Node_op
// Output : bool
// return mCondition applied on Node
/**bool NodeRegex::_is(string &Node_op){
// Parsing the condition is done in the initialization of the NodeRegex
// assert attributes exist in the node with the attribute function hasAttr()
// get the attributes
}*/
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/SeqStm.hpp"
using namespace Aidge;
///////////////////////////////////////////////////////
SeqStm::SeqStm(
const int stmIdx,
const std::vector<std::vector<int>>& transitionMatrix,
const std::map<std::string,NodeRegex*>& nodesRegex,
const std::map<NodeTypeKey,int>& typeToIdxTransition,
int actSt,
std::set<NodeTmp> allNodeValidated,
std::set<NodeTmp> allNodeTested,
std::set<std::pair<NodeTmp,std::string>> allCommonNode,
bool stmIsValid):mStmIdx(stmIdx),
mTransitionMatrix(transitionMatrix),
mNodesRegex(nodesRegex),
mTypeToIdxTransition(typeToIdxTransition)
{
//assert
if (transitionMatrix.size() == 0){
throw std::runtime_error ("no transitionMatrix");
}
if(transitionMatrix[0].size() == 0 || transitionMatrix[0].size() != typeToIdxTransition.size()){
throw std::runtime_error ("bad transitionMatrix");
}
int size = static_cast<int>(transitionMatrix.size());
if (actSt >= size){
throw std::runtime_error ("bad actSt");
}
mActSt = actSt;
mAllNodeValidated = allNodeValidated;
mAllNodeTested = allNodeTested;
mAllCommonNode = allCommonNode;
mStmIsValid = stmIsValid;
}
SeqStm* SeqStm::duplicateStm(){
//deep copy of the set
// std::set<Node> cAllNodeValidated(mAllNodeValidated.begin(), mAllNodeValidated.end());
// std::set<Node> cAllNodeTested(mAllNodeTested.begin(), mAllNodeTested.end());
// std::set<std::pair<Node,std::string>> cAllCommonNode;
// for (const auto& p : mAllCommonNode) {
// cAllCommonNode.insert(p);
// }
auto newStm = new SeqStm(
mStmIdx,
mTransitionMatrix,
mNodesRegex,
mTypeToIdxTransition,
mActSt,
mAllNodeValidated,
mAllNodeTested,
mAllCommonNode,
mStmIsValid
);
return newStm;
}
std::pair<NodeRegex*,std::string> SeqStm::getNodeRegexAndCommonAt(int idxType)
{
//std::cout << "!" << idxType << "\n";
for (auto const& x : mTypeToIdxTransition)
{
//x.second is the value : idx in mTransitionMatrix for the type
//x.first pair of the node regex class and a string that is the common tag '',#,#n
if (x.second == idxType ){
if (mNodesRegex.find(x.first.first) != mNodesRegex.end()){
return std::make_pair(mNodesRegex.find(x.first.first)->second, x.first.second);
}else{
throw std::runtime_error ("a type is not define in NodesRegex");
}
}
}
throw std::runtime_error ("bad idx in mNodesRegex");
return std::make_pair(nullptr,nullptr);
}
NodeType SeqStm::getTheNodeType(NodeTmp node)
{
//the node is a str of '{type}{idx}' and we juste want type
// // std::regex re("([a-zA-Z]+)[0-9]+");
// // std::smatch match;
// // if (std::regex_search(node, match, re) == true) {
// // return match.str(1);
// // }
// // throw std::runtime_error ("Type node not found");
// // return "";
//return node->name();
return node->type();
}
std::string SeqStm::transitionOnNodeType(NodeType nodeType){
if (!isStmBlocked()){
int idxType = 0;
for (auto & nextSt : mTransitionMatrix[mActSt]) {
// There are a next step for this type
//std::cout << "transition matrix next state -> "<< nextSt<<"\n" ;
if (nextSt != -1){
//std::cout << "next -> "<< nextSt<< " "<< isAValidSt(nextSt) <<"\n" ;
auto nodeRegex = getNodeRegexAndCommonAt(idxType);
//std::cout << "-> "<< nodeRegex.second<<"\n" ;
if (nodeRegex.first->isA(nodeType)){
//std::cout << "nodetype tested !"<<"\n" ;
if(isAValidSt(nextSt)){
//std::cout << "Valid state !"<<"\n" ;
mStmIsValid = true;
}
mActSt = nextSt;
return nodeRegex.second;
}
}
idxType += 1;
}
mActSt =-1;
}
return "";
}
std::pair<int,std::string> SeqStm::testNode(const NodeTmp node){
std::string commonTag = "";
//std::cout << "0\n" ;
if (!isStmBlocked()){
bool isNextStEnd = std::all_of(mTransitionMatrix[mActSt].begin(), mTransitionMatrix[mActSt].end(), [&](int x){ return x == -1; });
//std::cout << "1:"<< isNextStEnd <<"\n" ;
//if the next state if full of -1 can we relay add the node test to all node tested
// oker y test it but it sure that not be valid
if(!isNextStEnd){
mAllNodeTested.insert(node);
}
//std::cout << "2\n" ;
//recurtion avoidance
if(mAllNodeValidated.find(node) == mAllNodeValidated.end()){
NodeType nodeType = getTheNodeType(node);
//std::cout << "3 " << nodeType << "\n" ;
commonTag = transitionOnNodeType(nodeType);
//after the transition test, if the node is != -1 the node is valid for the stm
//std::cout << " mActSt = " << mActSt << "\n" ;
if( mActSt != -1 ){
mAllNodeValidated.insert(node);
}
}else{
mActSt = -1;
}
}
if(commonTag != ""){
mAllCommonNode.insert(std::make_pair(node,commonTag));
}
return std::make_pair(mActSt,commonTag);
}
void SeqStm::drawStm(){
//mTransitionMatrix
// Find the maximum width of each column
std::vector<std::size_t> max_widths(mTransitionMatrix[0].size(), 0);
for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i)
{
for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j)
{
std::size_t width = std::to_string(mTransitionMatrix[i][j]).length();
if (width > max_widths[j])
{
max_widths[j] = width;
}
}
}
// Print the vector with aligned columns
for (std::size_t i = 0; i < mTransitionMatrix.size(); ++i)
{
for (std::size_t j = 0; j < mTransitionMatrix[i].size(); ++j)
{
int i_int = static_cast<int>(i);
if (mActSt == -1 ){
if(mStmIsValid){
std::cout << "\033[48;5;40m";
}else{
std::cout << "\033[48;5;9m";
}
}
else if (mActSt == i_int){
std::cout << "\033[48;5;30m";
}else{
std::cout << "\033[48;5;27m";
}
// Pad the value with spaces to align it with the maximum width
std::size_t width = std::to_string(mTransitionMatrix[i][j]).length();
std::string padding(max_widths[j] - width, ' ');
std::cout << padding << mTransitionMatrix[i][j] << " ";
std::cout << "\033[0m";
}
std::cout << "\n";
}
std::cout << "mAllNodeTested : ";
for (const auto& x : mAllNodeTested) {
std::cout << x << ", ";
}
std::cout << "\n";
std::cout << "mAllNodeValidated : ";
for (const auto& x : mAllNodeValidated) {
std::cout << x << ", ";
}
std::cout << "\n";
}
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/graphmatching/StmFactory.hpp"
using namespace Aidge;
StmFactory::StmFactory(const std::map<std::string, NodeRegex *> &nodesRegex)
: mNodesRegex(nodesRegex) {}
SeqStm *StmFactory::duplicateStm(SeqStm *stm) { return stm->duplicateStm(); }
SeqStm *StmFactory::makeNewStm(const std::string &sequRegex) {
ParsingReturn parsing = initParsingSequRegex(sequRegex);
std::vector<std::vector<int>> transitionMatrix =
initTransitionMatrix(parsing);
std::set<NodeTmp> allNodeValidated;
std::set<NodeTmp> allNodeTested;
std::set<std::pair<NodeTmp, std::string>> allCommonNode;
SeqStm *newStm = new SeqStm(static_cast<int>(mCmptStm), transitionMatrix, mNodesRegex,
parsing.typeToIdxTransition, 0, allNodeValidated,
allNodeTested, allCommonNode, false);
mCmptStm += 1;
return newStm;
}
ParsingReturn StmFactory::initParsingSequRegex(const std::string &sequRegex) {
std::string toMatch;
std::regex re("\\s*([A-Za-z]+)(#\\d*)?([+*])?\\s*(->|;)");
std::smatch matches;
int idxType = 0;
// return
ParsingReturn parsing;
// std::map<std::pair<NodeType,std::string>,int> typeToIdxTransition;
// std::vector<std::pair<std::pair<NodeType,std::string>,std::string>>
// transition;
// assert
std::map<NodeType, std::string> assertCommonNodeTypes;
for (std::size_t i = 0; i < sequRegex.length(); i++) {
toMatch += sequRegex[i];
if (std::regex_match(toMatch, matches, re)) {
std::string type = matches.str(1);
std::string commonTag = matches.str(2);
std::string quantification = matches.str(3);
if ((commonTag != "") && (quantification != "")) {
throw std::runtime_error("bad commonTag and quantification");
}
// make the typeToIdxTransition
NodeTypeKey typeTag = std::make_pair(type, commonTag);
/*std::cout << " typeTag: " << type << " " << commonTag
<< parsing.typeToIdxTransition.size() << std::endl;*/
if (parsing.typeToIdxTransition.find(typeTag) ==
parsing.typeToIdxTransition.end()) {
parsing.typeToIdxTransition[typeTag] = idxType;
idxType += 1;
}
////////////////////////////////////////////////////////////
// ASSERT
// SAME Common node in the sequ
if (commonTag != "") {
if (assertCommonNodeTypes.find(type) != assertCommonNodeTypes.end()) {
if (assertCommonNodeTypes[type] == commonTag) {
throw std::runtime_error("same common node in the sequ regex");
}
} else {
assertCommonNodeTypes[type] = commonTag;
}
}
// save all transition
parsing.transition.push_back(std::make_pair(typeTag, quantification));
/*std::cout << "Match found: " << matches.str() << std::endl;
std::cout << "Type: " << matches.str(1) << std::endl;
std::cout << "Common tag: " << matches.str(2) << std::endl;
std::cout << "Quantification: " << matches.str(3) << std::endl;*/
toMatch = "";
}
}
if (parsing.transition.size() == 0) {
throw std::runtime_error("Bad Parsing SequRegex ");
}
return parsing;
}
std::vector<std::vector<int>>
StmFactory::initTransitionMatrix(ParsingReturn &parsing) {
// std::pair<NodeTypeKey,std::string>
std::vector<std::vector<int>> transitionMatrix;
std::size_t numberOfType = parsing.typeToIdxTransition.size();
if (numberOfType == 0) {
throw std::runtime_error("Bad number Of Type ");
}
// init start st
transitionMatrix.push_back(std::vector<int>(numberOfType, -1));
std::size_t idxTransition = 0;
int idxState = 0;
for (const auto &pair : parsing.transition) {
const NodeTypeKey &nodeTypeKey = pair.first;
const std::string &quant = pair.second;
/*std::cout << "Key: {" << nodeTypeKey.first << ", " << nodeTypeKey.second
<< "}, Value: " << quant << std::endl;
std::cout << "idxState " << idxState << " TM: " << transitionMatrix.size()
<< std::endl;*/
std::size_t idxType = parsing.typeToIdxTransition[nodeTypeKey];
/*std::cout << "idxType " << idxType << " TM: " << transitionMatrix[0].size()
<< "type" << numberOfType << std::endl;*/
if (quant == "*") {
transitionMatrix[idxTransition][idxType] = idxState;
} else if (quant == "+") {
idxState += 1;
transitionMatrix[idxTransition][idxType] = idxState;
transitionMatrix.push_back(std::vector<int>(numberOfType, -1));
idxTransition += 1;
transitionMatrix[idxTransition][idxType] = idxState;
} else {
idxState += 1;
transitionMatrix[idxTransition][idxType] = idxState;
transitionMatrix.push_back(std::vector<int>(numberOfType, -1));
idxTransition += 1;
}
}
return transitionMatrix;
}
\ No newline at end of file
......@@ -8,7 +8,7 @@ using namespace Aidge;
//ConditionalRegisterFunction
///////////////////////////////
ConditionalData* ConditionalRegisterFunction::run(const std::string key,std::vector<ConditionalData*> & datas){
std::shared_ptr<ConditionalData> ConditionalRegisterFunction::run(const std::string key,std::vector< std::shared_ptr<ConditionalData>> & datas){
auto lambdaIt = mWlambda.find(key);
if (lambdaIt != mWlambda.end()) {
......@@ -45,10 +45,9 @@ using namespace Aidge;
bool ConditionalInterpreter::test( const NodePtr nodeOp)
{
clearRes();
mResolution.clear();
try{
std::vector<ConditionalData*> r = visit({mTree},nodeOp);
std::vector< std::shared_ptr<ConditionalData>> r = visit({mTree},nodeOp);
if (mResolution.size() != 1){
throw std::runtime_error("Multi output interpretation output");
......@@ -72,8 +71,8 @@ using namespace Aidge;
}
/////
std::vector<ConditionalData*> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){
std::vector<ConditionalData*> dataVector;
std::vector< std::shared_ptr<ConditionalData>> ConditionalInterpreter::visit(const ASTNodeCh& nodes, const NodePtr nodeOp ){
std::vector< std::shared_ptr<ConditionalData>> dataVector;
for ( std::shared_ptr<AstNode<ConditionalTokenTypes>> node : nodes) {
try{
......@@ -140,7 +139,7 @@ using namespace Aidge;
case ConditionalTokenTypes::NODE: //TODO
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<NodePtr>(nodeOp);
mResolution.push_back(data);
......@@ -157,7 +156,7 @@ using namespace Aidge;
case ConditionalTokenTypes::BOOL: //TODO
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
if(node->getValue() == "true"){
data->setValue<bool>(true);
......@@ -195,7 +194,8 @@ using namespace Aidge;
void ConditionalInterpreter::fStrToInteger(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node)
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<int>(std::stoi(node->getValue()));
mResolution.push_back(data);
}
......@@ -203,14 +203,14 @@ using namespace Aidge;
void ConditionalInterpreter::fStrToFloat(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node)
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<float>(std::stof(node->getValue()));
mResolution.push_back(data);
}
void ConditionalInterpreter::fStrToStr(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node)
{
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<std::string>(node->getValue());
mResolution.push_back(data);
}
......@@ -218,7 +218,7 @@ using namespace Aidge;
void ConditionalInterpreter::fLambda(const std::shared_ptr<AstNode<ConditionalTokenTypes>>& node)
{
//if the lambda have input
ConditionalData* data;
std::shared_ptr<ConditionalData> data;
try {
data = mLambdaRegister.run(node->getValue(),mResolution);
} catch (const std::exception& e) {
......@@ -227,17 +227,20 @@ using namespace Aidge;
throw std::runtime_error(errorMessage.str());
}
clearRes();
//clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fEq(void)
{
if (mResolution.size() != 2){
if (mResolution.size() < 2){
throw std::runtime_error("EQ need 2 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto b = mResolution[1];
auto a = mResolution.back();
mResolution.pop_back();
auto b = mResolution.back();
mResolution.pop_back();
if (a->getType() != b->getType()){
throw std::runtime_error("EQ Unsupported between type :" + a->getType() +" "+ b->getType());
......@@ -245,7 +248,7 @@ using namespace Aidge;
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
if (a->isTypeEqualTo<int>()) {
data->setValue<bool>( a->getValue<int>() == b->getValue<int>());
......@@ -259,23 +262,25 @@ using namespace Aidge;
throw std::runtime_error("EQ Unknown type encountered :" + a->getType() );
}
clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fNeq(void)
{
if (mResolution.size() != 2){
if (mResolution.size() < 2){
throw std::runtime_error("NEQ need 2 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto b = mResolution[1];
auto a = mResolution.back();
mResolution.pop_back();
auto b = mResolution.back();
mResolution.pop_back();
if (a->getType() != b->getType()){
throw std::runtime_error("NEQ Unsupported between type :" + a->getType() +" "+ b->getType());
}
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
if (a->isTypeEqualTo<int>()) {
data->setValue<bool>( a->getValue<int>() != b->getValue<int>());
......@@ -288,67 +293,72 @@ using namespace Aidge;
throw std::runtime_error("NEQ Unknown type encountered :" + a->getType() );
}
clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fAnd(void)
{
if (mResolution.size() != 2){
if (mResolution.size() < 2){
throw std::runtime_error("AND need 2 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto b = mResolution[1];
auto a = mResolution.back();
mResolution.pop_back();
auto b = mResolution.back();
mResolution.pop_back();
if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){
throw std::runtime_error("AND Unknown type encountered need bool get :" + a->getType() );
}
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<bool>( a->getValue<bool>() && b->getValue<bool>());
clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fOr(void)
{
if (mResolution.size() != 2){
if (mResolution.size() < 2){
throw std::runtime_error("OR need 2 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto b = mResolution[1];
auto a = mResolution.back();
mResolution.pop_back();
auto b = mResolution.back();
mResolution.pop_back();
if (a->getType() != typeid(bool).name() || b->getType() != typeid(bool).name()){
throw std::runtime_error("OR Unknown type encountered need bool get :" + a->getType() );
}
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<bool>( a->getValue<bool>() || b->getValue<bool>());
clearRes();
mResolution.push_back(data);
}
void ConditionalInterpreter::fNot()
{
if (mResolution.size() != 1){
if (mResolution.size() < 1){
throw std::runtime_error("NOT need 1 arg and get :" + std::to_string(mResolution.size()));
}
auto a = mResolution[0];
auto a = mResolution.back();
mResolution.pop_back();
if (a->getType() != typeid(bool).name()){
throw std::runtime_error("NOT Unknown type encountered need bool get :" + a->getType() );
}
ConditionalData* data = new ConditionalData;
std::shared_ptr<ConditionalData> data = std::make_shared<ConditionalData>();
data->setValue<bool>( !a->getValue<bool>() );
clearRes();
mResolution.push_back(data);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment