Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <set>
#include <cassert>
#include <memory>
#include <string>
#include "aidge/operator/FC.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
using namespace Aidge;
void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// Assert the nodes types are correct to be fused
std::shared_ptr<Node> conv;
std::shared_ptr<Node> batchnorm;
for (const auto& element : nodes) {
assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace");
if (element->type() == "Conv"){
conv = element;
}
else if (element->type() == "BatchNorm") {
batchnorm = element;
}
}
// TODO : check if batchnorm is the only child of the Conv or FC
std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second);
std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second);
std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second);
std::shared_ptr<Tensor> b_var = batchnorm->input(4).first->getOperator()->getOutput(batchnorm->input(4).second);
// TODO : Find a way to remove the template
const float epsilon = std::static_pointer_cast<BatchNorm_Op<2>>(batchnorm->getOperator())->getAttr<float>("Epsilon");
DimSize_t convOutDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("OutChannels");
assert(scale->size() == convOutDims);
assert(shift->size() == convOutDims);
assert(b_mean->size() == convOutDims);
assert(b_var->size() == convOutDims);
assert(epsilon > 0.0);
// TODO : no no_bias attribute ?
float meanVariance = 0.0;
unsigned int count = 0;
for (std::size_t output = 0; output < convOutDims; ++output) {
// TODO : get suppose datatype is float ..
if (b_var->get<float>(output) > 1.0e-12) {
meanVariance += b_var->get<float>(output);
++count;
}
else {
printf("Zero-variance: %s [%lu]\n", conv->name().c_str(), output);
}
}
if (count > 0)
meanVariance /= count;
else {
printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n");
}
const DimSize_t channelsSize = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("InChannels");
// TODO : suppose we have Conv2D ...
const std::array<DimSize_t, 2> kernelDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<std::array<DimSize_t, 2>>("KernelDims");
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
std::shared_ptr<Tensor> weight = conv->input(1).first->getOperator()->getOutput(conv->input(1).second);
std::shared_ptr<Tensor> bias = conv->input(2).first->getOperator()->getOutput(conv->input(2).second);
for (std::size_t output = 0; output < convOutDims; ++output) {
// Corrected for zero-variance issue:
// "A Quantization-Friendly Separable Convolution for MobileNets"
// https://arxiv.org/pdf/1803.08607.pdf
// to help post-training quantization
const float factor = scale->get<float>(output)
/ std::sqrt(epsilon + ((b_var->get<float>(output) > 1.0e-12 || count == 0)
? b_var->get<float>(output) : meanVariance));
// Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) {
// TODO : Suppose kerneldims = 2
for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
std::vector<DimSize_t> currentIdx = {output, channel, k0, k1};
// TODO : suppose weights are float
float weightValue = weight->get<float>(currentIdx);
weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
}
}
}
// TODO : check if noBias==true is set, then set biasValue to 0
float biasValue = bias->get<float>(output);
biasValue = shift->get<float>(output) + (biasValue - b_mean->get<float>(output)) * factor;
bias->set<float>(output, biasValue);
}
auto g = std::make_shared<GraphView>();
g->add(std::set<std::shared_ptr<Node>>({
batchnorm,
batchnorm->input(1).first,
batchnorm->input(2).first,
batchnorm->input(3).first,
batchnorm->input(4).first
}));
g->replaceWith({});
}
void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm");
nodesRegex["Conv"] = new NodeRegex("Conv");
nodesRegex["FC"] = new NodeRegex("FC");
std::vector<std::string> seqRegex;
seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC)
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
fuseBatchNorm(matchNodes[i]);
}
}