From d84a69c7caa176dc0c351314ec01691e92648035 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Mon, 2 Oct 2023 15:40:16 +0000 Subject: [PATCH] [GenericOperator] Add set_compute_output_dims & compute_output_dims method + associated unittest. --- aidge_core/unit_tests/test_operator_binding.py | 12 +++++++++++- python_binding/operator/pybind_GenericOperator.cpp | 7 ++++++- 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index b326e0748..8898bc5a7 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -61,5 +61,15 @@ class test_operator_binding(unittest.TestCase): self.generic_operator.add_parameter("l_str", ["ok"]) self.assertEqual(self.generic_operator.get_parameter("l_str"), ["ok"]) + def test_compute_output_dims(self): + in_dims=[25, 25] + input = aidge_core.Producer(in_dims, name="In") + genOp = aidge_core.GenericOperator("genOp", 1, 1, 1, name="genOp") + _ = aidge_core.sequential([input, genOp]) + self.assertListEqual(genOp.get_operator().output(0).dims(), []) + genOp.get_operator().set_compute_output_dims(lambda x:x) + genOp.get_operator().compute_output_dims() + self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index bec59eaf2..dfd2cfede 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -11,6 +11,7 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> +#include <pybind11/functional.h> #include <stdio.h> #include "aidge/backend/OperatorImpl.hpp" @@ -59,7 +60,11 @@ void init_GenericOperator(py::module& m) { throw py::key_error("Failed to convert parameter type " + key + ", this issue may come from typeid function which gave an unknown key : [" + paramType + "]. Please open an issue asking to add the support for this key."); } return res; - }); + }) + .def_readonly_static("identity", &GenericOperator_Op::Identity) + .def("compute_output_dims", &GenericOperator_Op::computeOutputDims) + .def("set_compute_output_dims", &GenericOperator_Op::setComputeOutputDims, py::arg("computation_function")) + ; m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nbDataIn"), py::arg("nbIn"), py::arg("nbOut"), py::arg("name") = ""); -- GitLab