From 75d25d4238fd3f89ed6b7fb042885e1a886579b7 Mon Sep 17 00:00:00 2001 From: LOPEZ MAPE Lucas <lucas.lopezmape@cea.fr> Date: Fri, 15 Nov 2024 14:41:02 +0000 Subject: [PATCH] cast operator pybind --- python_binding/operator/pybind_Cast.cpp | 46 +++++++++++++++++++++++++ python_binding/pybind_core.cpp | 2 ++ 2 files changed, 48 insertions(+) create mode 100644 python_binding/operator/pybind_Cast.cpp diff --git a/python_binding/operator/pybind_Cast.cpp b/python_binding/operator/pybind_Cast.cpp new file mode 100644 index 000000000..960a084ff --- /dev/null +++ b/python_binding/operator/pybind_Cast.cpp @@ -0,0 +1,46 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include <string> +#include "aidge/backend/OperatorImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Cast.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +void init_Cast(py::module &m) { + // Binding for CastOp class + auto pyCastOp = py::class_<Cast_Op, std::shared_ptr<Cast_Op>, OperatorTensor>(m, "CastOp", py::multiple_inheritance(),R"mydelimiter( + CastOp is a tensor operator that casts the input tensor to a data type specified by the target_type argument. + :param target_type: data type of the output tensor + :type target_type: Datatype + :param name: name of the node. + )mydelimiter") + .def(py::init<DataType>(), py::arg("target_type")) + .def("target_type", &Cast_Op::targetType, "Get the targeted type, output tensor data type") + .def_static("get_inputs_name", &Cast_Op::getInputsName, "Get the names of the input tensors.") + .def_static("get_outputs_name", &Cast_Op::getOutputsName, "Get the names of the output tensors."); + + // Binding for the Cast function + m.def("Cast", &Cast, py::arg("target_type"), py::arg("name") = "", + R"mydelimiter( + CastOp is a tensor operator that casts the input tensor to a data type specified by the target_type argument. + :param target_type: data type of the output tensor + :type target_type: Datatype + :param name: name of the node. + )mydelimiter"); +} +} // namespace Aidge \ No newline at end of file diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index c287314f2..2602108ad 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -36,6 +36,7 @@ void init_Atan(py::module&); void init_AvgPooling(py::module&); void init_BatchNorm(py::module&); void init_BitShift(py::module&); +void init_Cast(py::module&); void init_Clip(py::module&); void init_Concat(py::module&); void init_ConstantOfShape(py::module&); @@ -127,6 +128,7 @@ void init_Aidge(py::module& m) { init_AvgPooling(m); init_BatchNorm(m); init_BitShift(m); + init_Cast(m); init_Clip(m); init_Concat(m); init_Conv(m); -- GitLab