Skip to content
Snippets Groups Projects
Commit 19a41f8e authored by vincent  lorrain's avatar vincent lorrain
Browse files

removeFlatten recipies and fix pybind

parent c3b0ad25
No related branches found
No related tags found
No related merge requests found
......@@ -50,7 +50,11 @@ void fuseMulAdd(std::shared_ptr<GraphView> graphView);
*
* @param nodes Strict set of Node to merge.
*/
void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
void removeFlatten(std::shared_ptr<Node> flatten);
void removeFlatten(std::shared_ptr<MatchSolution> solution);
/**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
*
......
......@@ -28,12 +28,13 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
......@@ -41,18 +42,20 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator.
:param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator.
:param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
// :param nodes: The flatten operator to remove.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
......@@ -60,11 +63,12 @@ void init_Recipies(py::module &m) {
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator.
// m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
// Recipie to remove a flatten operator.
:param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
// :param nodes: The flatten operator to remove.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
}
} // namespace Aidge
......@@ -18,33 +18,59 @@
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge {
void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
std::shared_ptr<Node> flatten;
for (const auto& element : nodes) {
assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
if (element->type() == "Flatten"){
flatten = element;
}
}
void removeFlatten(std::shared_ptr<Node> flatten) {
// assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// std::shared_ptr<Node> flatten;
// for (const auto& element : nodes) {
// assert((element->type() == "FC" || element->type() == "Flatten") && "Wrong type for the nodes to replace");
// if (element->type() == "Flatten"){
// flatten = element;
// }
// }
GraphView::replace({flatten}, {});
}
void removeFlatten(std::shared_ptr<MatchSolution> solution){
assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n");
assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n");
for (const auto& flatten : solution->at("Flatten")) {
removeFlatten(flatten);
}
}
void removeFlatten(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["Flatten"] = new NodeRegex("Flatten");
nodesRegex["FC"] = new NodeRegex("FC");
std::vector<std::string> seqRegex;
seqRegex.push_back("Flatten->FC;");
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
removeFlatten(matchNodes[i]);
// std::map<std::string,NodeRegex*> nodesRegex ;
// nodesRegex["Flatten"] = new NodeRegex("Flatten");
// nodesRegex["FC"] = new NodeRegex("FC");
// std::vector<std::string> seqRegex;
// seqRegex.push_back("Flatten->FC;");
// GRegex GReg(nodesRegex, seqRegex);
// Match matches = GReg.match(graphView);
// std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
// for (size_t i = 0; i < matches.getNbMatch(); ++i) {
// removeFlatten(matchNodes[i]);
// }
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Flatten","getType($) =='Flatten'");
regex->setNodeKey("FC","getType($) =='FC'");
regex->addQuery("Flatten->FC");
for (const auto& solution : regex->match(graphView)) {
removeFlatten(solution);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment