Skip to content
Snippets Groups Projects
Commit 72f49431 authored by Maxence Naud's avatar Maxence Naud
Browse files

[add] int32 and int64 types to const_filler and uniform_filler

parent 44c3441b
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!238Upd python binding
...@@ -30,11 +30,17 @@ void init_Filler(py::module &m) { ...@@ -30,11 +30,17 @@ void init_Filler(py::module &m) {
[](std::shared_ptr<Tensor> tensor, py::object value) -> void { [](std::shared_ptr<Tensor> tensor, py::object value) -> void {
switch (tensor->dataType()) { switch (tensor->dataType()) {
case DataType::Float64: case DataType::Float64:
constantFiller<double>(tensor, value.cast<double>()); constantFiller<cpptype_t<DataType::Float64>>(tensor, value.cast<cpptype_t<DataType::Float64>>());
break; break;
case DataType::Float32: case DataType::Float32:
constantFiller<float>(tensor, value.cast<float>()); constantFiller<cpptype_t<DataType::Float32>>(tensor, value.cast<cpptype_t<DataType::Float32>>());
break; break;
case DataType::Int64:
constantFiller<cpptype_t<DataType::Int64>>(tensor, value.cast<cpptype_t<DataType::Int64>>());
break;
case DataType::Int32:
constantFiller<cpptype_t<DataType::Int32>>(tensor, value.cast<cpptype_t<DataType::Int32>>());
break;
default: default:
AIDGE_THROW_OR_ABORT( AIDGE_THROW_OR_ABORT(
py::value_error, py::value_error,
...@@ -44,14 +50,14 @@ void init_Filler(py::module &m) { ...@@ -44,14 +50,14 @@ void init_Filler(py::module &m) {
py::arg("tensor"), py::arg("value")) py::arg("tensor"), py::arg("value"))
.def( .def(
"normal_filler", "normal_filler",
[](std::shared_ptr<Tensor> tensor, double mean, [](std::shared_ptr<Tensor> tensor, py::object mean,
double stdDev) -> void { py::object stdDev) -> void {
switch (tensor->dataType()) { switch (tensor->dataType()) {
case DataType::Float64: case DataType::Float64:
normalFiller<double>(tensor, mean, stdDev); normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float64>>(), stdDev.cast<cpptype_t<DataType::Float64>>());
break; break;
case DataType::Float32: case DataType::Float32:
normalFiller<float>(tensor, mean, stdDev); normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float32>>(), stdDev.cast<cpptype_t<DataType::Float32>>());
break; break;
default: default:
AIDGE_THROW_OR_ABORT( AIDGE_THROW_OR_ABORT(
...@@ -60,23 +66,39 @@ void init_Filler(py::module &m) { ...@@ -60,23 +66,39 @@ void init_Filler(py::module &m) {
} }
}, },
py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0) py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0)
.def( .def("uniform_filler", [] (std::shared_ptr<Tensor> tensor, py::object min, py::object max) -> void {
"uniform_filler", if (py::isinstance<py::int_>(min) && py::isinstance<py::int_>(max)) {
[](std::shared_ptr<Tensor> tensor, double min, double max) -> void {
switch (tensor->dataType()) { switch (tensor->dataType()) {
case DataType::Float64: case DataType::Int32:
uniformFiller<double>(tensor, min, max); uniformFiller<std::int32_t>(tensor, min.cast<std::int32_t>(), max.cast<std::int32_t>());
break;
case DataType::Int64:
uniformFiller<std::int64_t>(tensor, min.cast<std::int64_t>(), max.cast<std::int64_t>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
break; break;
}
} else if (py::isinstance<py::float_>(min) && py::isinstance<py::float_>(max)) {
switch (tensor->dataType()) {
case DataType::Float32: case DataType::Float32:
uniformFiller<float>(tensor, min, max); uniformFiller<float>(tensor, min.cast<float>(), max.cast<float>());
break;
case DataType::Float64:
uniformFiller<double>(tensor, min.cast<double>(), max.cast<double>());
break; break;
default: default:
AIDGE_THROW_OR_ABORT( AIDGE_THROW_OR_ABORT(
py::value_error, py::value_error,
"Data type is not supported for Uniform filler."); "Data type is not supported for Uniform filler.");
break;
} }
}, } else {
py::arg("tensor"), py::arg("min"), py::arg("max")) AIDGE_THROW_OR_ABORT(py::value_error,"Input must be either an int or a float.");
}
}, py::arg("tensor"), py::arg("min"), py::arg("max"))
.def( .def(
"xavier_uniform_filler", "xavier_uniform_filler",
[](std::shared_ptr<Tensor> tensor, py::object scaling, [](std::shared_ptr<Tensor> tensor, py::object scaling,
......
...@@ -39,6 +39,7 @@ void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValu ...@@ -39,6 +39,7 @@ void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValu
tensor->copyCastFrom(tensorWithValues); tensor->copyCastFrom(tensorWithValues);
} }
template void Aidge::constantFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>, std::int32_t);
template void Aidge::constantFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>, std::int64_t);
template void Aidge::constantFiller<float>(std::shared_ptr<Aidge::Tensor>, float); template void Aidge::constantFiller<float>(std::shared_ptr<Aidge::Tensor>, float);
template void Aidge::constantFiller<double>(std::shared_ptr<Aidge::Tensor>, double); template void Aidge::constantFiller<double>(std::shared_ptr<Aidge::Tensor>, double);
...@@ -8,8 +8,9 @@ ...@@ -8,8 +8,9 @@
* SPDX-License-Identifier: EPL-2.0 * SPDX-License-Identifier: EPL-2.0
* *
********************************************************************************/ ********************************************************************************/
#include <cstdint> // std::int32_t
#include <memory> #include <memory>
#include <random> // normal_distribution, uniform_real_distribution #include <random> // normal_distribution, uniform_real_distribution
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp" #include "aidge/filler/Filler.hpp"
...@@ -19,10 +20,16 @@ template <typename T> ...@@ -19,10 +20,16 @@ template <typename T>
void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) { void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) {
AIDGE_ASSERT(tensor->getImpl(), AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it."); "Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type {} and {}",NativeType<T>::type, tensor->dataType());
std::uniform_real_distribution<T> uniformDist(min, max); using DistType = typename std::conditional<
std::is_integral<T>::value,
std::uniform_int_distribution<T>,
std::uniform_real_distribution<T>
>::type;
DistType uniformDist(min, max);
std::shared_ptr<Aidge::Tensor> cpyTensor; std::shared_ptr<Aidge::Tensor> cpyTensor;
// Create cpy only if tensor not on CPU // Create cpy only if tensor not on CPU
...@@ -42,3 +49,7 @@ template void Aidge::uniformFiller<float>(std::shared_ptr<Aidge::Tensor>, float, ...@@ -42,3 +49,7 @@ template void Aidge::uniformFiller<float>(std::shared_ptr<Aidge::Tensor>, float,
float); float);
template void Aidge::uniformFiller<double>(std::shared_ptr<Aidge::Tensor>, template void Aidge::uniformFiller<double>(std::shared_ptr<Aidge::Tensor>,
double, double); double, double);
template void Aidge::uniformFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>,
std::int32_t, std::int32_t);
template void Aidge::uniformFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>,
std::int64_t, std::int64_t);
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment