From 72f494313f6948b1f51f22dd951eb7fa4a96afa4 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 5 Nov 2024 14:45:28 +0000 Subject: [PATCH] [add] int32 and int64 types to const_filler and uniform_filler --- python_binding/filler/pybind_Filler.cpp | 50 ++++++++++++++++++------- src/filler/ConstantFiller.cpp | 3 +- src/filler/UniformFiller.cpp | 17 +++++++-- 3 files changed, 52 insertions(+), 18 deletions(-) diff --git a/python_binding/filler/pybind_Filler.cpp b/python_binding/filler/pybind_Filler.cpp index a85c0d6cd..dbf9a4845 100644 --- a/python_binding/filler/pybind_Filler.cpp +++ b/python_binding/filler/pybind_Filler.cpp @@ -30,11 +30,17 @@ void init_Filler(py::module &m) { [](std::shared_ptr<Tensor> tensor, py::object value) -> void { switch (tensor->dataType()) { case DataType::Float64: - constantFiller<double>(tensor, value.cast<double>()); + constantFiller<cpptype_t<DataType::Float64>>(tensor, value.cast<cpptype_t<DataType::Float64>>()); break; case DataType::Float32: - constantFiller<float>(tensor, value.cast<float>()); + constantFiller<cpptype_t<DataType::Float32>>(tensor, value.cast<cpptype_t<DataType::Float32>>()); break; + case DataType::Int64: + constantFiller<cpptype_t<DataType::Int64>>(tensor, value.cast<cpptype_t<DataType::Int64>>()); + break; + case DataType::Int32: + constantFiller<cpptype_t<DataType::Int32>>(tensor, value.cast<cpptype_t<DataType::Int32>>()); + break; default: AIDGE_THROW_OR_ABORT( py::value_error, @@ -44,14 +50,14 @@ void init_Filler(py::module &m) { py::arg("tensor"), py::arg("value")) .def( "normal_filler", - [](std::shared_ptr<Tensor> tensor, double mean, - double stdDev) -> void { + [](std::shared_ptr<Tensor> tensor, py::object mean, + py::object stdDev) -> void { switch (tensor->dataType()) { case DataType::Float64: - normalFiller<double>(tensor, mean, stdDev); + normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float64>>(), stdDev.cast<cpptype_t<DataType::Float64>>()); break; case DataType::Float32: - normalFiller<float>(tensor, mean, stdDev); + normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float32>>(), stdDev.cast<cpptype_t<DataType::Float32>>()); break; default: AIDGE_THROW_OR_ABORT( @@ -60,23 +66,39 @@ void init_Filler(py::module &m) { } }, py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0) - .def( - "uniform_filler", - [](std::shared_ptr<Tensor> tensor, double min, double max) -> void { + .def("uniform_filler", [] (std::shared_ptr<Tensor> tensor, py::object min, py::object max) -> void { + if (py::isinstance<py::int_>(min) && py::isinstance<py::int_>(max)) { switch (tensor->dataType()) { - case DataType::Float64: - uniformFiller<double>(tensor, min, max); + case DataType::Int32: + uniformFiller<std::int32_t>(tensor, min.cast<std::int32_t>(), max.cast<std::int32_t>()); + break; + case DataType::Int64: + uniformFiller<std::int64_t>(tensor, min.cast<std::int64_t>(), max.cast<std::int64_t>()); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Uniform filler."); break; + } + } else if (py::isinstance<py::float_>(min) && py::isinstance<py::float_>(max)) { + switch (tensor->dataType()) { case DataType::Float32: - uniformFiller<float>(tensor, min, max); + uniformFiller<float>(tensor, min.cast<float>(), max.cast<float>()); + break; + case DataType::Float64: + uniformFiller<double>(tensor, min.cast<double>(), max.cast<double>()); break; default: AIDGE_THROW_OR_ABORT( py::value_error, "Data type is not supported for Uniform filler."); + break; } - }, - py::arg("tensor"), py::arg("min"), py::arg("max")) + } else { + AIDGE_THROW_OR_ABORT(py::value_error,"Input must be either an int or a float."); + } + }, py::arg("tensor"), py::arg("min"), py::arg("max")) .def( "xavier_uniform_filler", [](std::shared_ptr<Tensor> tensor, py::object scaling, diff --git a/src/filler/ConstantFiller.cpp b/src/filler/ConstantFiller.cpp index 1e992f4a1..b2118866f 100644 --- a/src/filler/ConstantFiller.cpp +++ b/src/filler/ConstantFiller.cpp @@ -39,6 +39,7 @@ void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValu tensor->copyCastFrom(tensorWithValues); } - +template void Aidge::constantFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>, std::int32_t); +template void Aidge::constantFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>, std::int64_t); template void Aidge::constantFiller<float>(std::shared_ptr<Aidge::Tensor>, float); template void Aidge::constantFiller<double>(std::shared_ptr<Aidge::Tensor>, double); diff --git a/src/filler/UniformFiller.cpp b/src/filler/UniformFiller.cpp index a942f59d7..1951fcc62 100644 --- a/src/filler/UniformFiller.cpp +++ b/src/filler/UniformFiller.cpp @@ -8,8 +8,9 @@ * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ +#include <cstdint> // std::int32_t #include <memory> -#include <random> // normal_distribution, uniform_real_distribution +#include <random> // normal_distribution, uniform_real_distribution #include "aidge/data/Tensor.hpp" #include "aidge/filler/Filler.hpp" @@ -19,10 +20,16 @@ template <typename T> void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) { AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); - AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type {} and {}",NativeType<T>::type, tensor->dataType()); - std::uniform_real_distribution<T> uniformDist(min, max); + using DistType = typename std::conditional< + std::is_integral<T>::value, + std::uniform_int_distribution<T>, + std::uniform_real_distribution<T> + >::type; + + DistType uniformDist(min, max); std::shared_ptr<Aidge::Tensor> cpyTensor; // Create cpy only if tensor not on CPU @@ -42,3 +49,7 @@ template void Aidge::uniformFiller<float>(std::shared_ptr<Aidge::Tensor>, float, float); template void Aidge::uniformFiller<double>(std::shared_ptr<Aidge::Tensor>, double, double); +template void Aidge::uniformFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>, + std::int32_t, std::int32_t); +template void Aidge::uniformFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>, + std::int64_t, std::int64_t); \ No newline at end of file -- GitLab