diff --git a/include/aidge/filler/Filler.hpp b/include/aidge/filler/Filler.hpp index a8419e1caf3f8639e067892501b8b4c8996952ae..ae58b78f203da127dc2b4a8cca4e95d8c60fa19f 100644 --- a/include/aidge/filler/Filler.hpp +++ b/include/aidge/filler/Filler.hpp @@ -13,15 +13,12 @@ #define AIDGE_CORE_FILLER_H_ #include <memory> +#include <random> // normal_distribution #include "aidge/data/Tensor.hpp" namespace Aidge { -// void heFiller(std::shared_ptr<Tensor> tensor); - -// template <typename T> -// void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue); template <typename T> void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { @@ -43,9 +40,35 @@ void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { tensor->copyCastFrom(tensorWithValues); } -void normalFiller(std::shared_ptr<Tensor> tensor, float mean, float var); +template <typename T> // TODO: Keep template or use switch case depending on Tensor datatype ? +void normalFiller(std::shared_ptr<Tensor> tensor, double mean=0.0, double stdDev= 1.0){ + AIDGE_ASSERT(tensor->getImpl(), + "Tensor got no implementation, cannot fill it."); + + std::random_device rd; + std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator + + + std::normal_distribution<T> normalDist(mean, stdDev); + + std::shared_ptr<Tensor> cpyTensor; + // Create cpy only if tensor not on CPU + Tensor& tensorWithValues = + tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + + + + // Setting values + for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { + tensorWithValues.set<T>(idx, normalDist(gen)); + } + + // Copy values back to the original tensors (actual copy only if needed) + tensor->copyCastFrom(tensorWithValues); +}; // void uniformFiller(std::shared_ptr<Tensor> tensor); // void xavierFiller(std::shared_ptr<Tensor> tensor); +// void heFiller(std::shared_ptr<Tensor> tensor); } // namespace Aidge diff --git a/python_binding/filler/pybind_Filler.cpp b/python_binding/filler/pybind_Filler.cpp index dacf0af40012fdb64655e80c6949e688d3a5d131..d735c68317c8a1d4a433f5dfdca5a83afc72b247 100644 --- a/python_binding/filler/pybind_Filler.cpp +++ b/python_binding/filler/pybind_Filler.cpp @@ -52,9 +52,28 @@ void init_Filler(py::module &m) { tensor, value.cast<std::uint16_t>()); break; default: - AIDGE_THROW_OR_ABORT(py::value_error, - "Data type is not supported."); + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Constant filler."); } - }); + }) + .def("normal_filler", + [](std::shared_ptr<Tensor> tensor, double mean, + double stdDev) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + normalFiller<double>(tensor, mean, stdDev); + break; + case DataType::Float32: + normalFiller<float>(tensor, mean, stdDev); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Normal filler."); + } + }, py::arg("tensor"), py::arg("mean")=0.0, py::arg("stdDev")=1.0) + + ; } } // namespace Aidge