Skip to content
Snippets Groups Projects
Commit 5273949f authored by vincent  lorrain's avatar vincent lorrain
Browse files

[GraphRegex] fix

parent 83339ba3
No related branches found
No related tags found
1 merge request!14Graph regex
Pipeline #31923 failed
......@@ -75,10 +75,10 @@ namespace Aidge{
void rmEdge(std::shared_ptr<FsmEdge>);
void addEdge(std::shared_ptr<FsmEdge>);
const std::set<std::shared_ptr<FsmNode>> getChildNodes(void);
//const std::set<std::shared_ptr<FsmNode>> getChildNodes(void);
const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>> getParentNodes(void);
const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>> getEdges(void);
const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>>& getParentNodes(void);
const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& getEdges(void);
void setGroupe(std::size_t groupeIdx);
......
......@@ -55,7 +55,7 @@ namespace Aidge{
* @brief constructor
* @param actState the actual state in the FSM
* @param actOpNode the actual node in the graph
* @param idxRejeced the idx in the global regected node vector init -1 as sentinel value of undefind
* @param idxRejeced the idx in the global regected node vector init max() as sentinel value of undefind
*/
FsmRunTimeContext(std::shared_ptr<FsmNode> actState ,NodePtr actOpNode ,std::size_t idxRejeced =std::numeric_limits<std::size_t>::max() );
FsmRunTimeContext(std::shared_ptr<FsmRunTimeContext> fsmRunTime);
......
......@@ -2,9 +2,12 @@
#define __AIDGE_MATCH_RESULT_H__
#include <memory>
#include <vector>
#include <map>
#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp"
#include "aidge/graphRegex/matchFsm/FsmRunTimeContext.hpp"
#include "aidge/graph/Node.hpp"
namespace Aidge{
......@@ -17,10 +20,20 @@ class MatchResult
private:
/* data */
std::vector<std::shared_ptr<FsmRunTimeContext>> mAllValid;
/*
the Run time of eatch sub FSM , to have a valide match we need a set of one run time per FSM compatible
the id must be contigue
*/
std::vector<std::vector<std::shared_ptr<FsmRunTimeContext>>> mIdToRunTime;
std::vector<std::set<NodePtr>> mSolve;
std::size_t mNbSubStm;
public:
MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid){
mAllValid = allValid;
};
MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm);
virtual ~MatchResult() = default;
/**
......@@ -28,6 +41,10 @@ public:
* @return the set of node of the graph that corresponding to an expression
*/
std::set<NodePtr> getNodes(void);
private:
void _generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence);
};
......
......@@ -62,7 +62,7 @@ FsmGraph::FsmGraph(/* args */){
}
return std::make_shared<MatchResult>(allValidContext);
return std::make_shared<MatchResult>(allValidContext,getNbSubFsm());
}
......
......@@ -72,23 +72,23 @@ void FsmNode::addEdge(std::shared_ptr<FsmEdge> edge){
}
}
const std::set<std::shared_ptr<FsmNode>> FsmNode::getChildNodes(void){
std::set<std::shared_ptr<FsmNode>> children;
for(auto edge : mEdges){
if (auto sharedEdge = edge.lock()) {
children.insert(sharedEdge->getDestNode());
}else{
throw std::runtime_error("getChildNodes FsmNode weak pointer is expired" );
}
}
return children;
}
const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>> FsmNode::getParentNodes(void){
// const std::set<std::shared_ptr<FsmNode>> FsmNode::getChildNodes(void){
// std::set<std::shared_ptr<FsmNode>> children;
// for(auto edge : mEdges){
// if (auto sharedEdge = edge.lock()) {
// children.insert(sharedEdge->getDestNode());
// }else{
// throw std::runtime_error("getChildNodes FsmNode weak pointer is expired" );
// }
// }
// return children;
// }
const std::set<std::weak_ptr<FsmNode>,lex_compare<FsmNode>>& FsmNode::getParentNodes(void){
return mParents;
}
const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>> FsmNode::getEdges(void){
const std::set<std::weak_ptr<FsmEdge>,lex_compare<FsmEdge>>& FsmNode::getEdges(void){
return mEdges;
}
......
......@@ -97,10 +97,11 @@ bool FsmRunTimeContext::areCompatible(std::shared_ptr<FsmRunTimeContext> fsmCont
//valid nodes
std::set<NodePtr> commonElements;
std::set<NodePtr> A = getValidNodesNoCommon();
std::set<NodePtr> B = fsmContext->getValidNodesNoCommon();
std::set_intersection(
getValidNodesNoCommon().begin(), getValidNodesNoCommon().end(),
fsmContext->getValidNodesNoCommon().begin(), fsmContext->getValidNodesNoCommon().end(),
A.begin(),A.end(),
B.begin(), B.end(),
std::inserter(commonElements, commonElements.end())
);
......@@ -187,6 +188,8 @@ std::map<NodePtr,std::size_t> FsmRunTimeContext::getCommon(void){
}
std::set<NodePtr> FsmRunTimeContext::getValidNodes(void){
auto sharedSet = std::make_shared<std::set<NodePtr>>();
// Create a set to store the values from the map
std::set<NodePtr> nodes;
// Iterate over the map and insert values into the set
......@@ -198,7 +201,9 @@ std::set<NodePtr> FsmRunTimeContext::getValidNodes(void){
std::set<NodePtr> FsmRunTimeContext::getValidNodesNoCommon(void){
std::set<NodePtr> differenceSet;
std::set_difference(getValidNodes().begin(), getValidNodes().end(), getCommonNodes().begin(), getCommonNodes().end(),std::inserter(differenceSet, differenceSet.end()));
std::set<NodePtr> valide = getValidNodes();
std::set<NodePtr> common = getCommonNodes();
std::set_difference(valide.begin(), valide.end(), common.begin(), common.end(),std::inserter(differenceSet, differenceSet.end()));
return differenceSet;
}
......
......@@ -2,3 +2,77 @@
using namespace Aidge;
MatchResult::MatchResult(std::vector<std::shared_ptr<FsmRunTimeContext>> allValid, std::size_t nbSubStm):mIdToRunTime(nbSubStm){
mAllValid = allValid;
mNbSubStm = nbSubStm;
//mIdToRunTimm
for (const auto& contextPtr : allValid) {
mIdToRunTime[contextPtr->getSubStmId()].push_back(contextPtr);
}
std::vector<std::shared_ptr<FsmRunTimeContext>> precedence;
_generateCombinationd(0,precedence);
}
void MatchResult::_generateCombinationd( std::size_t idxSubStm, std::vector<std::shared_ptr<FsmRunTimeContext>>& precedence){
//it's end , we are below the number of stm
if (idxSubStm == mNbSubStm)
{
//precedence containe a liste of FSM compatible, we just need to
//check if all the node have been valide by at least one contetext
//1) make the set of all node for the comput graph that are valide in all the FsmRunTimeContext
std::set<NodePtr> validNode;
std::set<NodePtr> rejectNode;
for (const auto& contextPtr : precedence) {
std::set<NodePtr> tmpV = contextPtr->getValidNodes();
validNode.insert(tmpV.begin(), tmpV.end());
std::set<NodePtr> tmpR = contextPtr->getRejectedNodes();
rejectNode.insert(tmpR.begin(),tmpR.end());
}
// 2) all RejectedNodes need to be valide by an others stm
// if it's not the case the match is not valid
if(std::includes(validNode.begin(), validNode.end(), rejectNode.begin(), rejectNode.end())){
//we can save the solution
mSolve.push_back(validNode);
}
precedence.pop_back();
return;
}
for (const auto& contextPtrOneFsm : mIdToRunTime[idxSubStm])
{
if(idxSubStm == 0){
precedence.push_back(contextPtrOneFsm);
_generateCombinationd(idxSubStm+1,precedence);
}else{
//test if the new context is compatible whith all the context in the precedence
//
bool compatibleSolutionFsm = true;
for (const auto& contextPtrOfOtherFsm : precedence) {
if(!(contextPtrOneFsm->areCompatible(contextPtrOfOtherFsm))){
compatibleSolutionFsm = false;
break;
}
}
if(compatibleSolutionFsm){
precedence.push_back(contextPtrOneFsm);
_generateCombinationd(idxSubStm+1,precedence);
}
}
}
if(idxSubStm != 0){
precedence.pop_back();
}
return;
}
......@@ -23,13 +23,13 @@ TEST_CASE("FsmMatch") {
allTest["A"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";});
allTest["B"]->insertLambda("isConv",+[](NodePtr NodeOp){return NodeOp->type() == "Conv";});
//GraphFsmInterpreter("A->B",allTest);
std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A->B",allTest);
std::shared_ptr<GraphFsmInterpreter> fsmGenerator = std::make_shared<GraphFsmInterpreter>("A#->B",allTest);
std::shared_ptr<FsmGraph> fsm = fsmGenerator->interpret();
REQUIRE(fsm->getNodes().size() == 3);
REQUIRE(fsm->getStartNodes().size() == 1);
//REQUIRE(fsm->getNodes().size() == 3);
//REQUIRE(fsm->getStartNodes().size() == 1);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment