From 72f494313f6948b1f51f22dd951eb7fa4a96afa4 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 5 Nov 2024 14:45:28 +0000
Subject: [PATCH] [add] int32 and int64 types to const_filler and
 uniform_filler

---
 python_binding/filler/pybind_Filler.cpp | 50 ++++++++++++++++++-------
 src/filler/ConstantFiller.cpp           |  3 +-
 src/filler/UniformFiller.cpp            | 17 +++++++--
 3 files changed, 52 insertions(+), 18 deletions(-)

diff --git a/python_binding/filler/pybind_Filler.cpp b/python_binding/filler/pybind_Filler.cpp
index a85c0d6cd..dbf9a4845 100644
--- a/python_binding/filler/pybind_Filler.cpp
+++ b/python_binding/filler/pybind_Filler.cpp
@@ -30,11 +30,17 @@ void init_Filler(py::module &m) {
          [](std::shared_ptr<Tensor> tensor, py::object value) -> void {
              switch (tensor->dataType()) {
                  case DataType::Float64:
-                     constantFiller<double>(tensor, value.cast<double>());
+                     constantFiller<cpptype_t<DataType::Float64>>(tensor, value.cast<cpptype_t<DataType::Float64>>());
                      break;
                  case DataType::Float32:
-                     constantFiller<float>(tensor, value.cast<float>());
+                     constantFiller<cpptype_t<DataType::Float32>>(tensor, value.cast<cpptype_t<DataType::Float32>>());
                      break;
+                case DataType::Int64:
+                    constantFiller<cpptype_t<DataType::Int64>>(tensor, value.cast<cpptype_t<DataType::Int64>>());
+                    break;
+                case DataType::Int32:
+                    constantFiller<cpptype_t<DataType::Int32>>(tensor, value.cast<cpptype_t<DataType::Int32>>());
+                    break;
                  default:
                      AIDGE_THROW_OR_ABORT(
                          py::value_error,
@@ -44,14 +50,14 @@ void init_Filler(py::module &m) {
          py::arg("tensor"), py::arg("value"))
         .def(
             "normal_filler",
-            [](std::shared_ptr<Tensor> tensor, double mean,
-               double stdDev) -> void {
+            [](std::shared_ptr<Tensor> tensor, py::object mean,
+               py::object stdDev) -> void {
                 switch (tensor->dataType()) {
                     case DataType::Float64:
-                        normalFiller<double>(tensor, mean, stdDev);
+                        normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float64>>(), stdDev.cast<cpptype_t<DataType::Float64>>());
                         break;
                     case DataType::Float32:
-                        normalFiller<float>(tensor, mean, stdDev);
+                        normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float32>>(), stdDev.cast<cpptype_t<DataType::Float32>>());
                         break;
                     default:
                         AIDGE_THROW_OR_ABORT(
@@ -60,23 +66,39 @@ void init_Filler(py::module &m) {
                 }
             },
             py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0)
-        .def(
-            "uniform_filler",
-            [](std::shared_ptr<Tensor> tensor, double min, double max) -> void {
+        .def("uniform_filler", [] (std::shared_ptr<Tensor> tensor, py::object min, py::object max) -> void {
+            if (py::isinstance<py::int_>(min) && py::isinstance<py::int_>(max)) {
                 switch (tensor->dataType()) {
-                    case DataType::Float64:
-                        uniformFiller<double>(tensor, min, max);
+                    case DataType::Int32:
+                        uniformFiller<std::int32_t>(tensor, min.cast<std::int32_t>(), max.cast<std::int32_t>());
+                        break;
+                    case DataType::Int64:
+                        uniformFiller<std::int64_t>(tensor, min.cast<std::int64_t>(), max.cast<std::int64_t>());
+                        break;
+                    default:
+                        AIDGE_THROW_OR_ABORT(
+                            py::value_error,
+                            "Data type is not supported for Uniform filler.");
                         break;
+                }
+            } else if (py::isinstance<py::float_>(min) && py::isinstance<py::float_>(max)) {
+                switch (tensor->dataType()) {
                     case DataType::Float32:
-                        uniformFiller<float>(tensor, min, max);
+                        uniformFiller<float>(tensor, min.cast<float>(), max.cast<float>());
+                        break;
+                    case DataType::Float64:
+                        uniformFiller<double>(tensor, min.cast<double>(), max.cast<double>());
                         break;
                     default:
                         AIDGE_THROW_OR_ABORT(
                             py::value_error,
                             "Data type is not supported for Uniform filler.");
+                        break;
                 }
-            },
-            py::arg("tensor"), py::arg("min"), py::arg("max"))
+            } else {
+                AIDGE_THROW_OR_ABORT(py::value_error,"Input must be either an int or a float.");
+            }
+            }, py::arg("tensor"), py::arg("min"), py::arg("max"))
         .def(
             "xavier_uniform_filler",
             [](std::shared_ptr<Tensor> tensor, py::object scaling,
diff --git a/src/filler/ConstantFiller.cpp b/src/filler/ConstantFiller.cpp
index 1e992f4a1..b2118866f 100644
--- a/src/filler/ConstantFiller.cpp
+++ b/src/filler/ConstantFiller.cpp
@@ -39,6 +39,7 @@ void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValu
     tensor->copyCastFrom(tensorWithValues);
 }
 
-
+template void Aidge::constantFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>, std::int32_t);
+template void Aidge::constantFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>, std::int64_t);
 template void Aidge::constantFiller<float>(std::shared_ptr<Aidge::Tensor>, float);
 template void Aidge::constantFiller<double>(std::shared_ptr<Aidge::Tensor>, double);
diff --git a/src/filler/UniformFiller.cpp b/src/filler/UniformFiller.cpp
index a942f59d7..1951fcc62 100644
--- a/src/filler/UniformFiller.cpp
+++ b/src/filler/UniformFiller.cpp
@@ -8,8 +8,9 @@
  * SPDX-License-Identifier: EPL-2.0
  *
  ********************************************************************************/
+#include <cstdint>  // std::int32_t
 #include <memory>
-#include <random>  // normal_distribution, uniform_real_distribution
+#include <random>   // normal_distribution, uniform_real_distribution
 
 #include "aidge/data/Tensor.hpp"
 #include "aidge/filler/Filler.hpp"
@@ -19,10 +20,16 @@ template <typename T>
 void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) {
     AIDGE_ASSERT(tensor->getImpl(),
                  "Tensor got no implementation, cannot fill it.");
-    AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
+    AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type {} and {}",NativeType<T>::type, tensor->dataType());
 
 
-    std::uniform_real_distribution<T> uniformDist(min, max);
+     using DistType = typename std::conditional<
+        std::is_integral<T>::value,
+        std::uniform_int_distribution<T>,
+        std::uniform_real_distribution<T>
+    >::type;
+
+    DistType uniformDist(min, max);
 
     std::shared_ptr<Aidge::Tensor> cpyTensor;
     // Create cpy only if tensor not on CPU
@@ -42,3 +49,7 @@ template void Aidge::uniformFiller<float>(std::shared_ptr<Aidge::Tensor>, float,
                                           float);
 template void Aidge::uniformFiller<double>(std::shared_ptr<Aidge::Tensor>,
                                            double, double);
+template void Aidge::uniformFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>,
+                                                 std::int32_t, std::int32_t);
+template void Aidge::uniformFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>,
+                                                 std::int64_t, std::int64_t);
\ No newline at end of file
-- 
GitLab