Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cassert>
#include <memory>

Maxence Naud
committed
#include <set>
#include "aidge/data/Tensor.hpp"

Maxence Naud
committed
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"

Maxence Naud
committed
#include "aidge/operator/ConvDepthWise.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MetaOperator.hpp"

Maxence Naud
committed
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"

Maxence Naud
committed
// Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"

Maxence Naud
committed
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
std::shared_ptr<Aidge::Node> batchnormNode) {
// Case: convNode is a MetaOperator ending with a Convolution
// eg. PaddedConv

Maxence Naud
committed
if (!(convNode -> getOperator() -> isAtomic())) {
metaNode = convNode;
const auto metaOp = std::static_pointer_cast<MetaOperator_Op>(convNode -> getOperator());
const std::shared_ptr<GraphView> metaOpGraph = metaOp -> getMicroGraph();
const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputNodes = metaOpGraph -> getOrderedOutputs();

Maxence Naud
committed
if (outputNodes.size() != 1) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Bad MetaOperator argument for fuseBatchNorm recipie.");
}
convNode = outputNodes[0].first;
}
AIDGE_ASSERT(((convNode->type() == Conv_Op<2>::Type) || (convNode->type() == ConvDepthWise_Op<2>::Type)), "Wrong type");
AIDGE_ASSERT(batchnormNode->type() == BatchNorm_Op<2>::Type, "Wrong type for batchnorm node.");
// TODO: Find a way to remove the template
// A feature map with 2 dimensions is assumed

Maxence Naud
committed
const std::shared_ptr<BatchNorm_Op<2>> batchOp =
std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
DimSize_t convNbOutChannels = 1;
DimSize_t channelsSize = 1;
std::array<DimSize_t, 2> kernelDims = {1,1};
AIDGE_ASSERT(convNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");

Maxence Naud
committed
std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
if (convNode->type() == Conv_Op<2>::Type) {
const std::shared_ptr<Conv_Op<2>> convOpPtr =
std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
convNbOutChannels = convOpPtr->outChannels();
channelsSize = convOpPtr->inChannels();

Maxence Naud
committed
kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims");
}
else if (convNode->type() == ConvDepthWise_Op<2>::Type) {
const std::shared_ptr<ConvDepthWise_Op<2>> convOpPtr =
std::static_pointer_cast<ConvDepthWise_Op<2>>(convNode->getOperator());
convNbOutChannels = convOpPtr->nbChannels();

Maxence Naud
committed
kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims");
}
AIDGE_ASSERT(kernelDims.size() == 2, "fuseBatchNorm(): only 2D convolutions are supported");
std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf;
const Tensor& scale = batchOp->getInput(1)->refCastFrom(scaleBuf, DataType::Float32, "cpu");
const Tensor& shift = batchOp->getInput(2)->refCastFrom(shiftBuf, DataType::Float32, "cpu");
const Tensor& b_mean = batchOp->getInput(3)->refCastFrom(b_meanBuf, DataType::Float32, "cpu");
const Tensor& b_var = batchOp->getInput(4)->refCastFrom(b_varBuf, DataType::Float32, "cpu");

Maxence Naud
committed
const float epsilon = batchOp->getAttr<float>("Epsilon");
assert(epsilon > 0.0);
// TODO : no no_bias attribute ?
float meanVariance = 0.0;
unsigned int count = 0;
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
if (b_var.get<float>(outChId) > 1.0e-12) {
meanVariance += b_var.get<float>(outChId);

Maxence Naud
committed
} else {
fmt::print("Zero-variance: {} [{}]\n", convNode->name(), outChId);
}
}
if (count > 0)
meanVariance /= count;
else {
fmt::print("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n");
// Add bias if it is non existant, as there will be a bias after the fuse
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
if (metaNode) {
// Conv is inside a meta-operator, we add bias outside it
// Find the correct input index of the meta-operator corresponding
// to the bias:
const auto metaOp = std::static_pointer_cast<MetaOperator_Op>(metaNode->getOperator());
const auto metaOpGraph = metaOp->getMicroGraph();
IOIndex_t inputIdx = 0;
for (auto input : metaOpGraph->getOrderedInputs()) {
if (input.first == convNode && input.second == 2) {
break;
}
++inputIdx;
}
auto prod = addProducer(metaNode, inputIdx, {convNbOutChannels}, "b");
// Add the new bias node to the same views as the meta node
for (auto g : metaNode->views()) {
g->add(prod);
}
}
else {
auto prod = addProducer(convNode, 2, {convNbOutChannels}, "b");
if (convNode->input(1).first) {
// Add the new bias node to the same views as the weights node
// if possible
for (auto g : convNode->input(1).first->views()) {
g->add(prod);
}
}
else {
for (auto g : convNode->views()) {
g->add(prod);
}
}
AIDGE_INTERNAL_ASSERT(convOp->getInput(2) != nullptr);
// Use the same backend for the bias than for the weights
convOp->getInput(2)->setBackend(convOp->getInput(1)->backend());
convOp->getInput(2)->zeros();
std::shared_ptr<Tensor> weightBuf, biasBuf;
Tensor& weight = convOp->getInput(1)->refCastFrom(weightBuf, DataType::Float32, "cpu");
Tensor& bias = convOp->getInput(2)->refCastFrom(biasBuf, DataType::Float32, "cpu");
for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
// Corrected for zero-variance issue:
// "A Quantization-Friendly Separable Convolution for MobileNets"
// https://arxiv.org/pdf/1803.08607.pdf
// to help post-training quantization
const float factor = scale.get<float>(outChId)
/ std::sqrt(epsilon + ((b_var.get<float>(outChId) > 1.0e-12 || count == 0)
? b_var.get<float>(outChId) : meanVariance));
// Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) {

Maxence Naud
committed
for (std::size_t k0 = 0; k0 < kernelDims[0]; ++k0) {
for (std::size_t k1 = 0; k1 < kernelDims[1]; ++k1) {
std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1};
float weightValue = weight.get<float>(currentIdx);
weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
biasValue = shift.get<float>(outChId) + (biasValue - b_mean.get<float>(outChId)) * factor;
bias.set<float>(outChId, biasValue);
// Copy values back to the original tensors (actual copy only if needed)
convOp->getInput(1)->copyCastFrom(weight);
convOp->getInput(2)->copyCastFrom(bias);
GraphView::replace(std::set<std::shared_ptr<Node>>({
batchnormNode,
batchnormNode->input(1).first,
batchnormNode->input(2).first,
batchnormNode->input(3).first,
batchnormNode->input(4).first
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) {
assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n");
assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n");
for (const auto& op : solution->at("OP")) {

Maxence Naud
committed
if (op->getOperator()->isAtomic()) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
} else { // op is a MetaOperator
auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator());
if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) &&
((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
Conv_Op<2>::Type) ||
(metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() ==
ConvDepthWise_Op<2>::Type))) {
for (const auto& batchNorm : solution->at("BatchNorm")) {
fuseBatchNorm(op, batchNorm);
}
}
void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();

Maxence Naud
committed
regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'");
fmt::print("\n============================\nSearching for solutions\n==============================\n");

Maxence Naud
committed
regex->setNodeKey(
"OP",
"getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'");
// || getType($) =='FC' ");
regex->addQuery("OP -> BatchNorm");
for (const auto& solution : regex->match(graphView)) {
fuseBatchNorm(solution);