From 75d25d4238fd3f89ed6b7fb042885e1a886579b7 Mon Sep 17 00:00:00 2001
From: LOPEZ MAPE Lucas <lucas.lopezmape@cea.fr>
Date: Fri, 15 Nov 2024 14:41:02 +0000
Subject: [PATCH] cast operator pybind

---
 python_binding/operator/pybind_Cast.cpp | 46 +++++++++++++++++++++++++
 python_binding/pybind_core.cpp          |  2 ++
 2 files changed, 48 insertions(+)
 create mode 100644 python_binding/operator/pybind_Cast.cpp

diff --git a/python_binding/operator/pybind_Cast.cpp b/python_binding/operator/pybind_Cast.cpp
new file mode 100644
index 000000000..960a084ff
--- /dev/null
+++ b/python_binding/operator/pybind_Cast.cpp
@@ -0,0 +1,46 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <pybind11/pybind11.h>
+
+#include <string>
+#include "aidge/backend/OperatorImpl.hpp"
+#include "aidge/data/Tensor.hpp"
+#include "aidge/operator/Cast.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/utils/Types.h"
+
+namespace py = pybind11;
+namespace Aidge {
+
+void init_Cast(py::module &m) {
+    // Binding for CastOp class
+    auto pyCastOp = py::class_<Cast_Op, std::shared_ptr<Cast_Op>, OperatorTensor>(m, "CastOp", py::multiple_inheritance(),R"mydelimiter(
+        CastOp is a tensor operator that casts the input tensor to a data type specified by the target_type argument.
+        :param target_type: data type of the output tensor 
+        :type target_type: Datatype
+        :param name: name of the node.
+    )mydelimiter")
+        .def(py::init<DataType>(), py::arg("target_type"))
+        .def("target_type", &Cast_Op::targetType, "Get the targeted type, output tensor data type")
+        .def_static("get_inputs_name", &Cast_Op::getInputsName, "Get the names of the input tensors.")
+        .def_static("get_outputs_name", &Cast_Op::getOutputsName, "Get the names of the output tensors.");
+
+    // Binding for the Cast function
+    m.def("Cast", &Cast, py::arg("target_type"), py::arg("name") = "",
+        R"mydelimiter(
+        CastOp is a tensor operator that casts the input tensor to a data type specified by the target_type argument.
+        :param target_type: data type of the output tensor 
+        :type target_type: Datatype
+        :param name: name of the node.
+    )mydelimiter");
+}
+} // namespace Aidge
\ No newline at end of file
diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp
index c287314f2..2602108ad 100644
--- a/python_binding/pybind_core.cpp
+++ b/python_binding/pybind_core.cpp
@@ -36,6 +36,7 @@ void init_Atan(py::module&);
 void init_AvgPooling(py::module&);
 void init_BatchNorm(py::module&);
 void init_BitShift(py::module&);
+void init_Cast(py::module&);
 void init_Clip(py::module&);
 void init_Concat(py::module&);
 void init_ConstantOfShape(py::module&);
@@ -127,6 +128,7 @@ void init_Aidge(py::module& m) {
     init_AvgPooling(m);
     init_BatchNorm(m);
     init_BitShift(m);
+    init_Cast(m);
     init_Clip(m);
     init_Concat(m);
     init_Conv(m);
-- 
GitLab