Skip to content
Snippets Groups Projects
Commit 6e6117d5 authored by Maxence Naud's avatar Maxence Naud
Browse files

Update fuseBatchNorm to include ConvDepthWise and MetaOperator containing Conv...

Update fuseBatchNorm to include ConvDepthWise and MetaOperator containing Conv or ConvDepthWise as an output
parent d895263f
No related branches found
No related tags found
No related merge requests found
......@@ -8,50 +8,77 @@
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <set>
#include <cassert>
#include <memory>
#include <set>
#include <string>
#include "aidge/operator/FC.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
//Graph Regex
// Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr<Aidge::Node> batchnormNode) {
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
std::shared_ptr<Aidge::Node> batchnormNode) {
// Case: convNode is a MetaOperator ending with a Convolution
// eg. PaddedConv
if (!(convNode -> getOperator() -> isAtomic())) {
const std::shared_ptr<MetaOperator_Op> metaNode = std::static_pointer_cast<MetaOperator_Op>(convNode -> getOperator());
const std::shared_ptr<GraphView> metanodeGraph = metaNode -> getMicroGraph();
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputNodes = metanodeGraph -> getOrderedOutputs();
if (outputNodes.size() != 1) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Bad MetaOperator argument for fuseBatchNorm recipie.");
}
convNode = outputNodes[0].first;
}
AIDGE_ASSERT(((convNode->type() == Conv_Op<2>::Type) || (convNode->type() == ConvDepthWise_Op<2>::Type)), "Wrong type");
AIDGE_ASSERT(batchnormNode->type() == BatchNorm_Op<2>::Type, "Wrong type for batchnorm node.");
// TODO: Find a way to remove the template
// A feature map with 2 dimensions is assumed
const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
const std::shared_ptr<BatchNorm_Op<2>> batchOp =
std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
DimSize_t convNbOutChannels;
DimSize_t channelsSize;
std::array<DimSize_t, 2> kernelDims;
std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
if (convNode->type() == Conv_Op<2>::Type) {
const std::shared_ptr<Conv_Op<2>> convOpPtr =
std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
convNbOutChannels = convOpPtr->getAttr<DimSize_t>("OutChannels");
channelsSize = convOpPtr->getAttr<DimSize_t>("InChannels");
kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims");
}
else if (convNode->type() == ConvDepthWise_Op<2>::Type) {
const std::shared_ptr<ConvDepthWise_Op<2>> convOpPtr =
std::static_pointer_cast<ConvDepthWise_Op<2>>(convNode->getOperator());
convNbOutChannels = convOpPtr->getAttr<DimSize_t>("Channels");
channelsSize = 1;
kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims");
}
const std::shared_ptr<Tensor> scale = batchOp->getInput(1);
const std::shared_ptr<Tensor> shift = batchOp->getInput(2);
const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3);
const std::shared_ptr<Tensor> b_var = batchOp->getInput(4);
const float epsilon = batchOp -> getAttr<float>("Epsilon");
const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels");
const DimSize_t channelsSize = convOp -> getAttr<DimSize_t>("InChannels");
const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims");
const std::shared_ptr<Tensor> scale = batchOp->getInput(1);
const std::shared_ptr<Tensor> shift = batchOp->getInput(2);
const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3);
const std::shared_ptr<Tensor> b_var = batchOp->getInput(4);
const float epsilon = batchOp->getAttr<float>("Epsilon");
assert(scale->size() == convNbOutChannels);
assert(shift->size() == convNbOutChannels);
assert(b_mean->size() == convNbOutChannels);
assert(b_var->size() == convNbOutChannels);
assert(epsilon > 0.0);
// TODO : no no_bias attribute ?
float meanVariance = 0.0;
unsigned int count = 0;
......@@ -60,8 +87,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
if (b_var->get<float>(outChId) > 1.0e-12) {
meanVariance += b_var->get<float>(outChId);
++count;
}
else {
} else {
printf("Zero-variance: %s [%lu]\n", convNode->name().c_str(), outChId);
}
}
......@@ -71,26 +97,30 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n");
}
std::shared_ptr<Tensor> weight = convOp -> getInput(1);
std::shared_ptr<Tensor> bias = convOp -> getInput(2);
std::shared_ptr<Tensor> weight = convOp->getInput(1);
std::shared_ptr<Tensor> bias = convOp->getInput(2);
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
// Corrected for zero-variance issue:
// "A Quantization-Friendly Separable Convolution for MobileNets"
// https://arxiv.org/pdf/1803.08607.pdf
// to help post-training quantization
const float factor = scale->get<float>(outChId)
/ std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0)
? b_var->get<float>(outChId) : meanVariance));
const float factor =
scale->get<float>(outChId) /
std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0)
? b_var->get<float>(outChId)
: meanVariance));
// Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) {
// TODO : Suppose kerneldims = 2
for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
for (std::size_t k0 = 0; k0 < kernelDims[0]; ++k0) {
for (std::size_t k1 = 0; k1 < kernelDims[1]; ++k1) {
std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1};
// TODO : suppose weights are float
float weightValue = weight->get<float>(currentIdx);
weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
weight->set<float>(
currentIdx,
weightValue * factor); // Update check it update Conv weights
}
}
}
......@@ -101,38 +131,47 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
biasValue = shift->get<float>(outChId) + (biasValue - b_mean->get<float>(outChId)) * factor;
bias->set<float>(outChId, biasValue);
}
GraphView::replace(std::set<std::shared_ptr<Node>>({
batchnormNode,
batchnormNode->input(1).first,
batchnormNode->input(2).first,
batchnormNode->input(3).first,
batchnormNode->input(4).first
}), {});
GraphView::replace(
std::set<std::shared_ptr<Node>>(
{batchnormNode, batchnormNode->input(1).first, batchnormNode->input(2).first,
batchnormNode->input(3).first, batchnormNode->input(4).first}),
{});
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op,batchNorm);
if (op->getOperator()->isAtomic()) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
} else { // op is a MetaOperator
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
Conv_Op<2>::Type) ||
(metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
ConvDepthWise_Op<2>::Type))) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
}
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'");
regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' ");
regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
printf("\n============================\nSearching for solutions\n==============================\n");
regex->setNodeKey(
"OP",
"getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment