diff --git a/include/aidge/filler/Filler.hpp b/include/aidge/filler/Filler.hpp index ae58b78f203da127dc2b4a8cca4e95d8c60fa19f..e6de2f930e7d06be47244a3f9f8cd73c3aa18b51 100644 --- a/include/aidge/filler/Filler.hpp +++ b/include/aidge/filler/Filler.hpp @@ -13,13 +13,12 @@ #define AIDGE_CORE_FILLER_H_ #include <memory> -#include <random> // normal_distribution +#include <random> // normal_distribution, uniform_real_distribution #include "aidge/data/Tensor.hpp" namespace Aidge { - template <typename T> void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { AIDGE_ASSERT(tensor->getImpl(), @@ -39,15 +38,16 @@ void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { // Copy values back to the original tensors (actual copy only if needed) tensor->copyCastFrom(tensorWithValues); } - -template <typename T> // TODO: Keep template or use switch case depending on Tensor datatype ? -void normalFiller(std::shared_ptr<Tensor> tensor, double mean=0.0, double stdDev= 1.0){ +// TODO: Keep template or use switch case depending on Tensor datatype ? +template <typename T> +void normalFiller(std::shared_ptr<Tensor> tensor, + double mean = 0.0, + double stdDev = 1.0) { AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); - + AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); std::random_device rd; - std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator - + std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator std::normal_distribution<T> normalDist(mean, stdDev); @@ -56,17 +56,39 @@ void normalFiller(std::shared_ptr<Tensor> tensor, double mean=0.0, double stdDev Tensor& tensorWithValues = tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + // Setting values + for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { + tensorWithValues.set<T>(idx, normalDist(gen)); + } + // Copy values back to the original tensors (actual copy only if needed) + tensor->copyCastFrom(tensorWithValues); +}; + +// TODO: Keep template or use switch case depending on Tensor datatype ? +template<typename T> +void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) { + AIDGE_ASSERT(tensor->getImpl(), + "Tensor got no implementation, cannot fill it."); + AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + std::random_device rd; + std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator + + std::uniform_distribution<T> uniformDist(min, max); + + std::shared_ptr<Tensor> cpyTensor; + // Create cpy only if tensor not on CPU + Tensor& tensorWithValues = + tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); // Setting values for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { - tensorWithValues.set<T>(idx, normalDist(gen)); + tensorWithValues.set<T>(idx, uniformDist(gen)); } // Copy values back to the original tensors (actual copy only if needed) tensor->copyCastFrom(tensorWithValues); }; -// void uniformFiller(std::shared_ptr<Tensor> tensor); // void xavierFiller(std::shared_ptr<Tensor> tensor); // void heFiller(std::shared_ptr<Tensor> tensor); diff --git a/python_binding/filler/pybind_Filler.cpp b/python_binding/filler/pybind_Filler.cpp index d735c68317c8a1d4a433f5dfdca5a83afc72b247..6938ef2b9c451329607fc2c7630dbee29a88f872 100644 --- a/python_binding/filler/pybind_Filler.cpp +++ b/python_binding/filler/pybind_Filler.cpp @@ -56,7 +56,7 @@ void init_Filler(py::module &m) { py::value_error, "Data type is not supported for Constant filler."); } - }) + }, py::arg("tensor"), py::arg("value")) .def("normal_filler", [](std::shared_ptr<Tensor> tensor, double mean, double stdDev) -> void { @@ -73,7 +73,22 @@ void init_Filler(py::module &m) { "Data type is not supported for Normal filler."); } }, py::arg("tensor"), py::arg("mean")=0.0, py::arg("stdDev")=1.0) - + .def("uniform_filler", + [](std::shared_ptr<Tensor> tensor, double min, + double max) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + uniformFiller<double>(tensor, min, max); + break; + case DataType::Float32: + uniformFiller<float>(tensor, min, max); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Uniform filler."); + } + }, py::arg("tensor"), py::arg("min"), py::arg("max")) ; } } // namespace Aidge