Skip to content
Snippets Groups Projects
Commit b2b97cfc authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fix fuseBatchNorm() using alternate graph matching

parent 5a68b2db
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!183Fix fuseBatchNorm() using alternate graph matching
Pipeline #53226 canceled
...@@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView); ...@@ -98,10 +98,6 @@ void removeFlatten(std::shared_ptr<GraphView> graphView);
*/ */
void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm); void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm);
void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);
/** /**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp" #include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp" #include "aidge/operator/ConvDepthWise.hpp"
...@@ -25,9 +26,6 @@ ...@@ -25,9 +26,6 @@
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
// Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
std::shared_ptr<Aidge::Node> batchnormNode) { std::shared_ptr<Aidge::Node> batchnormNode) {
// Case: convNode is a MetaOperator ending with a Convolution // Case: convNode is a MetaOperator ending with a Convolution
...@@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, ...@@ -191,44 +189,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
} }
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
if (op->getOperator()->isAtomic()) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
} else { // op is a MetaOperator
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
Conv_Op<2>::Type) ||
(metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
ConvDepthWise_Op<2>::Type))) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
}
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); auto matches = SinglePassGraphMatching(graphView).match("(Conv|ConvDepthWise|PaddedConv|PaddedConvDepthWise)->BatchNorm");
regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
fmt::print("\n============================\nSearching for solutions\n==============================\n");
regex->setNodeKey(
"OP",
"getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
for (const auto& solution : regex->match(graphView)) {
fuseBatchNorm(solution);
for (auto match : matches) {
auto rootNode = match.graph->rootNode();
fuseBatchNorm(rootNode, *rootNode->getChildren().begin());
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment