Skip to content
Snippets Groups Projects
Commit b3667685 authored by Cyril Moineau's avatar Cyril Moineau Committed by Maxence Naud
Browse files

Add a Random Generator handler class and use it for fillers.

parent 53b0e655
No related branches found
No related tags found
No related merge requests found
......@@ -41,12 +41,10 @@ enum VarianceNorm { FanIn, Average, FanOut };
template <typename T>
void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue);
// TODO: Keep template or use switch case depending on Tensor datatype ?
template <typename T>
void normalFiller(std::shared_ptr<Tensor> tensor, double mean = 0.0,
double stdDev = 1.0);
// TODO: Keep template or use switch case depending on Tensor datatype ?
template <typename T>
void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max);
......
......@@ -9,23 +9,53 @@
*
********************************************************************************/
#ifndef AIDGE_RANDOM_H_
#define AIDGE_RANDOM_H_
#include <algorithm>
#include <vector>
#include <random>
#include <vector>
namespace Aidge {
namespace Random {
void randShuffle(std::vector<unsigned int>& vec) {
std::random_device rd;
std::mt19937 g(rd());
std::shuffle(vec.begin(), vec.end(), g);
}
/**
* @brief Generator is a class created to handle only one Mersenne Twister
* pseudo-random number generator for the whole Aidge framework.
*
* All of its method are static. You can set a random seed and access the
* generator.
* By default, the random seed is set to 0 but selected randomly.
*
*/
class Generator {
public:
/**
* @brief Set a seed to the pseudo-random number generator.
*
* @return std::mt19937&
*/
static void setSeed(unsigned int seed);
static unsigned int getSeed() { return seed; };
/**
* @brief Return a Mersenne Twister pseudo-random number generator.
* You can set the seed of this generator using ``setSeed`` method.
*
* @return std::mt19937&
*/
static std::mt19937& get() { return generator; };
private:
// Mersenne Twister pseudo-random number generator
static std::mt19937 generator;
static unsigned int seed;
};
inline void randShuffle(std::vector<unsigned int>& vec) {
std::shuffle(vec.begin(), vec.end(), Aidge::Random::Generator::get());
}
#endif //AIDGE_RANDOM_H_
\ No newline at end of file
} // namespace Random
} // namespace Aidge
#endif // AIDGE_RANDOM_H_
......@@ -16,6 +16,7 @@
namespace py = pybind11;
namespace Aidge {
void init_Random(py::module&);
void init_Data(py::module&);
void init_Database(py::module&);
void init_DataProvider(py::module&);
......@@ -73,6 +74,8 @@ void init_TensorUtils(py::module&);
void init_Filler(py::module&);
void init_Aidge(py::module& m) {
init_Random(m);
init_Data(m);
init_Database(m);
init_DataProvider(m);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/utils/Random.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Random(py::module &m) {
auto mRand = m.def_submodule("random", "Random module.");
py::class_<Random::Generator>(mRand, "Generator")
.def_static("set_seed", Random::Generator::setSeed);
}
} // namespace Aidge
......@@ -41,8 +41,8 @@ Aidge::DataProvider::DataProvider(const Aidge::Database& database, const std::si
}
// Compute the number of bacthes depending on mDropLast boolean
mNbBatch = (mDropLast) ?
static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) :
mNbBatch = (mDropLast) ?
static_cast<std::size_t>(std::floor(mNbItems / mBatchSize)) :
static_cast<std::size_t>(std::ceil(mNbItems / mBatchSize));
}
......@@ -98,7 +98,7 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::DataProvider::readBatch() con
void Aidge::DataProvider::setBatches(){
mBatches.clear();
mBatches.resize(mNbItems);
std::iota(mBatches.begin(),
......@@ -106,7 +106,7 @@ void Aidge::DataProvider::setBatches(){
0U);
if (mShuffle){
Random::randShuffle(mBatches);
Aidge::Random::randShuffle(mBatches);
}
if (mNbItems % mBatchSize !=0){ // The last batch is not full
......
......@@ -13,6 +13,7 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/utils/Random.hpp"
template <typename T>
void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor,
......@@ -36,9 +37,6 @@ void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor,
? meanNorm / ((fanIn + fanOut) / 2.0)
: meanNorm / fanOut);
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::normal_distribution<T> normalDist(mean, stdDev);
std::shared_ptr<Tensor> cpyTensor;
......@@ -48,7 +46,7 @@ void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor,
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, scaling*normalDist(gen));
tensorWithValues.set<T>(idx, scaling*normalDist(Aidge::Random::Generator::get()));
}
// Copy values back to the original tensors (actual copy only if needed)
......
......@@ -13,6 +13,7 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/utils/Random.hpp"
template <typename T>
void Aidge::normalFiller(std::shared_ptr<Aidge::Tensor> tensor, double mean,
......@@ -20,8 +21,6 @@ void Aidge::normalFiller(std::shared_ptr<Aidge::Tensor> tensor, double mean,
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::normal_distribution<T> normalDist(mean, stdDev);
......@@ -32,7 +31,7 @@ void Aidge::normalFiller(std::shared_ptr<Aidge::Tensor> tensor, double mean,
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, normalDist(gen));
tensorWithValues.set<T>(idx, normalDist(Aidge::Random::Generator::get()));
}
// Copy values back to the original tensors (actual copy only if needed)
......
......@@ -13,14 +13,14 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/utils/Random.hpp"
template <typename T>
void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::uniform_real_distribution<T> uniformDist(min, max);
......@@ -31,7 +31,7 @@ void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) {
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, uniformDist(gen));
tensorWithValues.set<T>(idx, uniformDist(Aidge::Random::Generator::get()));
}
// Copy values back to the original tensors (actual copy only if needed)
......
......@@ -13,10 +13,11 @@
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/utils/Random.hpp"
template <typename T>
void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling,
Aidge::VarianceNorm varianceNorm) {
void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor,
T scaling, Aidge::VarianceNorm varianceNorm) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
......@@ -30,9 +31,6 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling
: fanOut);
const T scale(std::sqrt(3.0 / n));
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::uniform_real_distribution<T> uniformDist(-scale, scale);
std::shared_ptr<Aidge::Tensor> cpyTensor;
......@@ -41,8 +39,8 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
T value = scaling * uniformDist(gen);
tensorWithValues.set<T>(idx, value);
tensorWithValues.set<T>(
idx, scaling * uniformDist(Aidge::Random::Generator::get()));
}
// Copy values back to the original tensors (actual copy only if needed)
......@@ -50,7 +48,7 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling
}
template <typename T>
void Aidge::xavierNormalFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling,
Aidge::VarianceNorm varianceNorm) {
Aidge::VarianceNorm varianceNorm) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
......@@ -64,9 +62,6 @@ void Aidge::xavierNormalFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling,
: fanOut);
const double stdDev(std::sqrt(1.0 / n));
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::normal_distribution<T> normalDist(0.0, stdDev);
std::shared_ptr<Aidge::Tensor> cpyTensor;
......@@ -76,7 +71,8 @@ void Aidge::xavierNormalFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling,
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, scaling*normalDist(gen));
tensorWithValues.set<T>(
idx, scaling * normalDist(Aidge::Random::Generator::get()));
}
// Copy values back to the original tensors (actual copy only if needed)
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/utils/Random.hpp"
#include <random> // normal_distribution, uniform_real_distribution
std::mt19937 Aidge::Random::Generator::generator{std::random_device{}()};
unsigned int Aidge::Random::Generator::seed = 0;
void Aidge::Random::Generator::setSeed(unsigned int seed) {
seed = seed;
generator.seed(seed);
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment