From d506f102882d1fae499af47fc7a0e7c9663ae42d Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sat, 9 Dec 2023 23:57:17 +0100
Subject: [PATCH] Make FuseBatchNorm work for any type

---
 include/aidge/data/Tensor.hpp  | 30 +++++++++++++++++++--
 src/recipies/FuseBatchNorm.cpp | 48 ++++++++++++++++++----------------
 2 files changed, 54 insertions(+), 24 deletions(-)

diff --git a/include/aidge/data/Tensor.hpp b/include/aidge/data/Tensor.hpp
index 9a9f26bb2..94362c6cf 100644
--- a/include/aidge/data/Tensor.hpp
+++ b/include/aidge/data/Tensor.hpp
@@ -464,17 +464,29 @@ class Tensor : public Data,
         return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
     }
 
+    template <typename expectedType>
+    const expectedType& get(std::size_t idx) const {
+        // TODO : add assert expected Type compatible with datatype
+        // TODO : add assert idx < Size
+        return *reinterpret_cast<expectedType *>(mImpl->getRaw(idx));
+    }
+
     template <typename expectedType>
     expectedType& get(std::vector<std::size_t> coordIdx){
         return get<expectedType>(getIdx(coordIdx));
     }
 
+    template <typename expectedType>
+    const expectedType& get(std::vector<std::size_t> coordIdx) const {
+        return get<expectedType>(getIdx(coordIdx));
+    }
+
     template <typename expectedType>
     void set(std::size_t idx, expectedType value){
         // TODO : add assert expected Type compatible with datatype
         // TODO : add assert idx < Size
-        void* dataPtr = mImpl->getRaw(idx);
-        std::memcpy(dataPtr, &value, sizeof(expectedType));
+        expectedType* dataPtr = static_cast<expectedType*>(mImpl->getRaw(idx));
+        *dataPtr = value;
     }
 
     template <typename expectedType>
@@ -695,6 +707,20 @@ class Tensor : public Data,
         return refCast(fallback, target.dataType()).ref(fallback, device.first, device.second);
     }
 
+    /**
+     * Return a reference to a Tensor with float32 type on CPU:
+     * - itself, if already with the right characteristics;
+     * - the provided Tensor, overwritten with the copy-casted data.
+     * @param fallback A shared_ptr to Tensor ready to be overwritten if necessary.
+     * The shared_ptr does not need to be initialized. No new memory allocation
+     * will occur if fallback has already been allocated with the right 
+     * type/size/device.
+     * @return Reference to either itself or to fallback.
+    */
+    Tensor& refCastNative(std::shared_ptr<Tensor>& fallback) {
+        return refCast(fallback, DataType::Float32).ref(fallback, "cpu");
+    }
+
 private:
     ///\bug not protected against overflow
     std::size_t computeSize() {
diff --git a/src/recipies/FuseBatchNorm.cpp b/src/recipies/FuseBatchNorm.cpp
index ffb4599d8..8b9b4dba3 100644
--- a/src/recipies/FuseBatchNorm.cpp
+++ b/src/recipies/FuseBatchNorm.cpp
@@ -33,10 +33,11 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
     const std::shared_ptr<BatchNorm_Op<2>> batchOp = std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
     const std::shared_ptr<Conv_Op<2>> convOp = std::static_pointer_cast<Conv_Op<2>>(convNode->getOperator());
 
-    const std::shared_ptr<Tensor> scale  = batchOp->getInput(1);
-    const std::shared_ptr<Tensor> shift  = batchOp->getInput(2);
-    const std::shared_ptr<Tensor> b_mean = batchOp->getInput(3);
-    const std::shared_ptr<Tensor> b_var  = batchOp->getInput(4);
+    std::shared_ptr<Tensor> scaleBuf, shiftBuf, b_meanBuf, b_varBuf;
+    const Tensor& scale = batchOp->getInput(1)->refCastNative(scaleBuf);
+    const Tensor& shift = batchOp->getInput(2)->refCastNative(shiftBuf);
+    const Tensor& b_mean = batchOp->getInput(3)->refCastNative(b_meanBuf);
+    const Tensor& b_var = batchOp->getInput(4)->refCastNative(b_meanBuf);
 
     const float epsilon = batchOp -> getAttr<float>("Epsilon");
     const DimSize_t convNbOutChannels = convOp -> getAttr<DimSize_t>("OutChannels");
@@ -44,10 +45,10 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
     const std::array<DimSize_t, 2> kernelDims = convOp -> getAttr<std::array<DimSize_t, 2>>("KernelDims");
 
 
-    assert(scale->size()  == convNbOutChannels);
-    assert(shift->size()  == convNbOutChannels);
-    assert(b_mean->size() == convNbOutChannels);
-    assert(b_var->size()  == convNbOutChannels);
+    assert(scale.size()  == convNbOutChannels);
+    assert(shift.size()  == convNbOutChannels);
+    assert(b_mean.size() == convNbOutChannels);
+    assert(b_var.size()  == convNbOutChannels);
     assert(epsilon > 0.0);
     // TODO : no no_bias attribute ?
 
@@ -56,9 +57,8 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
     unsigned int count = 0;
 
     for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
-        // TODO: get() assumed dataType is float...
-        if (b_var->get<float>(outChId) > 1.0e-12) {
-            meanVariance += b_var->get<float>(outChId);
+        if (b_var.get<float>(outChId) > 1.0e-12) {
+            meanVariance += b_var.get<float>(outChId);
             ++count;
         }
         else {
@@ -71,39 +71,43 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, std::shared_ptr
         printf("Warning: variance < 1e-12 for all outputs! Is the network correctly trained?\n");
     }
 
-    std::shared_ptr<Tensor> weight = convOp -> getInput(1);
-    std::shared_ptr<Tensor> bias = convOp -> getInput(2);
+    std::shared_ptr<Tensor> weightBuf, biasBuf;
+    Tensor& weight = convOp->getInput(1)->refCastNative(weightBuf);
+    Tensor& bias = convOp->getInput(2)->refCastNative(biasBuf);
 
     for (std::size_t outChId = 0; outChId < convNbOutChannels; ++outChId) {
         // Corrected for zero-variance issue:
         // "A Quantization-Friendly Separable Convolution for MobileNets"
         // https://arxiv.org/pdf/1803.08607.pdf
         // to help post-training quantization
-        const float factor = scale->get<float>(outChId)
-            / std::sqrt(epsilon + ((b_var->get<float>(outChId) > 1.0e-12 || count == 0)
-                        ? b_var->get<float>(outChId) : meanVariance));
+        const float factor = scale.get<float>(outChId)
+            / std::sqrt(epsilon + ((b_var.get<float>(outChId) > 1.0e-12 || count == 0)
+                        ? b_var.get<float>(outChId) : meanVariance));
         // Weights adjustments
         for (std::size_t channel = 0; channel < channelsSize; ++channel) {
             // TODO : Suppose kerneldims = 2
             for(std::size_t k0 = 0; k0 < kernelDims[0]; ++ k0){
                 for(std::size_t k1 = 0; k1 < kernelDims[1]; ++ k1){
                     std::vector<DimSize_t> currentIdx = {outChId, channel, k0, k1};
-                    // TODO : suppose weights are float
-                    float weightValue = weight->get<float>(currentIdx);
-                    weight->set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
+                    float weightValue = weight.get<float>(currentIdx);
+                    weight.set<float>(currentIdx, weightValue*factor); // Update check it update Conv weights
                 }
             }
         }
 
         // TODO : check if noBias==true is set, then set biasValue to 0
-        float biasValue = bias->get<float>(outChId);
+        float biasValue = bias.get<float>(outChId);
 
-        biasValue = shift->get<float>(outChId) + (biasValue - b_mean->get<float>(outChId)) * factor;
+        biasValue = shift.get<float>(outChId) + (biasValue - b_mean.get<float>(outChId)) * factor;
 
-        bias->set<float>(outChId, biasValue);
+        bias.set<float>(outChId, biasValue);
 
     }
 
+    // Copy values back to the original tensors (actual copy only if needed)
+    convOp->getInput(1)->copyCastFrom(weight);
+    convOp->getInput(2)->copyCastFrom(bias);
+
     GraphView::replace(std::set<std::shared_ptr<Node>>({
         batchnormNode,
         batchnormNode->input(1).first,
-- 
GitLab