From ec2c2e4380a4b4b609c06dd3bb2038a9b5dca430 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Fri, 9 Feb 2024 18:08:44 +0100 Subject: [PATCH] Added Python binding for new operators --- .../operator/pybind_MetaOperatorDefs.cpp | 14 ++++++++++ python_binding/operator/pybind_Sigmoid.cpp | 27 +++++++++++++++++++ python_binding/operator/pybind_Tanh.cpp | 27 +++++++++++++++++++ 3 files changed, 68 insertions(+) create mode 100644 python_binding/operator/pybind_Sigmoid.cpp create mode 100644 python_binding/operator/pybind_Tanh.cpp diff --git a/python_binding/operator/pybind_MetaOperatorDefs.cpp b/python_binding/operator/pybind_MetaOperatorDefs.cpp index b043ac23c..9eb11b6ba 100644 --- a/python_binding/operator/pybind_MetaOperatorDefs.cpp +++ b/python_binding/operator/pybind_MetaOperatorDefs.cpp @@ -108,6 +108,19 @@ template <DimIdx_t DIM> void declare_PaddedMaxPoolingOp(py::module &m) { } +void declare_LSTMOp(py::module &m) { + m.def("LSTM", [](DimSize_t in_channels, + DimSize_t hidden_channels, + DimSize_t seq_length, + const std::string& name) + { + return LSTM(in_channels, hidden_channels, seq_length, name); + }, py::arg("in_channels"), + py::arg("hidden_channels"), + py::arg("seq_length"), + py::arg("name") = ""); +} + void init_MetaOperatorDefs(py::module &m) { declare_PaddedConvOp<1>(m); declare_PaddedConvOp<2>(m); @@ -121,6 +134,7 @@ void init_MetaOperatorDefs(py::module &m) { declare_PaddedMaxPoolingOp<1>(m); declare_PaddedMaxPoolingOp<2>(m); declare_PaddedMaxPoolingOp<3>(m); + declare_LSTMOp(m); py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance()) .def("get_micro_graph", &MetaOperator_Op::getMicroGraph); diff --git a/python_binding/operator/pybind_Sigmoid.cpp b/python_binding/operator/pybind_Sigmoid.cpp new file mode 100644 index 000000000..2393e56c1 --- /dev/null +++ b/python_binding/operator/pybind_Sigmoid.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Sigmoid.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Sigmoid(py::module& m) { + py::class_<Sigmoid_Op, std::shared_ptr<Sigmoid_Op>, OperatorTensor>(m, "SigmoidOp", py::multiple_inheritance()) + .def("get_inputs_name", &Sigmoid_Op::getInputsName) + .def("get_outputs_name", &Sigmoid_Op::getOutputsName); + + m.def("Sigmoid", &Sigmoid, py::arg("name") = ""); +} +} // namespace Aidge diff --git a/python_binding/operator/pybind_Tanh.cpp b/python_binding/operator/pybind_Tanh.cpp new file mode 100644 index 000000000..2f3140039 --- /dev/null +++ b/python_binding/operator/pybind_Tanh.cpp @@ -0,0 +1,27 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/operator/Tanh.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Tanh(py::module& m) { + py::class_<Tanh_Op, std::shared_ptr<Tanh_Op>, OperatorTensor>(m, "TanhOp", py::multiple_inheritance()) + .def("get_inputs_name", &Tanh_Op::getInputsName) + .def("get_outputs_name", &Tanh_Op::getOutputsName); + + m.def("Tanh", &Tanh, py::arg("name") = ""); +} +} // namespace Aidge -- GitLab