-
Olivier BICHLER authoredOlivier BICHLER authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Matching.hpp 9.30 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_GRAPH_MATCHING_H_
#define AIDGE_CORE_GRAPH_MATCHING_H_
#include <map>
#include <memory>
#include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
namespace Aidge {
/**
* A simple experimental graph matching class which works by direct, single pass
* parse and match, without constructing any intermediate representation.
* Due to its single pass nature, it has some constrains on how the queries must
* be formulated.
*/
class SinglePassGraphMatching {
public:
struct Context {
Context();
Context(const Context&); // explicitly define Context copy constructor
// to avoid automatic inlining
Context& operator=(const Context&);
~Context() noexcept;
std::string query;
bool firstSequence = true;
bool firstNode = true;
bool inSequence = false;
bool lookForChild = true;
bool singleOutput = true;
IOIndex_t edgeLeftIdx = 0;
IOIndex_t edgeRightIdx = 0;
NodePtr startNode;
// For check & debug purpose:
size_t depth = 0;
std::set<std::string> anchors;
};
struct MatchingResult {
// Mutable is required to allow modifying MatchingResult members with a std::set
// iterator. Any change should not modify the set ordering.
// We use graph->rootNode() as the std::set key, which is garanteed
// to never change after insertion!
mutable std::shared_ptr<GraphView> graph;
mutable std::map<std::string, std::map<std::string, NodePtr>> anchors;
mutable NodePtr startNode;
MatchingResult();
MatchingResult(const MatchingResult& other);
MatchingResult& operator=(const MatchingResult& other);
~MatchingResult() noexcept;
};
SinglePassGraphMatching(std::shared_ptr<GraphView> graph) : mGraph(graph) {}
SinglePassGraphMatching(const SinglePassGraphMatching& other);
SinglePassGraphMatching& operator=(const SinglePassGraphMatching& other);
~SinglePassGraphMatching() noexcept;
/**
* Matches a query by direct, single pass parse and match.
* The returned matches are non-ordered and therefore stored in a std::set.
*
* Some rules:
* - The first node of the first sequence is the root node and cannot be optional
* WRONG: Conv?->ReLU (will throw an error)
* GOOD: ReLU<-Conv?
*
* - The first node of any further sequence must be an existing anchor
* (the anchor cannot be in the middle of the sequence)
* WRONG: Conv->ReLU;Pad->Conv (will throw an error)
* Pad->Conv;Conv->ReLU (will throw an error)
* GOOD: Conv#->ReLU;Conv#<-Pad
* Pad->Conv#;Conv#->ReLU
*
* - Any node already matched cannot be matched again (except for anchors)
*
* - By default, an edge matches the first output to the first input.
* EXAMPLE: ReLU->Conv is equivalent to ReLU-0-0>Conv
* To match the second input, use ReLU-0-1>Conv (or ReLU-1>Conv)
* To match the second output, use ReLU-1-0>Conv
* To match any input and/or any output, use *, like ReLU-1-*>Conv
* or ReLU-*-0>Conv or ReLU-*-*>Conv
* The same is true for the "<-" edge syntax.
*
* - When several nodes could match for a given node query, the first one
* not already in the matching result is matched, following the
* childs/parents ordered node list
* EXAMPLE: Producer in "Conv<*-Producer" will match the weights Producer first
* EXAMPLE: Producer in "Conv#<1-.;Conv#<*-Producer" will match the bias Producer
* because the weights Producer has already been matched
*
* - One always matches a sub-graph: additional connections can exist anywhere
* in the matched sub-graph
* EXAMPLE: "Add<*-." will match the Add operator and its first input, any
* additional inputs will not be included in the result
* EXAMPLE: "(Add#<*-.)+" will match the Add operator and all of its inputs
* Note that the anchor is required since we intend to match several
* inputs of the same node!
*
* - In Aidge, a node output can be connected to multiple other nodes. In
* your query, you can allow it or not, with the "~" or "-" modifier.
* EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected
* to a ReLU node at their output #0.
* "Conv~>ReLU" will match all the Conv connected to a ReLU even
* if they are also connected to other nodes at the same output #0.
* When implementing a match & replace recipe, beware that you don't break
* branches in the middle of your matching result if you use "~"!
*
* - The matching results can be overlapping, meaning that some nodes may be
* found in multiple results. Some results may be subsets of other results.
* EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2
* "Conv->ReLU?->Conv?->ReLU?" will return both
* Conv#1->ReLU#1->Conv#2->ReLU#2 and Conv#2->ReLU#2
* To avoid this behavior, set the disjoint argument to true. In this case,
* only Conv#1->ReLU#1->Conv#2->ReLU#2 will be kept in the example above.
*
* - Whitespaces are allowed anywhere in the query
*
* QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))*
*
* @param query The query to search.
* @param disjoint If true, only keep the longuest disjoint (non-overlapping) matches.
* @return std::set<MatchingResult> Set of matches, each stored in a MatchingResult struct.
*/
std::set<MatchingResult> match(const std::string& query, bool disjoint = false);
/**
* @brief Same as match() but with a mandatory start node.
*
* @param startNode Mandatory start node for the query.
* @param query The query to search.
* @return MatchingResult MatchingResult struct, with empty graph if query
* is not found, or the graph corresponding to the query.
*/
MatchingResult matchFrom(NodePtr startNode, const std::string& query);
/**
* Filter to keep only the longuest disjoint (non-overlapping) matches.
*/
std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches);
inline void addNodeLambda(const std::string& name, std::function<bool(const NodePtr&)> func) {
mLambda[name] = func;
}
private:
std::shared_ptr<GraphView> mGraph;
std::map<std::string, std::function<bool(const NodePtr&)>> mLambda;
/**
* QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}')
* NODE_OR_BLOCK = (BLOCK | NODE) QUANTIFIER?
*/
bool matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches);
/**
* BLOCK = '(' SEQ | PAR | ALT | BLOCK | NODE ')'
*/
bool matchBlock(Context& ctx, std::set<MatchingResult>& matches);
/**
* SEQ = NODE_OR_BLOCK (EDGE NODE_OR_BLOCK)+
*/
bool matchSequence(Context& ctx, std::set<MatchingResult>& matches);
/**
* PAR = NODE_OR_BLOCK ('&' NODE_OR_BLOCK)+
*/
bool matchParallel(Context& ctx, std::set<MatchingResult>& matches);
/**
* ALT = NODE_OR_BLOCK ('|' NODE_OR_BLOCK)+
*/
bool matchAlternative(Context& ctx, std::set<MatchingResult>& matches);
/**
* IO_INDEX_ANY = '*'
* IO_INDEX = IO_INDEX_ANY | [0-9]+
* CHILD_EDGE = ('-' | '~') (IO_INDEX '-')? IO_INDEX? '>'
* PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~')
* EDGE = CHILD_EDGE | PARENT_EDGE
*/
bool matchEdge(Context& ctx, std::set<MatchingResult>& matches);
/**
* TYPE = [A-Za-z0-9_]+
* ANCHOR = [A-Za-z0-9_]+
* LAMBDA = [A-Za-z0-9_]+
* NODE = ((TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')?) | '$'
*/
bool matchNode(Context& ctx, std::set<MatchingResult>& matches);
inline void removeWhiteSpace(std::string& str) {
str.erase(str.begin(),
std::find_if(str.begin(),
str.end(),
[](char c) { return !std::isspace(c); }));
}
struct CompareMatchingResultSize {
bool operator()(const MatchingResult& lhs, const MatchingResult& rhs) const {
// Some matches size could be the same
if (lhs.graph->getNodes().size() == rhs.graph->getNodes().size()) {
// In this case, use rootNode which is garanteed to be different!
return lhs.graph->rootNode() < rhs.graph->rootNode();
}
return lhs.graph->getNodes().size() > rhs.graph->getNodes().size();
}
};
};
inline bool operator<(const Aidge::SinglePassGraphMatching::MatchingResult& lhs, const Aidge::SinglePassGraphMatching::MatchingResult& rhs) {
// Matches rootNode are garanteed to be different!
return lhs.graph->rootNode() < rhs.graph->rootNode();
}
} // namespace Aidge
#endif /* AIDGE_CORE_GRAPH_MATCHING_H_ */