/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <set> #include <cassert> #include <memory> #include <string> #include "aidge/operator/FC.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/recipies/Recipies.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" // Graph Regex #include "aidge/graphmatching/GRegex.hpp" #include "aidge/graphmatching/NodeRegex.hpp" using namespace Aidge; void Aidge::fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes){ assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); // Assert the nodes types are correct to be fused std::shared_ptr<Node> conv; std::shared_ptr<Node> batchnorm; for (const auto& element : nodes) { assert((element->type() == "Conv" || element->type() == "BatchNorm") && "Wrong type for the nodes to replace"); if (element->type() == "Conv"){ conv = element; } else if (element->type() == "BatchNorm") { batchnorm = element; } } // TODO : check if batchnorm is the only child of the Conv or FC std::shared_ptr<Tensor> scale = batchnorm->input(1).first->getOperator()->getOutput(batchnorm->input(1).second); std::shared_ptr<Tensor> shift = batchnorm->input(2).first->getOperator()->getOutput(batchnorm->input(2).second); std::shared_ptr<Tensor> b_mean = batchnorm->input(3).first->getOperator()->getOutput(batchnorm->input(3).second); std::shared_ptr<Tensor> b_var = batchnorm->input(4).first->getOperator()->getOutput(batchnorm->input(4).second); // TODO : Find a way to remove the template const float epsilon = std::static_pointer_cast<BatchNorm_Op<2>>(batchnorm->getOperator())->getAttr<float>("Epsilon"); DimSize_t convOutDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("OutChannels"); assert(scale->size() == convOutDims); assert(shift->size() == convOutDims); assert(b_mean->size() == convOutDims); assert(b_var->size() == convOutDims); assert(epsilon > 0.0); // TODO : no no_bias attribute ? float meanVariance = 0.0; unsigned int count = 0; for (std::size_t output = 0; output < convOutDims; ++output) { // TODO : get suppose datatype is float .. if (b_var->get<float>(output) > 1.0e-12) { meanVariance += b_var->get<float>(output); ++count; } else { printf("Zero-variance: %s [%lu]\n", conv->name().c_str(), output); } } if (count > 0) meanVariance /= count; else { printf("variance < 1e-12 for all outputs! Is the network correctly trained?\n"); } const DimSize_t channelsSize = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<DimSize_t>("InChannels"); // TODO : suppose we have Conv2D ... const std::array<DimSize_t, 2> kernelDims = std::static_pointer_cast<Conv_Op<2>>(conv->getOperator())->getAttr<std::array<DimSize_t, 2>>("KernelDims"); std::shared_ptr<Tensor> weight = conv->input(1).first->getOperator()->getOutput(conv->input(1).second); std::shared_ptr<Tensor> bias = conv->input(2).first->getOperator()->getOutput(conv->input(2).second); for (std::size_t output = 0; output < convOutDims; ++output) { // Corrected for zero-variance issue: // "A Quantization-Friendly Separable Convolution for MobileNets" // https://arxiv.org/pdf/1803.08607.pdf // to help post-training quantization const float factor = scale->get<float>(output) / std::sqrt(epsilon + ((b_var->get<float>(output) > 1.0e-12 || count == 0) ? b_var->get<float>(output) : meanVariance)); // Weights adjustments for (std::size_t channel = 0; channel < channelsSize; ++channel) { // TODO : Suppose kerneldims = 2 for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ std::vector<DimSize_t> currentIdx = {output, channel, k0, k1}; // TODO : suppose weights are float float weightValue = weight->get<float>(currentIdx); weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights } } } // TODO : check if noBias==true is set, then set biasValue to 0 float biasValue = bias->get<float>(output); biasValue = shift->get<float>(output) + (biasValue - b_mean->get<float>(output)) * factor; bias->set<float>(output, biasValue); } auto g = std::make_shared<GraphView>(); g->add(std::set<std::shared_ptr<Node>>({ batchnorm, batchnorm->input(1).first, batchnorm->input(2).first, batchnorm->input(3).first, batchnorm->input(4).first })); g->replaceWith({}); } void Aidge::fuseBatchNorm(std::shared_ptr<GraphView> graphView){ std::map<std::string,NodeRegex*> nodesRegex ; nodesRegex["BatchNorm"] = new NodeRegex("BatchNorm"); nodesRegex["Conv"] = new NodeRegex("Conv"); nodesRegex["FC"] = new NodeRegex("FC"); std::vector<std::string> seqRegex; seqRegex.push_back("Conv -> BatchNorm;"); // TODO: Add (Conv | FC) GRegex GReg(nodesRegex, seqRegex); Match matches = GReg.match(graphView); std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes(); for (size_t i = 0; i < matches.getNbMatch(); ++i) { fuseBatchNorm(matchNodes[i]); } }