diff --git a/python_binding/operator/pybind_OperatorTensor.cpp b/python_binding/operator/pybind_OperatorTensor.cpp index ce34dea158e6df1466db415b2539962c2113d42b..386a3af6c7c6e9dfad34ec2e56189a53797b59d9 100644 --- a/python_binding/operator/pybind_OperatorTensor.cpp +++ b/python_binding/operator/pybind_OperatorTensor.cpp @@ -21,6 +21,9 @@ void init_OperatorTensor(py::module& m){ py::class_<OperatorTensor, std::shared_ptr<OperatorTensor>, Operator>(m, "OperatorTensor") .def("get_output", &OperatorTensor::getOutput, py::arg("outputIdx")) .def("get_input", &OperatorTensor::getInput, py::arg("inputIdx")) + + .def("set_output", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setOutput, py::arg("outputIdx"), py::arg("data")) + .def("set_input", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setInput, py::arg("outputIdx"), py::arg("data")) .def("output_dims_forwarded", &OperatorTensor::outputDimsForwarded) ; }