From b3667685c0c56c35f113a5117f0b03bb616d5a0a Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 21 Mar 2024 16:00:53 +0000 Subject: [PATCH] Add a Random Generator handler class and use it for fillers. --- include/aidge/filler/Filler.hpp | 2 -- include/aidge/utils/Random.hpp | 50 ++++++++++++++++++++------ python_binding/pybind_core.cpp | 3 ++ python_binding/utils/pybind_Random.cpp | 24 +++++++++++++ src/data/DataProvider.cpp | 8 ++--- src/filler/HeFiller.cpp | 6 ++-- src/filler/NormalFiller.cpp | 5 ++- src/filler/UniformFiller.cpp | 6 ++-- src/filler/XavierFiller.cpp | 20 +++++------ src/utils/Random.cpp | 22 ++++++++++++ 10 files changed, 108 insertions(+), 38 deletions(-) create mode 100644 python_binding/utils/pybind_Random.cpp create mode 100644 src/utils/Random.cpp diff --git a/include/aidge/filler/Filler.hpp b/include/aidge/filler/Filler.hpp index c7b12a35c..51d01d87f 100644 --- a/include/aidge/filler/Filler.hpp +++ b/include/aidge/filler/Filler.hpp @@ -41,12 +41,10 @@ enum VarianceNorm { FanIn, Average, FanOut }; template <typename T> void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue); -// TODO: Keep template or use switch case depending on Tensor datatype ? template <typename T> void normalFiller(std::shared_ptr<Tensor> tensor, double mean = 0.0, double stdDev = 1.0); -// TODO: Keep template or use switch case depending on Tensor datatype ? template <typename T> void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max); diff --git a/include/aidge/utils/Random.hpp b/include/aidge/utils/Random.hpp index 704609c0c..73cbd1453 100644 --- a/include/aidge/utils/Random.hpp +++ b/include/aidge/utils/Random.hpp @@ -9,23 +9,53 @@ * ********************************************************************************/ - #ifndef AIDGE_RANDOM_H_ #define AIDGE_RANDOM_H_ - #include <algorithm> -#include <vector> #include <random> +#include <vector> +namespace Aidge { namespace Random { - void randShuffle(std::vector<unsigned int>& vec) { - std::random_device rd; - std::mt19937 g(rd()); - std::shuffle(vec.begin(), vec.end(), g); - } - +/** + * @brief Generator is a class created to handle only one Mersenne Twister + * pseudo-random number generator for the whole Aidge framework. + * + * All of its method are static. You can set a random seed and access the + * generator. + * By default, the random seed is set to 0 but selected randomly. + * + */ +class Generator { + public: + /** + * @brief Set a seed to the pseudo-random number generator. + * + * @return std::mt19937& + */ + static void setSeed(unsigned int seed); + static unsigned int getSeed() { return seed; }; + /** + * @brief Return a Mersenne Twister pseudo-random number generator. + * You can set the seed of this generator using ``setSeed`` method. + * + * @return std::mt19937& + */ + static std::mt19937& get() { return generator; }; + + private: + // Mersenne Twister pseudo-random number generator + static std::mt19937 generator; + static unsigned int seed; +}; + +inline void randShuffle(std::vector<unsigned int>& vec) { + std::shuffle(vec.begin(), vec.end(), Aidge::Random::Generator::get()); } -#endif //AIDGE_RANDOM_H_ \ No newline at end of file +} // namespace Random +} // namespace Aidge + +#endif // AIDGE_RANDOM_H_ diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 80f0cc44f..5ffa8f6b4 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -16,6 +16,7 @@ namespace py = pybind11; namespace Aidge { +void init_Random(py::module&); void init_Data(py::module&); void init_Database(py::module&); void init_DataProvider(py::module&); @@ -73,6 +74,8 @@ void init_TensorUtils(py::module&); void init_Filler(py::module&); void init_Aidge(py::module& m) { + init_Random(m); + init_Data(m); init_Database(m); init_DataProvider(m); diff --git a/python_binding/utils/pybind_Random.cpp b/python_binding/utils/pybind_Random.cpp new file mode 100644 index 000000000..a1956d2d1 --- /dev/null +++ b/python_binding/utils/pybind_Random.cpp @@ -0,0 +1,24 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include "aidge/utils/Random.hpp" + +namespace py = pybind11; + +namespace Aidge { + +void init_Random(py::module &m) { + auto mRand = m.def_submodule("random", "Random module."); + py::class_<Random::Generator>(mRand, "Generator") + .def_static("set_seed", Random::Generator::setSeed); +} +} // namespace Aidge diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp index 7783ed86c..5c3d1d7ef 100644 --- a/src/data/DataProvider.cpp +++ b/src/data/DataProvider.cpp @@ -41,8 +41,8 @@ Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::si } // Compute the number of bacthes depending on mDropLast boolean - mNbBatch = (mDropLast) ? - static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) : + mNbBatch = (mDropLast) ? + static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) : static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize)); } @@ -98,7 +98,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() con void Aidge::DataProvider::setBatches(){ - + mBatches.clear(); mBatches.resize(mNbItems); std::iota(mBatches.begin(), @@ -106,7 +106,7 @@ void Aidge::DataProvider::setBatches(){ 0U); if (mShuffle){ - Random::randShuffle(mBatches); + Aidge::Random::randShuffle(mBatches); } if (mNbItems % mBatchSize !=0){ // The last batch is not full diff --git a/src/filler/HeFiller.cpp b/src/filler/HeFiller.cpp index e49386b49..74d681f1a 100644 --- a/src/filler/HeFiller.cpp +++ b/src/filler/HeFiller.cpp @@ -13,6 +13,7 @@ #include "aidge/data/Tensor.hpp" #include "aidge/filler/Filler.hpp" +#include "aidge/utils/Random.hpp" template <typename T> void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor, @@ -36,9 +37,6 @@ void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor, ? meanNorm / ((fanIn + fanOut) / 2.0) : meanNorm / fanOut); - std::random_device rd; - std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator - std::normal_distribution<T> normalDist(mean, stdDev); std::shared_ptr<Tensor> cpyTensor; @@ -48,7 +46,7 @@ void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor, // Setting values for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { - tensorWithValues.set<T>(idx, scaling*normalDist(gen)); + tensorWithValues.set<T>(idx, scaling*normalDist(Aidge::Random::Generator::get())); } // Copy values back to the original tensors (actual copy only if needed) diff --git a/src/filler/NormalFiller.cpp b/src/filler/NormalFiller.cpp index 0fadbd134..f30b32431 100644 --- a/src/filler/NormalFiller.cpp +++ b/src/filler/NormalFiller.cpp @@ -13,6 +13,7 @@ #include "aidge/data/Tensor.hpp" #include "aidge/filler/Filler.hpp" +#include "aidge/utils/Random.hpp" template <typename T> void Aidge::normalFiller(std::shared_ptr<Aidge::Tensor> tensor, double mean, @@ -20,8 +21,6 @@ void Aidge::normalFiller(std::shared_ptr<Aidge::Tensor> tensor, double mean, AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); - std::random_device rd; - std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator std::normal_distribution<T> normalDist(mean, stdDev); @@ -32,7 +31,7 @@ void Aidge::normalFiller(std::shared_ptr<Aidge::Tensor> tensor, double mean, // Setting values for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { - tensorWithValues.set<T>(idx, normalDist(gen)); + tensorWithValues.set<T>(idx, normalDist(Aidge::Random::Generator::get())); } // Copy values back to the original tensors (actual copy only if needed) diff --git a/src/filler/UniformFiller.cpp b/src/filler/UniformFiller.cpp index e45d6f13e..a942f59d7 100644 --- a/src/filler/UniformFiller.cpp +++ b/src/filler/UniformFiller.cpp @@ -13,14 +13,14 @@ #include "aidge/data/Tensor.hpp" #include "aidge/filler/Filler.hpp" +#include "aidge/utils/Random.hpp" template <typename T> void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) { AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); - std::random_device rd; - std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator + std::uniform_real_distribution<T> uniformDist(min, max); @@ -31,7 +31,7 @@ void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) { // Setting values for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { - tensorWithValues.set<T>(idx, uniformDist(gen)); + tensorWithValues.set<T>(idx, uniformDist(Aidge::Random::Generator::get())); } // Copy values back to the original tensors (actual copy only if needed) diff --git a/src/filler/XavierFiller.cpp b/src/filler/XavierFiller.cpp index f1c5d17e8..a1de15971 100644 --- a/src/filler/XavierFiller.cpp +++ b/src/filler/XavierFiller.cpp @@ -13,10 +13,11 @@ #include "aidge/data/Tensor.hpp" #include "aidge/filler/Filler.hpp" +#include "aidge/utils/Random.hpp" template <typename T> -void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling, - Aidge::VarianceNorm varianceNorm) { +void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, + T scaling, Aidge::VarianceNorm varianceNorm) { AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); @@ -30,9 +31,6 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling : fanOut); const T scale(std::sqrt(3.0 / n)); - std::random_device rd; - std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator - std::uniform_real_distribution<T> uniformDist(-scale, scale); std::shared_ptr<Aidge::Tensor> cpyTensor; @@ -41,8 +39,8 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); // Setting values for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { - T value = scaling * uniformDist(gen); - tensorWithValues.set<T>(idx, value); + tensorWithValues.set<T>( + idx, scaling * uniformDist(Aidge::Random::Generator::get())); } // Copy values back to the original tensors (actual copy only if needed) @@ -50,7 +48,7 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling } template <typename T> void Aidge::xavierNormalFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling, - Aidge::VarianceNorm varianceNorm) { + Aidge::VarianceNorm varianceNorm) { AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); @@ -64,9 +62,6 @@ void Aidge::xavierNormalFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling, : fanOut); const double stdDev(std::sqrt(1.0 / n)); - std::random_device rd; - std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator - std::normal_distribution<T> normalDist(0.0, stdDev); std::shared_ptr<Aidge::Tensor> cpyTensor; @@ -76,7 +71,8 @@ void Aidge::xavierNormalFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling, // Setting values for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { - tensorWithValues.set<T>(idx, scaling*normalDist(gen)); + tensorWithValues.set<T>( + idx, scaling * normalDist(Aidge::Random::Generator::get())); } // Copy values back to the original tensors (actual copy only if needed) diff --git a/src/utils/Random.cpp b/src/utils/Random.cpp new file mode 100644 index 000000000..e3716f4fc --- /dev/null +++ b/src/utils/Random.cpp @@ -0,0 +1,22 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/utils/Random.hpp" + +#include <random> // normal_distribution, uniform_real_distribution + +std::mt19937 Aidge::Random::Generator::generator{std::random_device{}()}; +unsigned int Aidge::Random::Generator::seed = 0; + +void Aidge::Random::Generator::setSeed(unsigned int seed) { + seed = seed; + generator.seed(seed); +} -- GitLab