diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index ffb4599d83ba922ce5991460810f5d248806617c..b1959325ffc287df359569fd1e201ba6c52d9046 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -8,50 +8,77 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ -#include <set> #include <cassert> #include <memory> +#include <set> #include <string> -#include "aidge/operator/FC.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" #include "aidge/operator/BatchNorm.hpp" #include "aidge/operator/Conv.hpp" +#include "aidge/operator/ConvDepthWise.hpp" +#include "aidge/operator/FC.hpp" +#include "aidge/operator/MetaOperator.hpp" #include "aidge/recipies/Recipies.hpp" -#include "aidge/graph/GraphView.hpp" -#include "aidge/graph/Node.hpp" -#include "aidge/operator/Producer.hpp" -#include "aidge/operator/GenericOperator.hpp" - +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" -//Graph Regex +// Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" -void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr<Aidge::Node> batchnormNode) { +void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, + std::shared_ptr<Aidge::Node> batchnormNode) { + // Case: convNode is a MetaOperator ending with a Convolution + // eg. PaddedConv + if (!(convNode -> getOperator() -> isAtomic())) { + const std::shared_ptr<MetaOperator_Op> metaNode = std::static_pointer_cast<MetaOperator_Op>(convNode -> getOperator()); + const std::shared_ptr<GraphView> metanodeGraph = metaNode -> getMicroGraph(); + const std::vector<std::pair<std::shared_ptr<Node>, IOIndex_t>> outputNodes = metanodeGraph -> getOrderedOutputs(); + if (outputNodes.size() != 1) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Bad MetaOperator argument for fuseBatchNorm recipie."); + } + convNode = outputNodes[0].first; + } + + AIDGE_ASSERT(((convNode->type() == Conv_Op<2>::Type) || (convNode->type() == ConvDepthWise_Op<2>::Type)), "Wrong type"); + AIDGE_ASSERT(batchnormNode->type() == BatchNorm_Op<2>::Type, "Wrong type for batchnorm node."); // TODO: Find a way to remove the template // A feature map with 2 dimensions is assumed - const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); - const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); + const std::shared_ptr<BatchNorm_Op<2>> batchOp = + std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); + + DimSize_t convNbOutChannels; + DimSize_t channelsSize; + std::array<DimSize_t, 2> kernelDims; + std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator()); + if (convNode->type() == Conv_Op<2>::Type) { + const std::shared_ptr<Conv_Op<2>> convOpPtr = + std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); + convNbOutChannels = convOpPtr->getAttr<DimSize_t>("OutChannels"); + channelsSize = convOpPtr->getAttr<DimSize_t>("InChannels"); + kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims"); + } + else if (convNode->type() == ConvDepthWise_Op<2>::Type) { + const std::shared_ptr<ConvDepthWise_Op<2>> convOpPtr = + std::static_pointer_cast<ConvDepthWise_Op<2>>(convNode->getOperator()); + convNbOutChannels = convOpPtr->getAttr<DimSize_t>("Channels"); + channelsSize = 1; + kernelDims = convOpPtr->getAttr<std::array<DimSize_t, 2>>("KernelDims"); + } - const std::shared_ptr<Tensor> scale = batchOp->getInput(1); - const std::shared_ptr<Tensor> shift = batchOp->getInput(2); - const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3); - const std::shared_ptr<Tensor> b_var = batchOp->getInput(4); - const float epsilon = batchOp -> getAttr<float>("Epsilon"); - const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels"); - const DimSize_t channelsSize = convOp -> getAttr<DimSize_t>("InChannels"); - const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims"); + const std::shared_ptr<Tensor> scale = batchOp->getInput(1); + const std::shared_ptr<Tensor> shift = batchOp->getInput(2); + const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3); + const std::shared_ptr<Tensor> b_var = batchOp->getInput(4); + const float epsilon = batchOp->getAttr<float>("Epsilon"); - assert(scale->size() == convNbOutChannels); - assert(shift->size() == convNbOutChannels); - assert(b_mean->size() == convNbOutChannels); - assert(b_var->size() == convNbOutChannels); assert(epsilon > 0.0); // TODO : no no_bias attribute ? - float meanVariance = 0.0; unsigned int count = 0; @@ -60,8 +87,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr if (b_var->get<float>(outChId) > 1.0e-12) { meanVariance += b_var->get<float>(outChId); ++count; - } - else { + } else { printf("Zero-variance: %s [%lu]\n", convNode->name().c_str(), outChId); } } @@ -71,26 +97,30 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n"); } - std::shared_ptr<Tensor> weight = convOp -> getInput(1); - std::shared_ptr<Tensor> bias = convOp -> getInput(2); + std::shared_ptr<Tensor> weight = convOp->getInput(1); + std::shared_ptr<Tensor> bias = convOp->getInput(2); for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { // Corrected for zero-variance issue: // "A Quantization-Friendly Separable Convolution for MobileNets" // https://arxiv.org/pdf/1803.08607.pdf // to help post-training quantization - const float factor = scale->get<float>(outChId) - / std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0) - ? b_var->get<float>(outChId) : meanVariance)); + const float factor = + scale->get<float>(outChId) / + std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0) + ? b_var->get<float>(outChId) + : meanVariance)); // Weights adjustments for (std::size_t channel = 0; channel < channelsSize; ++channel) { // TODO : Suppose kerneldims = 2 - for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ - for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ + for (std::size_t k0 = 0; k0 < kernelDims[0]; ++k0) { + for (std::size_t k1 = 0; k1 < kernelDims[1]; ++k1) { std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1}; // TODO : suppose weights are float float weightValue = weight->get<float>(currentIdx); - weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights + weight->set<float>( + currentIdx, + weightValue * factor); // Update check it update Conv weights } } } @@ -101,38 +131,47 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr biasValue = shift->get<float>(outChId) + (biasValue - b_mean->get<float>(outChId)) * factor; bias->set<float>(outChId, biasValue); - } - GraphView::replace(std::set<std::shared_ptr<Node>>({ - batchnormNode, - batchnormNode->input(1).first, - batchnormNode->input(2).first, - batchnormNode->input(3).first, - batchnormNode->input(4).first - }), {}); - + GraphView::replace( + std::set<std::shared_ptr<Node>>( + {batchnormNode, batchnormNode->input(1).first, batchnormNode->input(2).first, + batchnormNode->input(3).first, batchnormNode->input(4).first}), + {}); } void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::MatchSolution> solution) { - assert(solution->at("BatchNorm").size() == 1 && "Wrong number of nodes BatchNorm to replace\n"); assert(solution->at("OP").size() == 1 && "Wrong number of nodes OP to replace\n"); for (const auto& op : solution->at("OP")) { - for (const auto& batchNorm : solution->at("BatchNorm")) { - fuseBatchNorm(op,batchNorm); + if (op->getOperator()->isAtomic()) { + for (const auto& batchNorm : solution->at("BatchNorm")) { + fuseBatchNorm(op, batchNorm); + } + } else { // op is a MetaOperator + auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(op->getOperator()); + if ((metaOp->getMicroGraph()->getOrderedOutputs().size() == 1) && + ((metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() == + Conv_Op<2>::Type) || + (metaOp->getMicroGraph()->getOrderedOutputs()[0].first->type() == + ConvDepthWise_Op<2>::Type))) { + for (const auto& batchNorm : solution->at("BatchNorm")) { + fuseBatchNorm(op, batchNorm); + } + } } } - } void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::GraphView> graphView) { - - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey("BatchNorm","getType($) =='BatchNorm'"); - regex->setNodeKey("OP","getType($) =='Conv'");// || getType($) =='FC' "); + regex->setNodeKey("BatchNorm", "getType($) =='BatchNorm'"); + printf("\n============================\nSearching for solutions\n==============================\n"); + regex->setNodeKey( + "OP", + "getType($) =='Conv' || getType($) =='ConvDepthWise' || getType($) =='PaddedConv' || getType($) =='PaddedConvDepthWise'"); + // || getType($) =='FC' "); regex->addQuery("OP -> BatchNorm");