From 28096a02e29b92fda2d326be76606847fa246709 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Tue, 19 Mar 2024 08:31:15 +0000
Subject: [PATCH] Add basic normal filler.

---
 include/aidge/filler/Filler.hpp         | 33 +++++++++++++++++++++----
 python_binding/filler/pybind_Filler.cpp | 25 ++++++++++++++++---
 2 files changed, 50 insertions(+), 8 deletions(-)

diff --git a/include/aidge/filler/Filler.hpp b/include/aidge/filler/Filler.hpp
index a8419e1ca..ae58b78f2 100644
--- a/include/aidge/filler/Filler.hpp
+++ b/include/aidge/filler/Filler.hpp
@@ -13,15 +13,12 @@
 #define AIDGE_CORE_FILLER_H_
 
 #include <memory>
+#include <random> // normal_distribution
 
 #include "aidge/data/Tensor.hpp"
 
 namespace Aidge {
 
-// void heFiller(std::shared_ptr<Tensor> tensor);
-
-// template <typename T>
-// void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue);
 
 template <typename T>
 void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) {
@@ -43,9 +40,35 @@ void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) {
     tensor->copyCastFrom(tensorWithValues);
 }
 
-void normalFiller(std::shared_ptr<Tensor> tensor, float mean, float var);
+template <typename T> // TODO: Keep template or use switch case depending on Tensor datatype ?
+void normalFiller(std::shared_ptr<Tensor> tensor, double mean=0.0, double stdDev= 1.0){
+    AIDGE_ASSERT(tensor->getImpl(),
+                 "Tensor got no implementation, cannot fill it.");
+
+    std::random_device rd;
+    std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
+
+
+    std::normal_distribution<T> normalDist(mean, stdDev);
+
+    std::shared_ptr<Tensor> cpyTensor;
+    // Create cpy only if tensor not on CPU
+    Tensor& tensorWithValues =
+        tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
+
+
+
+    // Setting values
+    for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
+        tensorWithValues.set<T>(idx, normalDist(gen));
+    }
+
+    // Copy values back to the original tensors (actual copy only if needed)
+    tensor->copyCastFrom(tensorWithValues);
+};
 // void uniformFiller(std::shared_ptr<Tensor> tensor);
 // void xavierFiller(std::shared_ptr<Tensor> tensor);
+// void heFiller(std::shared_ptr<Tensor> tensor);
 
 }  // namespace Aidge
 
diff --git a/python_binding/filler/pybind_Filler.cpp b/python_binding/filler/pybind_Filler.cpp
index dacf0af40..d735c6831 100644
--- a/python_binding/filler/pybind_Filler.cpp
+++ b/python_binding/filler/pybind_Filler.cpp
@@ -52,9 +52,28 @@ void init_Filler(py::module &m) {
                           tensor, value.cast<std::uint16_t>());
                       break;
                   default:
-                      AIDGE_THROW_OR_ABORT(py::value_error,
-                                           "Data type is not supported.");
+                      AIDGE_THROW_OR_ABORT(
+                          py::value_error,
+                          "Data type is not supported for Constant filler.");
               }
-          });
+          })
+        .def("normal_filler",
+             [](std::shared_ptr<Tensor> tensor, double mean,
+                double stdDev) -> void {
+                 switch (tensor->dataType()) {
+                     case DataType::Float64:
+                         normalFiller<double>(tensor, mean, stdDev);
+                         break;
+                     case DataType::Float32:
+                         normalFiller<float>(tensor, mean, stdDev);
+                         break;
+                     default:
+                         AIDGE_THROW_OR_ABORT(
+                             py::value_error,
+                             "Data type is not supported for Normal filler.");
+                 }
+             }, py::arg("tensor"), py::arg("mean")=0.0, py::arg("stdDev")=1.0)
+
+        ;
 }
 }  // namespace Aidge
-- 
GitLab