From d08841260fd6daecf86f1cf0bf0c02ee13a3fce9 Mon Sep 17 00:00:00 2001
From: IKucher <Inna.KUCHER@cea.fr>
Date: Wed, 15 Nov 2023 09:05:48 +0000
Subject: [PATCH] fixing the scaling for PTQ using round and saturation

---
 .../operator/ScalingImpl_forward_kernels.hpp  | 59 +++++++++++++++++++
 1 file changed, 59 insertions(+)

diff --git a/include/aidge/backend/cpu/operator/ScalingImpl_forward_kernels.hpp b/include/aidge/backend/cpu/operator/ScalingImpl_forward_kernels.hpp
index 8fe13bce..7d31dd5a 100644
--- a/include/aidge/backend/cpu/operator/ScalingImpl_forward_kernels.hpp
+++ b/include/aidge/backend/cpu/operator/ScalingImpl_forward_kernels.hpp
@@ -16,7 +16,60 @@
 
 #include "aidge/backend/cpu/operator/ScalingImpl.hpp"
 
+//TODO : improve propagate, n2d2 :
+/*
+template<typename T>
+void N2D2::floatingPointScaling_propagate(const Tensor<T>& input, Tensor<T>& output,
+                                          std::size_t batchSize, std::size_t nbChannels,
+                                          std::size_t height, std::size_t width,
+                                          bool isClipped,
+                                          const std::vector<Float_T>& clippingFactorPerChannel,
+                                          const std::vector<Float_T>& scalingFactorPerChannel,
+                                          std::size_t quantizedNbBits, bool isOutputUnsigned)
+{
+    std::size_t index = 0;
+    for (std::size_t batch = 0; batch < batchSize; batch++) {
+        for(std::size_t ch = 0; ch < nbChannels; ch++) {
+            for(std::size_t y = 0; y < height; y++) {
+                for(std::size_t x = 0; x < width; x++) {
+
+                    T res = isClipped ? Clip(input(index), clippingFactorPerChannel[ch])
+                                    : input(index);
+                    res = Scale(res, scalingFactorPerChannel[ch]);
+
+                    if(quantizedNbBits > 0) {
+                        res = saturate(std::round(res), quantizedNbBits, isOutputUnsigned);
+                    }
+                    output(index) = (T) res;
+                    index++;
+                }
+            }
+        }
+    }
+}
+*/
+
+
 namespace Aidge {
+
+template <class O>
+const O& clamp(const O& x, const O& min, const O& max)
+{
+    return (x < min) ? min : (x > max) ? max : x;
+}
+
+template<class O>
+O saturate(O value, std::size_t quantizedNbBits, bool isOutputUnsigned) {
+    assert(quantizedNbBits > 0);
+    
+    const O min = isOutputUnsigned?0:
+                                  -(1ll << (quantizedNbBits - 1ll));
+    const O max = isOutputUnsigned?(1ll << quantizedNbBits) - 1ll:
+                                   (1ll << (quantizedNbBits - 1ll)) - 1ll;
+
+    return clamp(value, min, max);
+}
+
 template <class I, class O>
 void ScalingImpl_cpu_forward_kernel(const Scaling_Op::Attrs& attrs,
                                      std::size_t inputLenght,
@@ -26,9 +79,15 @@ void ScalingImpl_cpu_forward_kernel(const Scaling_Op::Attrs& attrs,
     const I* input = static_cast<const I*>(input_);
     O* output = static_cast<O*>(output_);
     const I& scalingFactor = static_cast<const I&>(std::get<0>(attrs));
+    std::size_t quantizedNbBits = static_cast<std::size_t>(std::get<1>(attrs));
+    bool isOutputUnsigned = static_cast<bool>(std::get<2>(attrs));
 
     for (std::size_t i = 0; i < inputLenght; ++i) {
         output[i] = input[i] * scalingFactor;
+
+        if(quantizedNbBits > 0) {
+                output[i] = saturate(std::round(output[i]), quantizedNbBits, isOutputUnsigned);
+        }
     }
 }
 
-- 
GitLab