Skip to content
Snippets Groups Projects
Commit 25e5e447 authored by Olivier BICHLER's avatar Olivier BICHLER Committed by Maxence Naud
Browse files

Fix fuseBatchNorm() using alternate graph matching

parent 3820c97c
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!181Feat 145 grid sample
......@@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView);
*/
void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm);
void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
/**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
......
......@@ -16,6 +16,7 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
......@@ -25,9 +26,6 @@
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
// Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
std::shared_ptr<Aidge::Node> batchnormNode) {
// Case: convNode is a MetaOperator ending with a Convolution
......@@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
if (op->getOperator()->isAtomic()) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
} else { // op is a MetaOperator
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
Conv_Op<2>::Type) ||
(metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
ConvDepthWise_Op<2>::Type))) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
}
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
fmt::print("\n============================\nSearching for solutions\n==============================\n");
regex->setNodeKey(
"OP",
"getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
for (const auto& solution : regex->match(graphView)) {
fuseBatchNorm(solution);
auto matches = SinglePassGraphMatching(graphView).match("(Conv|ConvDepthWise|PaddedConv|PaddedConvDepthWise)->BatchNorm");
for (auto match : matches) {
auto rootNode = match.graph->rootNode();
fuseBatchNorm(rootNode, *rootNode->getChildren().begin());
}
}
\ No newline at end of file
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment