diff --git a/python_binding/operator/pybind_OperatorTensor.cpp b/python_binding/operator/pybind_OperatorTensor.cpp index 2602e115d43d805451aa9f0836c8151b2cd4b109..350c0958a478ed699e393e815f01eeac177e92fc 100644 --- a/python_binding/operator/pybind_OperatorTensor.cpp +++ b/python_binding/operator/pybind_OperatorTensor.cpp @@ -33,6 +33,7 @@ void init_OperatorTensor(py::module& m){ .def("set_output", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&) const) &OperatorTensor::setOutput, py::arg("outputIdx"), py::arg("data")) .def("set_input", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setInput, py::arg("outputIdx"), py::arg("data")) .def("forward_dims", &OperatorTensor::forwardDims, py::arg("allow_data_dependency") = false) + .def("forward_dtype", &OperatorTensor::forwardDType) .def("dims_forwarded", &OperatorTensor::dimsForwarded) ; } diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index e1b803c145fc13bb0caf6df1c839f874341ae72a..b7aa5e707cfbb548a1e26eca49ec6ffb9bac6c99 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -181,6 +181,12 @@ bool Aidge::OperatorTensor::forwardDType(){ // Param input can be different dtype than data input continue; } + if (inputCategory(i) == InputCategory::OptionalData + && !getInput(i)){ + // If OptionalData is not set, skip + continue; + } + if (expectedDType != getInput(i)->dataType()) { Log::info("{} operator's inputs should have the same datatype: expected {} (input #0), given {} (input #{})", type(), expectedDType, getInput(i)->dataType(), i);