Skip to content
Snippets Groups Projects
Commit 450b1329 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added simplify_graph()

parent f5574480
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!219Initial version of hybrid C++/Python static analysis
Pipeline #59030 failed
import numpy as np
import aidge_core
def simplify_graph(graph: aidge_core.GraphView):
"""
Simplify a graph loaded from ONNX.
:param graph: The GraphView to simplify.
:type graph: aidge_core.GraphView
"""
def check_constant_producer(value):
def _check_constant_producer(node):
out = node.get_operator().get_output(0)
return (len(out) == 1 and np.isclose(out[0], value))
return _check_constant_producer
gm = aidge_core.SinglePassGraphMatching(graph)
gm.add_node_lambda("Constant_sqrt2", check_constant_producer(np.sqrt(2)))
gm.add_node_lambda("Constant_1", check_constant_producer(1))
gm.add_node_lambda("Constant_0_5", check_constant_producer(0.5))
# Linear [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "MatMul-*>Add", "Linear")
# LayerNorm [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "ReduceMean-*>Sub#1~>(Pow#1->ReduceMean-*>Add#1->Sqrt)-*>Div#1-*>Mul#1-*>Add#2;"
"Sub#1~*>Div#1;"
"Pow#1<1~Producer;"
"Add#1<*~Producer;"
"Mul#1<*~Producer;"
"Add#2<*~Producer;"
"Sub#1~>$", "LayerNorm")
# ScaledDotProductAttention [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "MatMul->Div#1->Softmax-*>MatMul;"
"Div#1<1~Producer", "ScaledDotProductAttention")
# MultiHeadAttention [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "ScaledDotProductAttention#1->Transpose->Reshape#1->Linear;"
"Reshape#1<1~Producer;"
"ScaledDotProductAttention#1<0-(Transpose<-Reshape#2<-Add#1);"
"ScaledDotProductAttention#1<1-(Transpose<-Reshape#3<-Add#2);"
"ScaledDotProductAttention#1<2-(Transpose<-Reshape#4<-Add#3);"
"Reshape#2<1~Producer;"
"Add#1<*-0-Split#1;"
"Add#2<*-1-Split#1;"
"Add#3<*-2-Split#1;"
"Split#1<-MatMul;"
"Split#1<1~Producer", "MultiHeadAttention")
# GeLU [from PyTorch ONNX]
aidge_core.fuse_to_metaops(gm, "Div#1->Erf->Add#1-*>Mul->Mul#2;"
"Div#1<1~Producer[Constant_sqrt2];"
"Add#1<*~Producer[Constant_1];"
"Mul#2<*~Producer[Constant_0_5]", "GeLU")
...@@ -154,13 +154,13 @@ public: ...@@ -154,13 +154,13 @@ public:
*/ */
std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches); std::set<MatchingResult> filterLonguestDisjoint(const std::set<MatchingResult>& matches);
inline void addNodeLambda(const std::string& name, bool(func)(const NodePtr&)) { inline void addNodeLambda(const std::string& name, std::function<bool(const NodePtr&)> func) {
mLambda[name] = func; mLambda[name] = func;
} }
private: private:
std::shared_ptr<GraphView> mGraph; std::shared_ptr<GraphView> mGraph;
std::map<std::string, bool(*)(const NodePtr&)> mLambda; std::map<std::string, std::function<bool(const NodePtr&)>> mLambda;
/** /**
* QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}') * QUANTIFIER = '?' | '*' | '+' | ('{' [0-9]+ '}')
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graphRegex/matchFsm/MatchResult.hpp" #include "aidge/graph/Matching.hpp"
namespace Aidge { namespace Aidge {
...@@ -81,9 +81,6 @@ size_t removeIdentity(std::shared_ptr<GraphView> graph); ...@@ -81,9 +81,6 @@ size_t removeIdentity(std::shared_ptr<GraphView> graph);
*/ */
void removeFlatten(std::shared_ptr<Node> flatten); void removeFlatten(std::shared_ptr<Node> flatten);
void removeFlatten(std::shared_ptr<MatchSolution> solution);
/** /**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
* *
...@@ -151,6 +148,15 @@ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); ...@@ -151,6 +148,15 @@ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false);
*/ */
void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims); void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims);
/**
* Fuse each sub-graph matching a query in a Meta Operator.
* @param gm SinglePassGraphMatching containing the graph to manipulate
* @param query Sub-graph matching query
* @param type Type name of the resulting meta operators
* @return size_t Number of replacement
*/
size_t fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type = "");
/** /**
* Fuse each sub-graph matching a query in a Meta Operator. * Fuse each sub-graph matching a query in a Meta Operator.
* @param graph Graph to manipulate * @param graph Graph to manipulate
......
...@@ -10,6 +10,7 @@ ...@@ -10,6 +10,7 @@
********************************************************************************/ ********************************************************************************/
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <pybind11/functional.h>
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include <memory> #include <memory>
#include <string> #include <string>
...@@ -31,21 +32,20 @@ void init_SinglePassGraphMatching(py::module& m) { ...@@ -31,21 +32,20 @@ void init_SinglePassGraphMatching(py::module& m) {
py::class_<Aidge::SinglePassGraphMatching>(m, "SinglePassGraphMatching") py::class_<Aidge::SinglePassGraphMatching>(m, "SinglePassGraphMatching")
.def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph")) .def(py::init<std::shared_ptr<GraphView>>(), py::arg("graph"))
.def("match", .def("match",
[](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){ [](Aidge::SinglePassGraphMatching& self, const std::string& query, bool disjoint){
// Note: Need to convert set to vector has MatchingResult is not hashable and // Note: Need to convert set to vector has MatchingResult is not hashable and
// set<MatchingResult> cannot be binded // set<MatchingResult> cannot be binded
std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint); std::set<Aidge::SinglePassGraphMatching::MatchingResult> set_res = self.match(query, disjoint);
std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end()); std::vector<Aidge::SinglePassGraphMatching::MatchingResult> vec_res(set_res.begin(), set_res.end());
return vec_res; return vec_res;
}, },
py::arg("query"), py::arg("disjoint") = false, py::arg("query"), py::arg("disjoint") = false,
R"mydelimiter( Matches a query by direct, single-pass parse and match. R"mydelimiter( Matches a query by direct, single-pass parse and match.
:param query: The query string to search. :param query: The query string to search.
:param disjoint: If true, only keep the longest disjoint matches. :param disjoint: If true, only keep the longest disjoint matches.
:return: A set of MatchingResult instances. :return: A set of MatchingResult instances.
)mydelimiter"); )mydelimiter")
.def("add_node_lambda", &SinglePassGraphMatching::addNodeLambda, py::arg("name"), py::arg("func"));
} }
} // namespace Aidge } // namespace Aidge
...@@ -112,7 +112,20 @@ void init_Recipes(py::module &m) ...@@ -112,7 +112,20 @@ void init_Recipes(py::module &m)
:type recursive: bool :type recursive: bool
)mydelimiter"); )mydelimiter");
m.def("fuse_to_metaops", fuseToMetaOps, py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter( m.def("fuse_to_metaops", py::overload_cast<SinglePassGraphMatching&, const std::string&, const std::string&>(fuseToMetaOps), py::arg("gm"), py::arg("query"), py::arg("type") = "", R"mydelimiter(
Fuse each sub-graph matching a query in a Meta Operator.
:param gm: SinglePassGraphMatching containing the graph to manipulate
:type gm: :py:class:`aidge_core.SinglePassGraphMatching`
:param query: Sub-graph matching query
:type query: str
:param type: Type name of the resulting meta operators
:type type: str, optional
:return: Number of sub-graph actually fused in a Meta Operator.
:rtype: int
)mydelimiter");
m.def("fuse_to_metaops", py::overload_cast<std::shared_ptr<GraphView>, const std::string&, const std::string&>(fuseToMetaOps), py::arg("graph_view"), py::arg("query"), py::arg("type") = "", R"mydelimiter(
Fuse each sub-graph matching a query in a Meta Operator. Fuse each sub-graph matching a query in a Meta Operator.
:param graph_view: Graph view on which we want to apply the recipe :param graph_view: Graph view on which we want to apply the recipe
......
...@@ -17,9 +17,9 @@ ...@@ -17,9 +17,9 @@
#include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/MetaOperator.hpp"
#include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/Recipes.hpp"
size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) { size_t Aidge::fuseToMetaOps(SinglePassGraphMatching& gm, const std::string& query, const std::string& type) {
const auto metaType = (!type.empty()) ? type : query; const auto metaType = (!type.empty()) ? type : query;
const auto matches = SinglePassGraphMatching(graphView).match(query); const auto matches = gm.match(query);
size_t nbReplaced = 0; size_t nbReplaced = 0;
for (const auto& match : matches) { for (const auto& match : matches) {
...@@ -48,3 +48,8 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str ...@@ -48,3 +48,8 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str
Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size()); Log::info("Replaced {} (out of {}) matching sub-graph with meta operators", nbReplaced, matches.size());
return nbReplaced; return nbReplaced;
} }
size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::string& query, const std::string& type) {
SinglePassGraphMatching gm(graphView);
return fuseToMetaOps(gm, query, type);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment