Skip to content
Snippets Groups Projects

Alternative graph matching

Merged Olivier BICHLER requested to merge experimental into dev
3 files
+ 1292
0
Compare changes
  • Side-by-side
  • Inline
Files
3
+ 225
0
 
/********************************************************************************
 
* Copyright (c) 2023 CEA-List
 
*
 
* This program and the accompanying materials are made available under the
 
* terms of the Eclipse Public License 2.0 which is available at
 
* http://www.eclipse.org/legal/epl-2.0.
 
*
 
* SPDX-License-Identifier: EPL-2.0
 
*
 
********************************************************************************/
 
 
#ifndef AIDGE_CORE_GRAPH_MATCHING_H_
 
#define AIDGE_CORE_GRAPH_MATCHING_H_
 
 
#include <map>
 
#include <memory>
 
#include <set>
 
 
#include "aidge/graph/Node.hpp"
 
#include "aidge/graph/GraphView.hpp"
 
 
namespace Aidge {
 
/**
 
* A simple experimental graph matching class which works by direct, single pass
 
* parse and match, without constructing any intermediate representation.
 
* Due to its single pass nature, it has some constrains on how the queries must
 
* be formulated.
 
*/
 
class SinglePassGraphMatching {
 
public:
 
struct Context {
 
std::string query;
 
bool firstSequence = true;
 
bool firstNode = true;
 
bool inSequence = false;
 
bool lookForChild = true;
 
bool singleOutput = true;
 
IOIndex_t edgeLeftIdx = 0;
 
IOIndex_t edgeRightIdx = 0;
 
 
// For check & debug purpose:
 
size_t depth = 0;
 
std::set<std::string> anchors;
 
};
 
 
struct MatchingResult {
 
// Mutable is required to allow modifying MatchingResult members with a std::set
 
// iterator. Any change should not modify the set ordering.
 
// We use graph->rootNode() as the std::set key, which is garanteed
 
// to never change after insertion!
 
mutable std::shared_ptr<GraphView> graph;
 
mutable std::map<std::string, std::map<std::string, NodePtr>> anchors;
 
mutable NodePtr startNode;
 
 
MatchingResult() {
 
graph = std::make_shared<GraphView>();
 
}
 
 
MatchingResult(const MatchingResult& result) {
 
graph = std::make_shared<GraphView>(*(result.graph.get()));
 
anchors = result.anchors;
 
startNode = result.startNode;
 
}
 
 
MatchingResult& operator=(const MatchingResult& result) {
 
graph = std::make_shared<GraphView>(*(result.graph.get()));
 
anchors = result.anchors;
 
startNode = result.startNode;
 
return *this;
 
}
 
};
 
 
SinglePassGraphMatching(std::shared_ptr<GraphView> graph) : mGraph(graph) {}
 
 
/**
 
* Matches a query by direct, single pass parse and match.
 
* The returned matches are non-ordered and therefore stored in a std::set.
 
*
 
* Some rules:
 
* - The first node of the first sequence is the root node and cannot be optional
 
* WRONG: Conv?->ReLU (will throw an error)
 
* GOOD: ReLU<-Conv?
 
*
 
* - The first node of any further sequence must be an existing anchor
 
* (the anchor cannot be in the middle of the sequence)
 
* WRONG: Conv->ReLU;Pad->Conv (will throw an error)
 
* Pad->Conv;Conv->ReLU (will throw an error)
 
* GOOD: Conv#->ReLU;Conv#<-Pad
 
* Pad->Conv#;Conv#->ReLU
 
*
 
* - Any node already matched cannot be matched again (except for anchors)
 
*
 
* - By default, an edge matches the first output to the first input.
 
* EXAMPLE: ReLU->Conv is equivalent to ReLU-0-0>Conv
 
* To match the second input, use ReLU-0-1>Conv (or ReLU-1>Conv)
 
* To match the second output, use ReLU-1-0>Conv
 
* To match any input and/or any output, use *, like ReLU-1-*>Conv
 
* or ReLU-*-0>Conv or ReLU-*-*>Conv
 
* The same is true for the "<-" edge syntax.
 
*
 
* - When several nodes could match for a given node query, the first one
 
* not already in the matching result is matched, following the
 
* childs/parents ordered node list
 
* EXAMPLE: Producer in "Conv<*-Producer" will match the weights Producer first
 
* EXAMPLE: Producer in "Conv#<1-.;Conv#<*-Producer" will match the bias Producer
 
* because the weights Producer has already been matched
 
*
 
* - One always matches a sub-graph: additional connections can exist anywhere
 
* in the matched sub-graph
 
* EXAMPLE: "Add<*-." will match the Add operator and its first input, any
 
* additional inputs will not be included in the result
 
* EXAMPLE: "(Add#<*-.)+" will match the Add operator and all of its inputs
 
* Note that the anchor is required since we intend to match several
 
* inputs of the same node!
 
*
 
* - In Aidge, a node output can be connected to multiple other nodes. In
 
* your query, you can allow it or not, with the "~" or "-" modifier.
 
* EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected
 
* to a ReLU node at their output #0.
 
* "Conv~>ReLU" will match all the Conv connected to a ReLU even
 
* if they are also connected to other nodes at the same output #0.
 
* When implementing a match & replace recipe, beware that you don't break
 
* branches in the middle of your matching result if you use "~"!
 
*
 
* - The matching results can be overlapping, meaning that some nodes may be
 
* found in multiple results. Some results may be subsets of other results.
 
* EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2
 
* "Conv->ReLU?->Conv?->ReLU?" will return both
 
* Conv#1->ReLU#1->Conv#2->ReLU#2 and Conv#2->ReLU#2
 
* To avoid this behavior, set the disjoint argument to true. In this case,
 
* only Conv#1->ReLU#1->Conv#2->ReLU#2 will be kept in the example above.
 
*
 
* - Whitespaces are allowed anywhere in the query
 
*
 
* QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))*
 
*
 
* @param query The query to search.
 
* @param disjoint If true, only keep the longuest disjoint (non-overlapping) matches.
 
* @return Set of matches, each stored in a MatchingResult struct.
 
*/
 
std::set<MatchingResult> match(const std::string& query, bool disjoint = false);
 
 
/**
 
* Filter to keep only the longuest disjoint (non-overlapping) matches.
 
*/
 
std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches);
 
 
inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) {
 
mLambda[name] = func;
 
}
 
 
private:
 
std::shared_ptr<GraphView> mGraph;
 
std::map<std::string, bool(*)(const NodePtr&)> mLambda;
 
 
/**
 
* QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}')
 
* NODE_OR_BLOCK = (BLOCK | NODE) QUANTIFIER?
 
*/
 
bool matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches);
 
 
/**
 
* BLOCK = '(' SEQ | PAR | BLOCK | ALT | NODE ')'
 
*/
 
bool matchBlock(Context& ctx, std::set<MatchingResult>& matches);
 
 
/**
 
* SEQ = NODE_OR_BLOCK (EDGE NODE_OR_BLOCK)+
 
*/
 
bool matchSequence(Context& ctx, std::set<MatchingResult>& matches);
 
 
/**
 
* PAR = NODE_OR_BLOCK ('&' NODE_OR_BLOCK)+
 
*/
 
bool matchParallel(Context& ctx, std::set<MatchingResult>& matches);
 
 
/**
 
* ALT = NODE_OR_BLOCK ('|' NODE_OR_BLOCK)+
 
*/
 
bool matchAlternative(Context& ctx, std::set<MatchingResult>& matches);
 
 
/**
 
* IO_INDEX_ANY = '*'
 
* IO_INDEX = IO_INDEX_ANY | [0-9]+
 
* CHILD_EDGE = ('-' | '~') (IO_INDEX '-')? IO_INDEX? '>'
 
* PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~')
 
* EDGE = CHILD_EDGE | PARENT_EDGE
 
*/
 
bool matchEdge(Context& ctx, std::set<MatchingResult>& matches);
 
 
/**
 
* TYPE = [A-Za-z0-9_]+
 
* ANCHOR = [A-Za-z0-9_]+
 
* LAMBDA = [A-Za-z0-9_]+
 
* NODE = (TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')?
 
*/
 
bool matchNode(Context& ctx, std::set<MatchingResult>& matches);
 
 
inline void removeWhiteSpace(std::string& str) {
 
str.erase(str.begin(),
 
std::find_if(str.begin(),
 
str.end(),
 
[](char c) { return !std::isspace(c); }));
 
}
 
 
struct CompareMatchingResultSize {
 
bool operator()(const MatchingResult& lhs, const MatchingResult& rhs) const {
 
// Some matches size could be the same
 
if (lhs.graph->getNodes().size() == rhs.graph->getNodes().size()) {
 
// In this case, use rootNode which is garanteed to be different!
 
return lhs.graph->rootNode() < rhs.graph->rootNode();
 
}
 
 
return lhs.graph->getNodes().size() > rhs.graph->getNodes().size();
 
}
 
};
 
};
 
 
inline bool operator<(const Aidge::SinglePassGraphMatching::MatchingResult& lhs, const Aidge::SinglePassGraphMatching::MatchingResult& rhs) {
 
// Matches rootNode are garanteed to be different!
 
return lhs.graph->rootNode() < rhs.graph->rootNode();
 
}
 
} // namespace Aidge
 
 
#endif /* AIDGE_CORE_GRAPH_MATCHING_H_ */
Loading