Skip to content
Snippets Groups Projects
Commit a551865e authored by Cyril Moineau's avatar Cyril Moineau Committed by Olivier BICHLER
Browse files

Remove Flatten now use new graphmatching + remove Flatten beofre MatMul + remove multiple Flatten.

parent 7f1ba66c
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!189Improve Remove Flatten
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Recipes(py::module &m) void init_Recipes(py::module &m)
{ {
...@@ -71,9 +71,10 @@ void init_Recipes(py::module &m) ...@@ -71,9 +71,10 @@ void init_Recipes(py::module &m)
)mydelimiter"); )mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipe to remove a flatten operator. Recipe to remove a Flatten operator if it is followed by a FC or a MatMul.
The recipe can remove multiple Flatten operator if they are one after the other.
:param graph_view: Graph view on which we want to apply the recipe :param graph_view: Graph view on which we want to apply the recipe.
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
......
...@@ -17,38 +17,33 @@ ...@@ -17,38 +17,33 @@
//Graph Regex //Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp" // #include "aidge/graphRegex/GraphRegex.hpp"
#include "aidge/graph/Matching.hpp"
namespace Aidge { namespace Aidge {
void removeFlatten(std::shared_ptr<Node> flatten) {
GraphView::replace({flatten}, {});
}
void removeFlatten(std::shared_ptr<MatchSolution> solution){
assert(solution->at("FC").size() == 1 && "Wrong number of nodes FC to replace\n"); void removeFlatten(const std::set<NodePtr>& solution){
assert(solution->at("Flatten").size() == 1 && "Wrong number of nodes Flatten to replace\n"); std::set<NodePtr> flattenNodes {};
for (const auto& node : solution) {
for (const auto& flatten : solution->at("Flatten")) { if (node->type() == "Flatten"){
removeFlatten(flatten); printf("Flatten found.\n");
flattenNodes.insert(node);
}
else if (! (node->type() == "MatMul" || node->type() == "FC")){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Node of type {} is not MatMul nor FC, an error during GraphMatching occured !", node->type());
}
} }
GraphView::replace(flattenNodes, {});
} }
void removeFlatten(std::shared_ptr<GraphView> graphView){ void removeFlatten(std::shared_ptr<GraphView> graphView){
const auto matches = SinglePassGraphMatching(graphView).match(
"(FC|MatMul)<-(Flatten)+"
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); );
regex->setNodeKey("Flatten","getType($) =='Flatten'");
regex->setNodeKey("FC","getType($) =='FC'");
regex->addQuery("Flatten->FC");
for (const auto& solution : regex->match(graphView)) { for (const auto& solution : matches) {
removeFlatten(solution); removeFlatten(solution.graph->getNodes());
} }
} }
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment