Skip to content
Snippets Groups Projects
Commit f050ac31 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Improved binding

parent e90eeb92
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!186Refactor OperatorImpl for backend/export
Pipeline #53759 passed
...@@ -32,7 +32,7 @@ class Operator; ...@@ -32,7 +32,7 @@ class Operator;
*/ */
struct ImplSpec { struct ImplSpec {
struct IOSpec { struct IOSpec {
IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, std::vector<std::pair<int, int>> dims_ = {}): IOSpec(DataType type_, DataFormat format_ = DataFormat::Any, const std::vector<std::pair<int, int>>& dims_ = {}):
type(type_), type(type_),
format(format_), format(format_),
dims(dims_) dims(dims_)
...@@ -43,9 +43,9 @@ struct ImplSpec { ...@@ -43,9 +43,9 @@ struct ImplSpec {
std::vector<std::pair<int, int>> dims; std::vector<std::pair<int, int>> dims;
}; };
ImplSpec(DynamicAttributes attrs_ = DynamicAttributes()); ImplSpec(const DynamicAttributes& attrs_ = DynamicAttributes());
ImplSpec(IOSpec io, DynamicAttributes attrs_ = DynamicAttributes()); ImplSpec(const IOSpec& io, const DynamicAttributes& attrs_ = DynamicAttributes());
ImplSpec(IOSpec i, IOSpec o, DynamicAttributes attrs_ = DynamicAttributes()); ImplSpec(const IOSpec& i, const IOSpec& o, const DynamicAttributes& attrs_ = DynamicAttributes());
ImplSpec(const Aidge::ImplSpec&); ImplSpec(const Aidge::ImplSpec&);
~ImplSpec() noexcept; ~ImplSpec() noexcept;
......
...@@ -74,6 +74,8 @@ public: ...@@ -74,6 +74,8 @@ public:
static const std::vector<std::string> getOutputsName() { static const std::vector<std::string> getOutputsName() {
return {"data_output"}; return {"data_output"};
} }
virtual ~ReduceMean_Op() noexcept;
}; };
/** /**
......
...@@ -71,14 +71,14 @@ public: ...@@ -71,14 +71,14 @@ public:
}; };
void init_OperatorImpl(py::module& m){ void init_OperatorImpl(py::module& m){
py::class_<ImplSpec>(m, "ImplSpec") py::class_<ImplSpec::IOSpec>(m, "IOSpec")
.def(py::init<DynamicAttributes>()) .def(py::init<DataType, DataFormat, const std::vector<std::pair<int, int>>&>(), py::arg("type"), py::arg("format") = DataFormat::Any, py::arg("dims") = std::vector<std::pair<int, int>>{})
.def(py::init<ImplSpec::IOSpec, DynamicAttributes>())
.def(py::init<ImplSpec::IOSpec, ImplSpec::IOSpec, DynamicAttributes>())
; ;
py::class_<ImplSpec::IOSpec>(m, "IOSpec") py::class_<ImplSpec>(m, "ImplSpec")
.def(py::init<DataType, DataFormat, std::vector<std::pair<int, int>>>()) .def(py::init<const DynamicAttributes&>(), py::arg("attr") = DynamicAttributes())
.def(py::init<const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("io"), py::arg("attr") = DynamicAttributes())
.def(py::init<const ImplSpec::IOSpec&, const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes())
; ;
py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr()) py::class_<OperatorImpl, std::shared_ptr<OperatorImpl>, pyOperatorImpl>(m, "OperatorImpl", py::dynamic_attr())
......
...@@ -16,46 +16,51 @@ ...@@ -16,46 +16,51 @@
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Data(py::module& m){ template <class T>
// Define enumeration names for python as lowercase dtype name void bindEnum(py::module& m, const std::string& name) {
// This defined enum names compatible with basic numpy dtype // Define enumeration names for python as lowercase type name
// This defined enum names compatible with basic numpy type
// name such as: float32, flot64, [u]int32, [u]int64, ... // name such as: float32, flot64, [u]int32, [u]int64, ...
auto python_enum_name = [](const DataType& dtype) { auto python_enum_name = [](const T& type) {
auto str_lower = [](std::string& str) { auto str_lower = [](std::string& str) {
std::transform(str.begin(), str.end(), str.begin(), std::transform(str.begin(), str.end(), str.begin(),
[](unsigned char c){ [](unsigned char c){
return std::tolower(c); return std::tolower(c);
}); });
}; };
auto dtype_name = std::string(Aidge::format_as(dtype)); auto type_name = std::string(Aidge::format_as(type));
str_lower(dtype_name); str_lower(type_name);
return dtype_name; return type_name;
}; };
// Auto generate enumeration names from lowercase dtype strings // Auto generate enumeration names from lowercase type strings
std::vector<std::string> enum_names; std::vector<std::string> enum_names;
for (auto dtype_str : EnumStrings<Aidge::DataType>::data) { for (auto type_str : EnumStrings<T>::data) {
auto dtype = static_cast<DataType>(enum_names.size()); auto type = static_cast<T>(enum_names.size());
auto enum_name = python_enum_name(dtype); auto enum_name = python_enum_name(type);
enum_names.push_back(enum_name); enum_names.push_back(enum_name);
} }
// Define python side enumeration aidge_core.dtype // Define python side enumeration aidge_core.type
auto e_dtype = py::enum_<DataType>(m, "dtype"); auto e_type = py::enum_<T>(m, name.c_str());
// Add enum value for each enum name // Add enum value for each enum name
for (std::size_t idx = 0; idx < enum_names.size(); idx++) { for (std::size_t idx = 0; idx < enum_names.size(); idx++) {
e_dtype.value(enum_names[idx].c_str(), static_cast<DataType>(idx)); e_type.value(enum_names[idx].c_str(), static_cast<T>(idx));
} }
// Define str() to return the bare enum name value, it allows // Define str() to return the bare enum name value, it allows
// to compare directly for instance str(tensor.dtype()) // to compare directly for instance str(tensor.type())
// with str(nparray.dtype) // with str(nparray.type)
e_dtype.def("__str__", [enum_names](const DataType& dtype) { e_type.def("__str__", [enum_names](const T& type) {
return enum_names[static_cast<int>(dtype)]; return enum_names[static_cast<int>(type)];
}, py::prepend());; }, py::prepend());;
}
py::class_<Data, std::shared_ptr<Data>>(m,"Data"); void init_Data(py::module& m){
bindEnum<DataType>(m, "dtype");
bindEnum<DataFormat>(m, "dformat");
py::class_<Data, std::shared_ptr<Data>>(m,"Data");
} }
} }
...@@ -21,8 +21,8 @@ void init_Data(py::module&); ...@@ -21,8 +21,8 @@ void init_Data(py::module&);
void init_Database(py::module&); void init_Database(py::module&);
void init_DataProvider(py::module&); void init_DataProvider(py::module&);
void init_Tensor(py::module&); void init_Tensor(py::module&);
void init_OperatorImpl(py::module&);
void init_Attributes(py::module&); void init_Attributes(py::module&);
void init_OperatorImpl(py::module&);
void init_Log(py::module&); void init_Log(py::module&);
void init_Operator(py::module&); void init_Operator(py::module&);
void init_OperatorTensor(py::module&); void init_OperatorTensor(py::module&);
...@@ -89,6 +89,7 @@ void init_Aidge(py::module& m) { ...@@ -89,6 +89,7 @@ void init_Aidge(py::module& m) {
init_Database(m); init_Database(m);
init_DataProvider(m); init_DataProvider(m);
init_Tensor(m); init_Tensor(m);
init_Attributes(m);
init_Node(m); init_Node(m);
init_GraphView(m); init_GraphView(m);
...@@ -96,7 +97,6 @@ void init_Aidge(py::module& m) { ...@@ -96,7 +97,6 @@ void init_Aidge(py::module& m) {
init_Connector(m); init_Connector(m);
init_OperatorImpl(m); init_OperatorImpl(m);
init_Attributes(m);
init_Log(m); init_Log(m);
init_Operator(m); init_Operator(m);
init_OperatorTensor(m); init_OperatorTensor(m);
......
...@@ -22,11 +22,11 @@ ...@@ -22,11 +22,11 @@
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
Aidge::ImplSpec::ImplSpec(DynamicAttributes attrs_): Aidge::ImplSpec::ImplSpec(const DynamicAttributes& attrs_):
attrs(attrs_) {} attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(IOSpec io, DynamicAttributes attrs_): Aidge::ImplSpec::ImplSpec(const IOSpec& io, const DynamicAttributes& attrs_):
inputs(1, io), outputs(1, io), attrs(attrs_) {} inputs(1, io), outputs(1, io), attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(IOSpec i, IOSpec o, DynamicAttributes attrs_): Aidge::ImplSpec::ImplSpec(const IOSpec& i, const IOSpec& o, const DynamicAttributes& attrs_):
inputs(1, i), outputs(1, o), attrs(attrs_) {} inputs(1, i), outputs(1, o), attrs(attrs_) {}
Aidge::ImplSpec::ImplSpec(const Aidge::ImplSpec&) = default; Aidge::ImplSpec::ImplSpec(const Aidge::ImplSpec&) = default;
Aidge::ImplSpec::~ImplSpec() noexcept = default; Aidge::ImplSpec::~ImplSpec() noexcept = default;
......
...@@ -80,6 +80,8 @@ void Aidge::ReduceMean_Op::setBackend(const std::string& name, Aidge::DeviceIdx_ ...@@ -80,6 +80,8 @@ void Aidge::ReduceMean_Op::setBackend(const std::string& name, Aidge::DeviceIdx_
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
Aidge::ReduceMean_Op::~ReduceMean_Op() noexcept = default;
//////////////////////////////////////////// ////////////////////////////////////////////
std::shared_ptr<Aidge::Node> Aidge::ReduceMean(const std::vector<std::int32_t> &axes, std::shared_ptr<Aidge::Node> Aidge::ReduceMean(const std::vector<std::int32_t> &axes,
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment