Skip to content
Snippets Groups Projects
Commit 6682b885 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Recipies] Add method FuseBatchNorm.

parent 689a11f5
No related branches found
No related tags found
1 merge request!9Fuse bn
Pipeline #31925 failed
...@@ -76,7 +76,6 @@ public: ...@@ -76,7 +76,6 @@ public:
if (!mInputs[0]->empty()) { if (!mInputs[0]->empty()) {
for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) { for (std::size_t i = nbDataInputs(); i < nbInputs(); ++i) {
if(mInputs[i]->size() != mInputs[0]->dims()[1]) { if(mInputs[i]->size() != mInputs[0]->dims()[1]) {
assert(!mInputs[0]->hasImpl() && "Incompatible size with already implemented learnable parameter");
mInputs[i]->resize(std::array<DimSize_t, 1>({mInputs[0]->dims()[1]})); mInputs[i]->resize(std::array<DimSize_t, 1>({mInputs[0]->dims()[1]}));
} }
} }
...@@ -160,4 +159,4 @@ template <> ...@@ -160,4 +159,4 @@ template <>
const char *const EnumStrings<Aidge::BatchNormParam>::data[] = { "Epsilon", "Momentum" }; const char *const EnumStrings<Aidge::BatchNormParam>::data[] = { "Epsilon", "Momentum" };
} }
#endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_ #endif //AIDGE_CORE_OPERATOR_BATCHNORM_H_
\ No newline at end of file
...@@ -195,8 +195,14 @@ inline std::shared_ptr<Node> Conv( ...@@ -195,8 +195,14 @@ inline std::shared_ptr<Node> Conv(
namespace { namespace {
template <> template <>
const char *const EnumStrings<Aidge::ConvParam>::data[] = {"StrideDims", "DilationDims", "InChannels", "OutChannels", const char *const EnumStrings<Aidge::ConvParam>::data[] = {
"KernelDims", "PaddingDims"}; "StrideDims",
"DilationDims",
"InChannels",
"OutChannels",
"KernelDims",
"PaddingDims"
};
} }
#endif /* AIDGE_CORE_OPERATOR_CONV_H_ */ #endif /* AIDGE_CORE_OPERATOR_CONV_H_ */
...@@ -20,6 +20,8 @@ namespace py = pybind11; ...@@ -20,6 +20,8 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Recipies(py::module &m) { void init_Recipies(py::module &m) {
m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
...@@ -32,6 +34,7 @@ void init_Recipies(py::module &m) { ...@@ -32,6 +34,7 @@ void init_Recipies(py::module &m) {
:param nodes: The MatMul and Add nodes to fuse. :param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of :py:class:`aidge_core.Node` :type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter"); )mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator. Recipie to remove a flatten operator.
...@@ -44,6 +47,24 @@ void init_Recipies(py::module &m) { ...@@ -44,6 +47,24 @@ void init_Recipies(py::module &m) {
:param nodes: The flatten operator to remove. :param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node` :type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter"); )mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseBatchNorm), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator.
:param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
} }
} // namespace Aidge } // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <set>
#include <cassert>
#include <memory>
#include <string>
#include "aidge/operator/FC.hpp"
#include "aidge/operator/BatchNorm.hpp"
#include "aidge/operator/Conv.hpp"
#include "aidge/utils/Recipies.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
using namespace Aidge;
void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// Assert the nodes types are correct to be fused
std::shared_ptr<Node> conv;
std::shared_ptr<Node> batchnorm;
for (const auto& element : nodes) {
assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace");
if (element->type() == "Conv"){
conv = element;
}
else if (element->type() == "BatchNorm") {
batchnorm = element;
}
}
// TODO : check if batchnorm is the only child of the Conv or FC
std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second);
std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second);
std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second);
std::shared_ptr<Tensor> b_var = batchnorm->input(4).first->getOperator()->getOutput(batchnorm->input(4).second);
// TODO : Find a way to remove the template
const float epsilon = std::static_pointer_cast<BatchNorm_Op<2>>(batchnorm->getOperator())->get<float>("Epsilon");
DimSize_t convOutDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<DimSize_t>("OutChannels");
assert(scale->size() == convOutDims);
assert(shift->size() == convOutDims);
assert(b_mean->size() == convOutDims);
assert(b_var->size() == convOutDims);
assert(epsilon > 0.0);
// TODO : no no_bias attribute ?
float meanVariance = 0.0;
unsigned int count = 0;
for (std::size_t output = 0; output < convOutDims; ++output) {
// TODO : get suppose datatype is float ..
if (b_var->get<float>(output) > 1.0e-12) {
meanVariance += b_var->get<float>(output);
++count;
}
else {
printf("Zero-variance: %s [%lu]\n", conv->name().c_str(), output);
}
}
if (count > 0)
meanVariance /= count;
else {
printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n");
}
const DimSize_t channelsSize = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<DimSize_t>("InChannels");
// TODO : suppose we have Conv2D ...
const std::array<DimSize_t, 2> kernelDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->get<std::array<DimSize_t, 2>>("KernelDims");
std::shared_ptr<Tensor> weight = conv->input(1).first->getOperator()->getOutput(conv->input(1).second);
std::shared_ptr<Tensor> bias = conv->input(2).first->getOperator()->getOutput(conv->input(2).second);
for (std::size_t output = 0; output < convOutDims; ++output) {
// Corrected for zero-variance issue:
// "A Quantization-Friendly Separable Convolution for MobileNets"
// https://arxiv.org/pdf/1803.08607.pdf
// to help post-training quantization
const float factor = scale->get<float>(output)
/ std::sqrt(epsilon + ((b_var->get<float>(output) > 1.0e-12 || count == 0)
? b_var->get<float>(output) : meanVariance));
// Weights adjustments
for (std::size_t channel = 0; channel < channelsSize; ++channel) {
// TODO : Suppose kerneldims = 2
for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
std::vector<DimSize_t> currentIdx = {output, channel, k0, k1};
// TODO : suppose weights are float
float weightValue = weight->get<float>(currentIdx);
weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
}
}
}
// TODO : check if noBias==true is set, then set biasValue to 0
float biasValue = bias->get<float>(output);
biasValue = shift->get<float>(output) + (biasValue - b_mean->get<float>(output)) * factor;
bias->set<float>(output, biasValue);
}
auto g = std::make_shared<GraphView>();
g->add(std::set<std::shared_ptr<Node>>({
batchnorm,
batchnorm->input(1).first,
batchnorm->input(2).first,
batchnorm->input(3).first,
batchnorm->input(4).first
}));
g->replaceWith({});
}
void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm");
nodesRegex["Conv"] = new NodeRegex("Conv");
nodesRegex["FC"] = new NodeRegex("FC");
std::vector<std::string> seqRegex;
seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC)
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
fuseBatchNorm(matchNodes[i]);
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment