Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <set>
#include <cassert>
#include <memory>
#include <string>
#include "aidge/operator/FC.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr<Aidge::Node> batchnormNode) {
// TODO: Find a way to remove the template
// A feature map with 2 dimensions is assumed
const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf;
const Tensor& scale = batchOp->getInput(1)->refCastNative(scaleBuf);
const Tensor& shift = batchOp->getInput(2)->refCastNative(shiftBuf);
const Tensor& b_mean = batchOp->getInput(3)->refCastNative(b_meanBuf);
const Tensor& b_var = batchOp->getInput(4)->refCastNative(b_meanBuf);
const float epsilon = batchOp -> getAttr<float>("Epsilon");
const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels");
const DimSize_t channelsSize = convOp -> getAttr<DimSize_t>("InChannels");
const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims");
assert(scale.size() == convNbOutChannels);
assert(shift.size() == convNbOutChannels);
assert(b_mean.size() == convNbOutChannels);
assert(b_var.size() == convNbOutChannels);
assert(epsilon > 0.0);
// TODO : no no_bias attribute ?
float meanVariance = 0.0;
unsigned int count = 0;
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
if (b_var.get<float>(outChId) > 1.0e-12) {
meanVariance += b_var.get<float>(outChId);
printf("Zero-variance: %s [%lu]\n", convNode->name().c_str(), outChId);
}
}
if (count > 0)
meanVariance /= count;
else {
printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n");
std::shared_ptr<Tensor> weightBuf, biasBuf;
Tensor& weight = convOp->getInput(1)->refCastNative(weightBuf);
Tensor& bias = convOp->getInput(2)->refCastNative(biasBuf);
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
// Corrected for zero-variance issue:
// "A Quantization-Friendly Separable Convolution for MobileNets"
// https://arxiv.org/pdf/1803.08607.pdf
// to help post-training quantization
const float factor = scale.get<float>(outChId)
/ std::sqrt(epsilon + ((b_var.get<float>(outChId) > 1.0e-12 || count == 0)
? b_var.get<float>(outChId) : meanVariance));
// Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) {
// TODO : Suppose kerneldims = 2
for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1};
float weightValue = weight.get<float>(currentIdx);
weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
}
}
}
// TODO : check if noBias==true is set, then set biasValue to 0
float biasValue = bias.get<float>(outChId);
biasValue = shift.get<float>(outChId) + (biasValue - b_mean.get<float>(outChId)) * factor;
bias.set<float>(outChId, biasValue);
// Copy values back to the original tensors (actual copy only if needed)
convOp->getInput(1)->copyCastFrom(weight);
convOp->getInput(2)->copyCastFrom(bias);
GraphView::replace(std::set<std::shared_ptr<Node>>({
batchnormNode,
batchnormNode->input(1).first,
batchnormNode->input(2).first,
batchnormNode->input(3).first,
batchnormNode->input(4).first
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op,batchNorm);
}
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'");
regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
for (const auto& solution : regex->match(graphView)) {
fuseBatchNorm(solution);