Skip to content
Snippets Groups Projects
Commit e78ca222 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'upd_python_binding' into 'dev'

Upd python binding

See merge request !238
parents 40b8dddd 871476a2
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!238Upd python binding
Pipeline #58412 passed
......@@ -315,6 +315,7 @@ void init_Tensor(py::module& m){
.def(py::self - py::self)
.def(py::self * py::self)
.def(py::self / py::self)
.def("clone", &Tensor::clone)
.def("sqrt", &Tensor::sqrt)
.def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true)
.def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
......@@ -334,8 +335,8 @@ void init_Tensor(py::module& m){
.def("cpy_transpose", (void (Tensor::*)(const Tensor& src, const std::vector<DimSize_t>& transpose)) &Tensor::copyTranspose, py::arg("src"), py::arg("transpose"))
.def("__str__", [](Tensor& b) {
if (b.empty()) {
return std::string("{}");
if (b.empty() && b.undefined()) {
return std::string("{}");
} else {
return b.toString();
}
......
......@@ -30,11 +30,17 @@ void init_Filler(py::module &m) {
[](std::shared_ptr<Tensor> tensor, py::object value) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
constantFiller<double>(tensor, value.cast<double>());
constantFiller<cpptype_t<DataType::Float64>>(tensor, value.cast<cpptype_t<DataType::Float64>>());
break;
case DataType::Float32:
constantFiller<float>(tensor, value.cast<float>());
constantFiller<cpptype_t<DataType::Float32>>(tensor, value.cast<cpptype_t<DataType::Float32>>());
break;
case DataType::Int64:
constantFiller<cpptype_t<DataType::Int64>>(tensor, value.cast<cpptype_t<DataType::Int64>>());
break;
case DataType::Int32:
constantFiller<cpptype_t<DataType::Int32>>(tensor, value.cast<cpptype_t<DataType::Int32>>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
......@@ -44,14 +50,14 @@ void init_Filler(py::module &m) {
py::arg("tensor"), py::arg("value"))
.def(
"normal_filler",
[](std::shared_ptr<Tensor> tensor, double mean,
double stdDev) -> void {
[](std::shared_ptr<Tensor> tensor, py::object mean,
py::object stdDev) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
normalFiller<double>(tensor, mean, stdDev);
normalFiller<cpptype_t<DataType::Float64>>(tensor, mean.cast<cpptype_t<DataType::Float64>>(), stdDev.cast<cpptype_t<DataType::Float64>>());
break;
case DataType::Float32:
normalFiller<float>(tensor, mean, stdDev);
normalFiller<cpptype_t<DataType::Float32>>(tensor, mean.cast<cpptype_t<DataType::Float32>>(), stdDev.cast<cpptype_t<DataType::Float32>>());
break;
default:
AIDGE_THROW_OR_ABORT(
......@@ -60,23 +66,39 @@ void init_Filler(py::module &m) {
}
},
py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0)
.def(
"uniform_filler",
[](std::shared_ptr<Tensor> tensor, double min, double max) -> void {
.def("uniform_filler", [] (std::shared_ptr<Tensor> tensor, py::object min, py::object max) -> void {
if (py::isinstance<py::int_>(min) && py::isinstance<py::int_>(max)) {
switch (tensor->dataType()) {
case DataType::Float64:
uniformFiller<double>(tensor, min, max);
case DataType::Int32:
uniformFiller<std::int32_t>(tensor, min.cast<std::int32_t>(), max.cast<std::int32_t>());
break;
case DataType::Int64:
uniformFiller<std::int64_t>(tensor, min.cast<std::int64_t>(), max.cast<std::int64_t>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
break;
}
} else if (py::isinstance<py::float_>(min) && py::isinstance<py::float_>(max)) {
switch (tensor->dataType()) {
case DataType::Float32:
uniformFiller<float>(tensor, min, max);
uniformFiller<float>(tensor, min.cast<float>(), max.cast<float>());
break;
case DataType::Float64:
uniformFiller<double>(tensor, min.cast<double>(), max.cast<double>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
break;
}
},
py::arg("tensor"), py::arg("min"), py::arg("max"))
} else {
AIDGE_THROW_OR_ABORT(py::value_error,"Input must be either an int or a float.");
}
}, py::arg("tensor"), py::arg("min"), py::arg("max"))
.def(
"xavier_uniform_filler",
[](std::shared_ptr<Tensor> tensor, py::object scaling,
......
......@@ -34,6 +34,11 @@ void init_Node(py::module& m) {
Type of the node.
)mydelimiter")
.def("attributes", &Node::attributes,
R"mydelimiter(
Get attributes.
)mydelimiter")
.def("get_operator", &Node::getOperator,
R"mydelimiter(
Get the Operator object of the Node.
......@@ -48,7 +53,7 @@ void init_Node(py::module& m) {
:rtype: str
)mydelimiter")
.def("create_unique_name", &Node::createUniqueName, py::arg("base_name"),
.def("create_unique_name", &Node::createUniqueName, py::arg("base_name"),
R"mydelimiter(
Given a base name, generate a new name which is unique in all the GraphViews containing this node.
......
......@@ -39,6 +39,7 @@ void Aidge::constantFiller(std::shared_ptr<Aidge::Tensor> tensor, T constantValu
tensor->copyCastFrom(tensorWithValues);
}
template void Aidge::constantFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>, std::int32_t);
template void Aidge::constantFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>, std::int64_t);
template void Aidge::constantFiller<float>(std::shared_ptr<Aidge::Tensor>, float);
template void Aidge::constantFiller<double>(std::shared_ptr<Aidge::Tensor>, double);
......@@ -8,8 +8,9 @@
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <cstdint> // std::int32_t
#include <memory>
#include <random> // normal_distribution, uniform_real_distribution
#include <random> // normal_distribution, uniform_real_distribution
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
......@@ -19,10 +20,16 @@ template <typename T>
void Aidge::uniformFiller(std::shared_ptr<Aidge::Tensor> tensor, T min, T max) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type {} and {}",NativeType<T>::type, tensor->dataType());
std::uniform_real_distribution<T> uniformDist(min, max);
using DistType = typename std::conditional<
std::is_integral<T>::value,
std::uniform_int_distribution<T>,
std::uniform_real_distribution<T>
>::type;
DistType uniformDist(min, max);
std::shared_ptr<Aidge::Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
......@@ -42,3 +49,7 @@ template void Aidge::uniformFiller<float>(std::shared_ptr<Aidge::Tensor>, float,
float);
template void Aidge::uniformFiller<double>(std::shared_ptr<Aidge::Tensor>,
double, double);
template void Aidge::uniformFiller<std::int32_t>(std::shared_ptr<Aidge::Tensor>,
std::int32_t, std::int32_t);
template void Aidge::uniformFiller<std::int64_t>(std::shared_ptr<Aidge::Tensor>,
std::int64_t, std::int64_t);
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment