From 3d892e42918f47c247df284b0c6120018f513c39 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Tue, 21 Jan 2025 14:43:18 +0100
Subject: [PATCH] Add python binding for Flatten

---
 python_binding/operator/pybind_Flatten.cpp | 50 ++++++++++++++++++++++
 1 file changed, 50 insertions(+)
 create mode 100644 python_binding/operator/pybind_Flatten.cpp

diff --git a/python_binding/operator/pybind_Flatten.cpp b/python_binding/operator/pybind_Flatten.cpp
new file mode 100644
index 000000000..899e5d775
--- /dev/null
+++ b/python_binding/operator/pybind_Flatten.cpp
@@ -0,0 +1,50 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <memory>
+#include <pybind11/pybind11.h>
+#include <string>
+#include <vector>
+
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/operator/Flatten.hpp"
+#include "aidge/utils/Attributes.hpp"
+#include "aidge/utils/Types.h"
+
+namespace py = pybind11;
+namespace Aidge {
+
+void init_Flatten(py::module &m) {
+  py::class_<Flatten_Op, std::shared_ptr<Flatten_Op>, OperatorTensor>(
+      m, "FlattenOp", py::multiple_inheritance(),
+		R"mydelimiter(
+		Initialize flatten operator
+		:param axis :   up to which input dimensions (exclusive) should be flattened to the outer dimension of the output
+                        between [-r;r-1] with r = input_tensor.nbDims()
+		:type axes : :py:class: List[Int]
+		)mydelimiter")
+      .def("get_inputs_name", &Flatten_Op::getInputsName)
+      .def("get_outputs_name", &Flatten_Op::getOutputsName)
+      .def("axis", &Flatten_Op::axis);
+  // Here we bind the constructor of the Flatten Node. We add an argument
+  // for each attribute of the operator (in here we only have 'axis') and
+  // the last argument is the node's name.
+  m.def("Flatten", &Flatten, py::arg("axis") = 1,
+        py::arg("name") = "",
+        R"mydelimiter(
+    Initialize a node containing a flatten operator.
+	:param axis :   up to which input dimensions (exclusive) should be flattened to the outer dimension of the output
+                    between [-r;r-1] with r = input_tensor.nbDims()
+	:type axes : :py:class: List[Int]
+    :param name : name of the node.
+)mydelimiter");
+}
+} // namespace Aidge
-- 
GitLab