Skip to content
Snippets Groups Projects
Commit 19a41f8e authored by vincent  lorrain's avatar vincent lorrain
Browse files

removeFlatten recipies and fix pybind

parent c3b0ad25
No related branches found
No related tags found
1 merge request!48Refactor/recipies
...@@ -50,7 +50,11 @@ void fuseMulAdd(std::shared_ptr<GraphView> graphView); ...@@ -50,7 +50,11 @@ void fuseMulAdd(std::shared_ptr<GraphView> graphView);
* *
* @param nodes Strict set of Node to merge. * @param nodes Strict set of Node to merge.
*/ */
void removeFlatten(std::set<std::shared_ptr<Node>> nodes); void removeFlatten(std::shared_ptr<Node> flatten);
void removeFlatten(std::shared_ptr<MatchSolution> solution);
/** /**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
* *
......
...@@ -28,12 +28,13 @@ void init_Recipies(py::module &m) { ...@@ -28,12 +28,13 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param nodes: The MatMul and Add nodes to fuse. // :param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter"); // )mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator. Recipie to remove a flatten operator.
...@@ -41,18 +42,20 @@ void init_Recipies(py::module &m) { ...@@ -41,18 +42,20 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator.
:param nodes: The flatten operator to remove. // m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
:type nodes: list of :py:class:`aidge_core.Node` // Recipie to remove a flatten operator.
)mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param nodes: The MatMul and Add nodes to fuse. // :param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter"); // )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator. Recipie to remove a flatten operator.
...@@ -60,11 +63,12 @@ void init_Recipies(py::module &m) { ...@@ -60,11 +63,12 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie :param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator. // m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator.
:param nodes: The flatten operator to remove. // :param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter"); // )mydelimiter");
} }
} // namespace Aidge } // namespace Aidge
...@@ -18,33 +18,59 @@ ...@@ -18,33 +18,59 @@
// Graph Regex // Graph Regex
#include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp" #include "aidge/graphmatching/NodeRegex.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge { namespace Aidge {
void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { void removeFlatten(std::shared_ptr<Node> flatten) {
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); // assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
std::shared_ptr<Node> flatten; // std::shared_ptr<Node> flatten;
for (const auto& element : nodes) { // for (const auto& element : nodes) {
assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace"); // assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
if (element->type() == "Flatten"){ // if (element->type() == "Flatten"){
flatten = element; // flatten = element;
} // }
} // }
GraphView::replace({flatten}, {}); GraphView::replace({flatten}, {});
} }
void removeFlatten(std::shared_ptr<MatchSolution> solution){
assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n");
assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n");
for (const auto& flatten : solution->at("Flatten")) {
removeFlatten(flatten);
}
}
void removeFlatten(std::shared_ptr<GraphView> graphView){ void removeFlatten(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ; // std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["Flatten"] = new NodeRegex("Flatten"); // nodesRegex["Flatten"] = new NodeRegex("Flatten");
nodesRegex["FC"] = new NodeRegex("FC"); // nodesRegex["FC"] = new NodeRegex("FC");
std::vector<std::string> seqRegex; // std::vector<std::string> seqRegex;
seqRegex.push_back("Flatten->FC;"); // seqRegex.push_back("Flatten->FC;");
GRegex GReg(nodesRegex, seqRegex); // GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView); // Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); // std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) { // for (size_t i = 0; i < matches.getNbMatch(); ++i) {
removeFlatten(matchNodes[i]); // removeFlatten(matchNodes[i]);
// }
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Flatten","getType($) =='Flatten'");
regex->setNodeKey("FC","getType($) =='FC'");
regex->addQuery("Flatten->FC");
for (const auto& solution : regex->match(graphView)) {
removeFlatten(solution);
} }
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment