Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Matching.hpp 9.30 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#ifndef AIDGE_CORE_GRAPH_MATCHING_H_
#define AIDGE_CORE_GRAPH_MATCHING_H_

#include <map>
#include <memory>
#include <set>

#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"

namespace Aidge {
/**
 * A simple experimental graph matching class which works by direct, single pass
 * parse and match, without constructing any intermediate representation.
 * Due to its single pass nature, it has some constrains on how the queries must
 * be formulated.
*/
class SinglePassGraphMatching {
public:
    struct Context {
        Context();
        Context(const Context&); // explicitly define Context copy constructor
                                 // to avoid automatic inlining
        Context& operator=(const Context&);
        ~Context() noexcept;

        std::string query;
        bool firstSequence = true;
        bool firstNode = true;
        bool inSequence = false;
        bool lookForChild = true;
        bool singleOutput = true;
        IOIndex_t edgeLeftIdx = 0;
        IOIndex_t edgeRightIdx = 0;
        NodePtr startNode;

        // For check & debug purpose:
        size_t depth = 0;
        std::set<std::string> anchors;
    };

    struct MatchingResult {
        // Mutable is required to allow modifying MatchingResult members with a std::set
        // iterator. Any change should not modify the set ordering.
        // We use graph->rootNode() as the std::set key, which is garanteed
        // to never change after insertion!
        mutable std::shared_ptr<GraphView> graph;
        mutable std::map<std::string, std::map<std::string, NodePtr>> anchors;
        mutable NodePtr startNode;

        MatchingResult();

        MatchingResult(const MatchingResult& other);
        MatchingResult& operator=(const MatchingResult& other);
        ~MatchingResult() noexcept;
    };

    SinglePassGraphMatching(std::shared_ptr<GraphView> graph) : mGraph(graph) {}
    SinglePassGraphMatching(const SinglePassGraphMatching& other);
    SinglePassGraphMatching& operator=(const SinglePassGraphMatching& other);
    ~SinglePassGraphMatching() noexcept;

    /**
     * Matches a query by direct, single pass parse and match.
     * The returned matches are non-ordered and therefore stored in a std::set.
     *
     * Some rules:
     * - The first node of the first sequence is the root node and cannot be optional
     *   WRONG: Conv?->ReLU (will throw an error)
     *   GOOD: ReLU<-Conv?
     *
     * - The first node of any further sequence must be an existing anchor
     *   (the anchor cannot be in the middle of the sequence)
     *   WRONG: Conv->ReLU;Pad->Conv (will throw an error)
     *          Pad->Conv;Conv->ReLU (will throw an error)
     *   GOOD: Conv#->ReLU;Conv#<-Pad
     *         Pad->Conv#;Conv#->ReLU
     *
     * - Any node already matched cannot be matched again (except for anchors)
     *
     * - By default, an edge matches the first output to the first input.
     *   EXAMPLE: ReLU->Conv is equivalent to ReLU-0-0>Conv
     *            To match the second input, use ReLU-0-1>Conv (or ReLU-1>Conv)
     *            To match the second output, use ReLU-1-0>Conv
     *            To match any input and/or any output, use *, like ReLU-1-*>Conv
     *            or ReLU-*-0>Conv or ReLU-*-*>Conv
     *            The same is true for the "<-" edge syntax.
     *
     * - When several nodes could match for a given node query, the first one
     *   not already in the matching result is matched, following the
     *   childs/parents ordered node list
     *   EXAMPLE: Producer in "Conv<*-Producer" will match the weights Producer first
     *   EXAMPLE: Producer in "Conv#<1-.;Conv#<*-Producer" will match the bias Producer
     *            because the weights Producer has already been matched
     *
     * - One always matches a sub-graph: additional connections can exist anywhere
     *   in the matched sub-graph
     *   EXAMPLE: "Add<*-." will match the Add operator and its first input, any
     *            additional inputs will not be included in the result
     *   EXAMPLE: "(Add#<*-.)+" will match the Add operator and all of its inputs
     *            Note that the anchor is required since we intend to match several
     *            inputs of the same node!
     *
     * - In Aidge, a node output can be connected to multiple other nodes. In
     *   your query, you can allow it or not, with the "~" or "-" modifier.
     *   EXAMPLE: "Conv->ReLU" will match the Conv that are **only** connected
     *            to a ReLU node at their output #0.
     *            "Conv~>ReLU" will match all the Conv connected to a ReLU even
     *            if they are also connected to other nodes at the same output #0.
     *   When implementing a match & replace recipe, beware that you don't break
     *   branches in the middle of your matching result if you use "~"!
     *
     * - The matching results can be overlapping, meaning that some nodes may be
     *   found in multiple results. Some results may be subsets of other results.
     *   EXAMPLE: assume graph Conv#1->ReLU#1->Conv#2->ReLU#2
     *            "Conv->ReLU?->Conv?->ReLU?" will return both
     *            Conv#1->ReLU#1->Conv#2->ReLU#2 and Conv#2->ReLU#2
     *   To avoid this behavior, set the disjoint argument to true. In this case,
     *   only Conv#1->ReLU#1->Conv#2->ReLU#2 will be kept in the example above.
     *
     * - Whitespaces are allowed anywhere in the query
     *
     * QUERY = SEQ | NODE_OR_BLOCK (';' (SEQ | NODE_OR_BLOCK))*
     *
     * @param query The query to search.
     * @param disjoint If true, only keep the longuest disjoint (non-overlapping) matches.
     * @return std::set<MatchingResult> Set of matches, each stored in a MatchingResult struct.
    */
    std::set<MatchingResult> match(const std::string& query, bool disjoint = false);

    /**
     * @brief Same as match() but with a mandatory start node.
     * 
     * @param startNode Mandatory start node for the query.
     * @param query The query to search.
     * @return MatchingResult MatchingResult struct, with empty graph if query
     * is not found, or the graph corresponding to the query.
     */
    MatchingResult matchFrom(NodePtr startNode, const std::string& query);

    /**
     * Filter to keep only the longuest disjoint (non-overlapping) matches.
    */
    std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches);

    inline void addNodeLambda(const std::string& name, std::function<bool(const NodePtr&)> func) {
        mLambda[name] = func;
    }

private:
    std::shared_ptr<GraphView> mGraph;
    std::map<std::string, std::function<bool(const NodePtr&)>> mLambda;

    /**
     * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}')
     * NODE_OR_BLOCK = (BLOCK | NODE) QUANTIFIER?
    */
    bool matchNodeOrBlock(Context& ctx, std::set<MatchingResult>& matches);

    /**
     * BLOCK = '(' SEQ | PAR | ALT | BLOCK | NODE ')'
    */
    bool matchBlock(Context& ctx, std::set<MatchingResult>& matches);

    /**
     * SEQ = NODE_OR_BLOCK (EDGE NODE_OR_BLOCK)+
    */
    bool matchSequence(Context& ctx, std::set<MatchingResult>& matches);

    /**
     * PAR = NODE_OR_BLOCK ('&' NODE_OR_BLOCK)+
    */
    bool matchParallel(Context& ctx, std::set<MatchingResult>& matches);

    /**
     * ALT = NODE_OR_BLOCK ('|' NODE_OR_BLOCK)+
    */
    bool matchAlternative(Context& ctx, std::set<MatchingResult>& matches);

    /**
     * IO_INDEX_ANY = '*'
     * IO_INDEX = IO_INDEX_ANY | [0-9]+
     * CHILD_EDGE = ('-' | '~') (IO_INDEX '-')? IO_INDEX? '>'
     * PARENT_EDGE = '<' (IO_INDEX '-')? IO_INDEX? ('-' | '~')
     * EDGE = CHILD_EDGE | PARENT_EDGE
    */
    bool matchEdge(Context& ctx, std::set<MatchingResult>& matches);

    /**
     * TYPE = [A-Za-z0-9_]+
     * ANCHOR = [A-Za-z0-9_]+
     * LAMBDA = [A-Za-z0-9_]+
     * NODE = ((TYPE | '.') ('#' ANCHOR)? ('[' LAMBDA ']')?) | '$'
    */
    bool matchNode(Context& ctx, std::set<MatchingResult>& matches);

    inline void removeWhiteSpace(std::string& str) {
        str.erase(str.begin(),
            std::find_if(str.begin(),
                        str.end(),
                        [](char c) { return !std::isspace(c); }));
    }

    struct CompareMatchingResultSize {
        bool operator()(const MatchingResult& lhs, const MatchingResult& rhs) const {
            // Some matches size could be the same
            if (lhs.graph->getNodes().size() == rhs.graph->getNodes().size()) {
                // In this case, use rootNode which is garanteed to be different!
                return lhs.graph->rootNode() < rhs.graph->rootNode();
            }

            return lhs.graph->getNodes().size() > rhs.graph->getNodes().size();
        }
    };
};

inline bool operator<(const Aidge::SinglePassGraphMatching::MatchingResult& lhs, const Aidge::SinglePassGraphMatching::MatchingResult& rhs) {
    // Matches rootNode are garanteed to be different!
    return lhs.graph->rootNode() < rhs.graph->rootNode();
}
}  // namespace Aidge

#endif /* AIDGE_CORE_GRAPH_MATCHING_H_ */