Skip to content
Snippets Groups Projects
Commit 72f49431 authored by Maxence Naud's avatar Maxence Naud
Browse files

[add] int32 and int64 types to const_filler and uniform_filler

parent 44c3441b
No related branches found
No related tags found
No related merge requests found
......@@ -30,11 +30,17 @@ void init_Filler(py::module &m) {
[](std::shared_ptr<Tensor> tensor, py::object value) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
constantFiller<double>(tensor, value.cast<double>());
constantFiller<cpptype_t<DataType::Float64>>(tensor, value.cast<cpptype_t<DataType::Float64>>());
break;
case DataType::Float32:
constantFiller<float>(tensor, value.cast<float>());
constantFiller<cpptype_t<DataType::Float32>>(tensor, value.cast<cpptype_t<DataType::Float32>>());
break;
case DataType::Int64:
constantFiller<cpptype_t<DataType::Int64>>(tensor, value.cast<cpptype_t<DataType::Int64>>());
break;
case DataType::Int32:
constantFiller<cpptype_t<DataType::Int32>>(tensor, value.cast<cpptype_t<DataType::Int32>>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
......@@ -44,14 +50,14 @@ void init_Filler(py::module &m) {
py::arg("tensor"), py::arg("value"))
.def(
"normal_filler",
[](std::shared_ptr<Tensor> tensor, double mean,
double stdDev) -> void {
[](std::shared_ptr<Tensor> tensor, py::object mean,
py::object stdDev) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
normalFiller<double>(tensor, mean, stdDev);
normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float64>>(), stdDev.cast<cpptype_t<DataType::Float64>>());
break;
case DataType::Float32:
normalFiller<float>(tensor, mean, stdDev);
normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float32>>(), stdDev.cast<cpptype_t<DataType::Float32>>());
break;
default:
AIDGE_THROW_OR_ABORT(
......@@ -60,23 +66,39 @@ void init_Filler(py::module &m) {
}
},
py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0)
.def(
"uniform_filler",
[](std::shared_ptr<Tensor> tensor, double min, double max) -> void {
.def("uniform_filler", [] (std::shared_ptr<Tensor> tensor, py::object min, py::object max) -> void {
if (py::isinstance<py::int_>(min) && py::isinstance<py::int_>(max)) {
switch (tensor->dataType()) {
case DataType::Float64:
uniformFiller<double>(tensor, min, max);
case DataType::Int32:
uniformFiller<std::int32_t>(tensor, min.cast<std::int32_t>(), max.cast<std::int32_t>());
break;
case DataType::Int64:
uniformFiller<std::int64_t>(tensor, min.cast<std::int64_t>(), max.cast<std::int64_t>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
break;
}
} else if (py::isinstance<py::float_>(min) && py::isinstance<py::float_>(max)) {
switch (tensor->dataType()) {
case DataType::Float32:
uniformFiller<float>(tensor, min, max);
uniformFiller<float>(tensor, min.cast<float>(), max.cast<float>());
break;
case DataType::Float64:
uniformFiller<double>(tensor, min.cast<double>(), max.cast<double>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
break;
}
},
py::arg("tensor"), py::arg("min"), py::arg("max"))
} else {
AIDGE_THROW_OR_ABORT(py::value_error,"Input must be either an int or a float.");
}
}, py::arg("tensor"), py::arg("min"), py::arg("max"))
.def(
"xavier_uniform_filler",
[](std::shared_ptr<Tensor> tensor, py::object scaling,
......
......@@ -39,6 +39,7 @@ void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValu
tensor->copyCastFrom(tensorWithValues);
}
template void Aidge::constantFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>, std::int32_t);
template void Aidge::constantFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>, std::int64_t);
template void Aidge::constantFiller<float>(std::shared_ptr<Aidge::Tensor>, float);
template void Aidge::constantFiller<double>(std::shared_ptr<Aidge::Tensor>, double);
......@@ -8,8 +8,9 @@
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstdint> // std::int32_t
#include <memory>
#include <random> // normal_distribution, uniform_real_distribution
#include <random> // normal_distribution, uniform_real_distribution
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
......@@ -19,10 +20,16 @@ template <typename T>
void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type {} and {}",NativeType<T>::type, tensor->dataType());
std::uniform_real_distribution<T> uniformDist(min, max);
using DistType = typename std::conditional<
std::is_integral<T>::value,
std::uniform_int_distribution<T>,
std::uniform_real_distribution<T>
>::type;
DistType uniformDist(min, max);
std::shared_ptr<Aidge::Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
......@@ -42,3 +49,7 @@ template void Aidge::uniformFiller<float>(std::shared_ptr<Aidge::Tensor>, float,
float);
template void Aidge::uniformFiller<double>(std::shared_ptr<Aidge::Tensor>,
double, double);
template void Aidge::uniformFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>,
std::int32_t, std::int32_t);
template void Aidge::uniformFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>,
std::int64_t, std::int64_t);
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment