From 44a64216592b2e9ddeda5b58ab04bb37b41bb206 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Fri, 28 Jun 2024 10:38:21 +0000 Subject: [PATCH] update DataType for python binding --- python_binding/data/pybind_Data.cpp | 4 +++- python_binding/operator/pybind_Shape.cpp | 14 ++++++-------- 2 files changed, 9 insertions(+), 9 deletions(-) diff --git a/python_binding/data/pybind_Data.cpp b/python_binding/data/pybind_Data.cpp index 32b0b0790..46bbcf83d 100644 --- a/python_binding/data/pybind_Data.cpp +++ b/python_binding/data/pybind_Data.cpp @@ -22,9 +22,11 @@ void init_Data(py::module& m){ .value("float32", DataType::Float32) .value("float16", DataType::Float16) .value("int8", DataType::Int8) + .value("int16", DataType::Int16) .value("int32", DataType::Int32) .value("int64", DataType::Int64) - .value("uint8", DataType::UInt8) + .value("int8", DataType::Int8) + .value("uint16", DataType::UInt16) .value("uint32", DataType::UInt32) .value("uint64", DataType::UInt64) ; diff --git a/python_binding/operator/pybind_Shape.cpp b/python_binding/operator/pybind_Shape.cpp index dbae1d95d..4e1d4203e 100644 --- a/python_binding/operator/pybind_Shape.cpp +++ b/python_binding/operator/pybind_Shape.cpp @@ -9,11 +9,10 @@ * ********************************************************************************/ +#include <cstdint> // std::int64_t + #include <pybind11/pybind11.h> -#include <string> -#include <vector> -#include "aidge/data/Tensor.hpp" #include "aidge/operator/Shape.hpp" #include "aidge/operator/OperatorTensor.hpp" @@ -21,14 +20,13 @@ namespace py = pybind11; namespace Aidge { void init_Shape(py::module& m) { - py::class_<Shape_Op, std::shared_ptr<Shape_Op>, Attributes, OperatorTensor>(m, "ShapeOp", py::multiple_inheritance()) - .def(py::init<std::int64_t, - std::int64_t>(), + py::class_<Shape_Op, std::shared_ptr<Shape_Op>, OperatorTensor>(m, "ShapeOp", py::multiple_inheritance()) + .def(py::init<const std::int64_t, + const std::int64_t>(), py::arg("start"), py::arg("end")) .def_static("get_inputs_name", &Shape_Op::getInputsName) - .def_static("get_outputs_name", &Shape_Op::getOutputsName) - .def_static("attributes_name", &Shape_Op::staticGetAttrsName); + .def_static("get_outputs_name", &Shape_Op::getOutputsName); declare_registrable<Shape_Op>(m, "ShapeOp"); -- GitLab