From 3d892e42918f47c247df284b0c6120018f513c39 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 21 Jan 2025 14:43:18 +0100 Subject: [PATCH] Add python binding for Flatten --- python_binding/operator/pybind_Flatten.cpp | 50 ++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 python_binding/operator/pybind_Flatten.cpp diff --git a/python_binding/operator/pybind_Flatten.cpp b/python_binding/operator/pybind_Flatten.cpp new file mode 100644 index 000000000..899e5d775 --- /dev/null +++ b/python_binding/operator/pybind_Flatten.cpp @@ -0,0 +1,50 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> +#include <pybind11/pybind11.h> +#include <string> +#include <vector> + +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/operator/Flatten.hpp" +#include "aidge/utils/Attributes.hpp" +#include "aidge/utils/Types.h" + +namespace py = pybind11; +namespace Aidge { + +void init_Flatten(py::module &m) { + py::class_<Flatten_Op, std::shared_ptr<Flatten_Op>, OperatorTensor>( + m, "FlattenOp", py::multiple_inheritance(), + R"mydelimiter( + Initialize flatten operator + :param axis : up to which input dimensions (exclusive) should be flattened to the outer dimension of the output + between [-r;r-1] with r = input_tensor.nbDims() + :type axes : :py:class: List[Int] + )mydelimiter") + .def("get_inputs_name", &Flatten_Op::getInputsName) + .def("get_outputs_name", &Flatten_Op::getOutputsName) + .def("axis", &Flatten_Op::axis); + // Here we bind the constructor of the Flatten Node. We add an argument + // for each attribute of the operator (in here we only have 'axis') and + // the last argument is the node's name. + m.def("Flatten", &Flatten, py::arg("axis") = 1, + py::arg("name") = "", + R"mydelimiter( + Initialize a node containing a flatten operator. + :param axis : up to which input dimensions (exclusive) should be flattened to the outer dimension of the output + between [-r;r-1] with r = input_tensor.nbDims() + :type axes : :py:class: List[Int] + :param name : name of the node. +)mydelimiter"); +} +} // namespace Aidge -- GitLab