Skip to content
Snippets Groups Projects
Commit 5e9239f3 authored by Cyril Moineau's avatar Cyril Moineau Committed by Maxence Naud
Browse files

Add Normal Filler.

parent 28096a02
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!99Adding Filler to aidge_core
......@@ -13,13 +13,12 @@
#define AIDGE_CORE_FILLER_H_
#include <memory>
#include <random> // normal_distribution
#include <random> // normal_distribution, uniform_real_distribution
#include "aidge/data/Tensor.hpp"
namespace Aidge {
template <typename T>
void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) {
AIDGE_ASSERT(tensor->getImpl(),
......@@ -39,15 +38,16 @@ void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) {
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
}
template <typename T> // TODO: Keep template or use switch case depending on Tensor datatype ?
void normalFiller(std::shared_ptr<Tensor> tensor, double mean=0.0, double stdDev= 1.0){
// TODO: Keep template or use switch case depending on Tensor datatype ?
template <typename T>
void normalFiller(std::shared_ptr<Tensor> tensor,
double mean = 0.0,
double stdDev = 1.0) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::normal_distribution<T> normalDist(mean, stdDev);
......@@ -56,17 +56,39 @@ void normalFiller(std::shared_ptr<Tensor> tensor, double mean=0.0, double stdDev
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, normalDist(gen));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
// TODO: Keep template or use switch case depending on Tensor datatype ?
template<typename T>
void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::uniform_distribution<T> uniformDist(min, max);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, normalDist(gen));
tensorWithValues.set<T>(idx, uniformDist(gen));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
// void uniformFiller(std::shared_ptr<Tensor> tensor);
// void xavierFiller(std::shared_ptr<Tensor> tensor);
// void heFiller(std::shared_ptr<Tensor> tensor);
......
......@@ -56,7 +56,7 @@ void init_Filler(py::module &m) {
py::value_error,
"Data type is not supported for Constant filler.");
}
})
}, py::arg("tensor"), py::arg("value"))
.def("normal_filler",
[](std::shared_ptr<Tensor> tensor, double mean,
double stdDev) -> void {
......@@ -73,7 +73,22 @@ void init_Filler(py::module &m) {
"Data type is not supported for Normal filler.");
}
}, py::arg("tensor"), py::arg("mean")=0.0, py::arg("stdDev")=1.0)
.def("uniform_filler",
[](std::shared_ptr<Tensor> tensor, double min,
double max) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
uniformFiller<double>(tensor, min, max);
break;
case DataType::Float32:
uniformFiller<float>(tensor, min, max);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
}, py::arg("tensor"), py::arg("min"), py::arg("max"))
;
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment