diff --git a/aidge_core/unit_tests/test_operator_binding.py b/aidge_core/unit_tests/test_operator_binding.py index b326e0748c2c77612dd79122fe891a6207d945dc..8898bc5a7ac6ce771cab8402933d464c1f04316f 100644 --- a/aidge_core/unit_tests/test_operator_binding.py +++ b/aidge_core/unit_tests/test_operator_binding.py @@ -61,5 +61,15 @@ class test_operator_binding(unittest.TestCase): self.generic_operator.add_parameter("l_str", ["ok"]) self.assertEqual(self.generic_operator.get_parameter("l_str"), ["ok"]) + def test_compute_output_dims(self): + in_dims=[25, 25] + input = aidge_core.Producer(in_dims, name="In") + genOp = aidge_core.GenericOperator("genOp", 1, 1, 1, name="genOp") + _ = aidge_core.sequential([input, genOp]) + self.assertListEqual(genOp.get_operator().output(0).dims(), []) + genOp.get_operator().set_compute_output_dims(lambda x:x) + genOp.get_operator().compute_output_dims() + self.assertListEqual(genOp.get_operator().output(0).dims(), in_dims) + if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() diff --git a/python_binding/operator/pybind_GenericOperator.cpp b/python_binding/operator/pybind_GenericOperator.cpp index bec59eaf2cecdc7f64d1da07580116c4b3334992..dfd2cfedec5aa291f11cf7c2a93d750c3d91145f 100644 --- a/python_binding/operator/pybind_GenericOperator.cpp +++ b/python_binding/operator/pybind_GenericOperator.cpp @@ -11,6 +11,7 @@ #include <pybind11/pybind11.h> #include <pybind11/stl.h> +#include <pybind11/functional.h> #include <stdio.h> #include "aidge/backend/OperatorImpl.hpp" @@ -59,7 +60,11 @@ void init_GenericOperator(py::module& m) { throw py::key_error("Failed to convert parameter type " + key + ", this issue may come from typeid function which gave an unknown key : [" + paramType + "]. Please open an issue asking to add the support for this key."); } return res; - }); + }) + .def_readonly_static("identity", &GenericOperator_Op::Identity) + .def("compute_output_dims", &GenericOperator_Op::computeOutputDims) + .def("set_compute_output_dims", &GenericOperator_Op::setComputeOutputDims, py::arg("computation_function")) + ; m.def("GenericOperator", &GenericOperator, py::arg("type"), py::arg("nbDataIn"), py::arg("nbIn"), py::arg("nbOut"), py::arg("name") = "");