From d506f102882d1fae499af47fc7a0e7c9663ae42d Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sat, 9 Dec 2023 23:57:17 +0100 Subject: [PATCH] Make FuseBatchNorm work for any type --- include/aidge/data/Tensor.hpp | 30 +++++++++++++++++++-- src/recipies/FuseBatchNorm.cpp | 48 ++++++++++++++++++---------------- 2 files changed, 54 insertions(+), 24 deletions(-) diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp index 9a9f26bb2..94362c6cf 100644 --- a/include/aidge/data/Tensor.hpp +++ b/include/aidge/data/Tensor.hpp @@ -464,17 +464,29 @@ class Tensor : public Data, return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); } + template <typename expectedType> + const expectedType& get(std::size_t idx) const { + // TODO : add assert expected Type compatible with datatype + // TODO : add assert idx < Size + return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx)); + } + template <typename expectedType> expectedType& get(std::vector<std::size_t> coordIdx){ return get<expectedType>(getIdx(coordIdx)); } + template <typename expectedType> + const expectedType& get(std::vector<std::size_t> coordIdx) const { + return get<expectedType>(getIdx(coordIdx)); + } + template <typename expectedType> void set(std::size_t idx, expectedType value){ // TODO : add assert expected Type compatible with datatype // TODO : add assert idx < Size - void* dataPtr = mImpl->getRaw(idx); - std::memcpy(dataPtr, &value, sizeof(expectedType)); + expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRaw(idx)); + *dataPtr = value; } template <typename expectedType> @@ -695,6 +707,20 @@ class Tensor : public Data, return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second); } + /** + * Return a reference to a Tensor with float32 type on CPU: + * - itself, if already with the right characteristics; + * - the provided Tensor, overwritten with the copy-casted data. + * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary. + * The shared_ptr does not need to be initialized. No new memory allocation + * will occur if fallback has already been allocated with the right + * type/size/device. + * @return Reference to either itself or to fallback. + */ + Tensor& refCastNative(std::shared_ptr<Tensor>& fallback) { + return refCast(fallback, DataType::Float32).ref(fallback, "cpu"); + } + private: ///\bug not protected against overflow std::size_t computeSize() { diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp index ffb4599d8..8b9b4dba3 100644 --- a/src/recipies/FuseBatchNorm.cpp +++ b/src/recipies/FuseBatchNorm.cpp @@ -33,10 +33,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator()); const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator()); - const std::shared_ptr<Tensor> scale = batchOp->getInput(1); - const std::shared_ptr<Tensor> shift = batchOp->getInput(2); - const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3); - const std::shared_ptr<Tensor> b_var = batchOp->getInput(4); + std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf; + const Tensor& scale = batchOp->getInput(1)->refCastNative(scaleBuf); + const Tensor& shift = batchOp->getInput(2)->refCastNative(shiftBuf); + const Tensor& b_mean = batchOp->getInput(3)->refCastNative(b_meanBuf); + const Tensor& b_var = batchOp->getInput(4)->refCastNative(b_meanBuf); const float epsilon = batchOp -> getAttr<float>("Epsilon"); const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels"); @@ -44,10 +45,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims"); - assert(scale->size() == convNbOutChannels); - assert(shift->size() == convNbOutChannels); - assert(b_mean->size() == convNbOutChannels); - assert(b_var->size() == convNbOutChannels); + assert(scale.size() == convNbOutChannels); + assert(shift.size() == convNbOutChannels); + assert(b_mean.size() == convNbOutChannels); + assert(b_var.size() == convNbOutChannels); assert(epsilon > 0.0); // TODO : no no_bias attribute ? @@ -56,9 +57,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr unsigned int count = 0; for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { - // TODO: get() assumed dataType is float... - if (b_var->get<float>(outChId) > 1.0e-12) { - meanVariance += b_var->get<float>(outChId); + if (b_var.get<float>(outChId) > 1.0e-12) { + meanVariance += b_var.get<float>(outChId); ++count; } else { @@ -71,39 +71,43 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n"); } - std::shared_ptr<Tensor> weight = convOp -> getInput(1); - std::shared_ptr<Tensor> bias = convOp -> getInput(2); + std::shared_ptr<Tensor> weightBuf, biasBuf; + Tensor& weight = convOp->getInput(1)->refCastNative(weightBuf); + Tensor& bias = convOp->getInput(2)->refCastNative(biasBuf); for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) { // Corrected for zero-variance issue: // "A Quantization-Friendly Separable Convolution for MobileNets" // https://arxiv.org/pdf/1803.08607.pdf // to help post-training quantization - const float factor = scale->get<float>(outChId) - / std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0) - ? b_var->get<float>(outChId) : meanVariance)); + const float factor = scale.get<float>(outChId) + / std::sqrt(epsilon + ((b_var.get<float>(outChId) > 1.0e-12 || count == 0) + ? b_var.get<float>(outChId) : meanVariance)); // Weights adjustments for (std::size_t channel = 0; channel < channelsSize; ++channel) { // TODO : Suppose kerneldims = 2 for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){ for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){ std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1}; - // TODO : suppose weights are float - float weightValue = weight->get<float>(currentIdx); - weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights + float weightValue = weight.get<float>(currentIdx); + weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights } } } // TODO : check if noBias==true is set, then set biasValue to 0 - float biasValue = bias->get<float>(outChId); + float biasValue = bias.get<float>(outChId); - biasValue = shift->get<float>(outChId) + (biasValue - b_mean->get<float>(outChId)) * factor; + biasValue = shift.get<float>(outChId) + (biasValue - b_mean.get<float>(outChId)) * factor; - bias->set<float>(outChId, biasValue); + bias.set<float>(outChId, biasValue); } + // Copy values back to the original tensors (actual copy only if needed) + convOp->getInput(1)->copyCastFrom(weight); + convOp->getInput(2)->copyCastFrom(bias); + GraphView::replace(std::set<std::shared_ptr<Node>>({ batchnormNode, batchnormNode->input(1).first, -- GitLab