Skip to content
Snippets Groups Projects
Commit 25e5e447 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Fix fuseBatchNorm() using alternate graph matching

parent 3820c97c
No related branches found
No related tags found
No related merge requests found
...@@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView); ...@@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView);
*/ */
void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm); void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm);
void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
/** /**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/ConvDepthWise.hpp"
...@@ -25,9 +26,6 @@ ...@@ -25,9 +26,6 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
// Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
std::shared_ptr<Aidge::Node> batchnormNode) { std::shared_ptr<Aidge::Node> batchnormNode) {
// Case: convNode is a MetaOperator ending with a Convolution // Case: convNode is a MetaOperator ending with a Convolution
...@@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, ...@@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
} }
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
if (op->getOperator()->isAtomic()) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
} else { // op is a MetaOperator
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
Conv_Op<2>::Type) ||
(metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
ConvDepthWise_Op<2>::Type))) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
}
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); auto matches = SinglePassGraphMatching(graphView).match("(Conv|ConvDepthWise|PaddedConv|PaddedConvDepthWise)->BatchNorm");
regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
fmt::print("\n============================\nSearching for solutions\n==============================\n");
regex->setNodeKey(
"OP",
"getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
for (const auto& solution : regex->match(graphView)) {
fuseBatchNorm(solution);
for (auto match : matches) {
auto rootNode = match.graph->rootNode();
fuseBatchNorm(rootNode, *rootNode->getChildren().begin());
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment