Skip to content
Snippets Groups Projects
Commit ec2c2e43 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added Python binding for new operators

parent 4dd01672
No related branches found
No related tags found
No related merge requests found
......@@ -108,6 +108,19 @@ template <DimIdx_t DIM> void declare_PaddedMaxPoolingOp(py::module &m) {
}
void declare_LSTMOp(py::module &m) {
m.def("LSTM", [](DimSize_t in_channels,
DimSize_t hidden_channels,
DimSize_t seq_length,
const std::string& name)
{
return LSTM(in_channels, hidden_channels, seq_length, name);
}, py::arg("in_channels"),
py::arg("hidden_channels"),
py::arg("seq_length"),
py::arg("name") = "");
}
void init_MetaOperatorDefs(py::module &m) {
declare_PaddedConvOp<1>(m);
declare_PaddedConvOp<2>(m);
......@@ -121,6 +134,7 @@ void init_MetaOperatorDefs(py::module &m) {
declare_PaddedMaxPoolingOp<1>(m);
declare_PaddedMaxPoolingOp<2>(m);
declare_PaddedMaxPoolingOp<3>(m);
declare_LSTMOp(m);
py::class_<MetaOperator_Op, std::shared_ptr<MetaOperator_Op>, OperatorTensor>(m, "MetaOperator_Op", py::multiple_inheritance())
.def("get_micro_graph", &MetaOperator_Op::getMicroGraph);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Sigmoid.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Sigmoid(py::module& m) {
py::class_<Sigmoid_Op, std::shared_ptr<Sigmoid_Op>, OperatorTensor>(m, "SigmoidOp", py::multiple_inheritance())
.def("get_inputs_name", &Sigmoid_Op::getInputsName)
.def("get_outputs_name", &Sigmoid_Op::getOutputsName);
m.def("Sigmoid", &Sigmoid, py::arg("name") = "");
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/operator/Tanh.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Tanh(py::module& m) {
py::class_<Tanh_Op, std::shared_ptr<Tanh_Op>, OperatorTensor>(m, "TanhOp", py::multiple_inheritance())
.def("get_inputs_name", &Tanh_Op::getInputsName)
.def("get_outputs_name", &Tanh_Op::getOutputsName);
m.def("Tanh", &Tanh, py::arg("name") = "");
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment