From d84a69c7caa176dc0c351314ec01691e92648035 Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Mon, 2 Oct 2023 15:40:16 +0000
Subject: [PATCH] [GenericOperator] Add set_compute_output_dims &
 compute_output_dims method + associated unittest.

---
 aidge_core/unit_tests/test_operator_binding.py     | 12 +++++++++++-
 python_binding/operator/pybind_GenericOperator.cpp |  7 ++++++-
 2 files changed, 17 insertions(+), 2 deletions(-)

diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py
index b326e0748..8898bc5a7 100644
--- a/aidge_core/unit_tests/test_operator_binding.py
+++ b/aidge_core/unit_tests/test_operator_binding.py
@@ -61,5 +61,15 @@ class test_operator_binding(unittest.TestCase):
         self.generic_operator.add_parameter("l_str", ["ok"])
         self.assertEqual(self.generic_operator.get_parameter("l_str"), ["ok"])
 
+    def test_compute_output_dims(self):
+        in_dims=[25, 25]
+        input = aidge_core.Producer(in_dims, name="In")
+        genOp = aidge_core.GenericOperator("genOp", 1, 1, 1, name="genOp")
+        _ = aidge_core.sequential([input, genOp])
+        self.assertListEqual(genOp.get_operator().output(0).dims(), [])
+        genOp.get_operator().set_compute_output_dims(lambda x:x)
+        genOp.get_operator().compute_output_dims()
+        self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims)
+
 if __name__ == '__main__':
-    unittest.main()
\ No newline at end of file
+    unittest.main()
diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp
index bec59eaf2..dfd2cfede 100644
--- a/python_binding/operator/pybind_GenericOperator.cpp
+++ b/python_binding/operator/pybind_GenericOperator.cpp
@@ -11,6 +11,7 @@
 
 #include <pybind11/pybind11.h>
 #include <pybind11/stl.h>
+#include <pybind11/functional.h>
 #include <stdio.h>
 
 #include "aidge/backend/OperatorImpl.hpp"
@@ -59,7 +60,11 @@ void init_GenericOperator(py::module& m) {
             throw py::key_error("Failed to convert parameter type " + key + ", this issue may come from typeid function which gave an unknown key : [" + paramType + "]. Please open an issue asking to add the support for this key.");
         }
         return res;
-    });
+    })
+    .def_readonly_static("identity", &GenericOperator_Op::Identity)
+    .def("compute_output_dims", &GenericOperator_Op::computeOutputDims)
+    .def("set_compute_output_dims", &GenericOperator_Op::setComputeOutputDims, py::arg("computation_function"))
+    ;
 
     m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nbDataIn"), py::arg("nbIn"), py::arg("nbOut"),
           py::arg("name") = "");
-- 
GitLab