Skip to content
Snippets Groups Projects
Commit 2f0b5ba2 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'fix_fuse_batchnorm' into 'dev'

Fix fuseBatchNorm() using alternate graph matching

See merge request !183
parents 02917b8d 3be0f2d3
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!183Fix fuseBatchNorm() using alternate graph matching
Pipeline #53250 passed
...@@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView); ...@@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView);
*/ */
void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm); void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm);
void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
/** /**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/ConvDepthWise.hpp"
...@@ -25,9 +26,6 @@ ...@@ -25,9 +26,6 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
// Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
std::shared_ptr<Aidge::Node> batchnormNode) { std::shared_ptr<Aidge::Node> batchnormNode) {
// Case: convNode is a MetaOperator ending with a Convolution // Case: convNode is a MetaOperator ending with a Convolution
...@@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, ...@@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
} }
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
if (op->getOperator()->isAtomic()) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
} else { // op is a MetaOperator
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
Conv_Op<2>::Type) ||
(metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
ConvDepthWise_Op<2>::Type))) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
}
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); auto matches = SinglePassGraphMatching(graphView).match("(Conv|ConvDepthWise|PaddedConv|PaddedConvDepthWise)->BatchNorm");
regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
fmt::print("\n============================\nSearching for solutions\n==============================\n");
regex->setNodeKey(
"OP",
"getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
for (const auto& solution : regex->match(graphView)) {
fuseBatchNorm(solution);
for (auto match : matches) {
auto rootNode = match.graph->rootNode();
fuseBatchNorm(rootNode, *rootNode->getChildren().begin());
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment