From da06710b62623ed9bf48888140937f451b68f77d Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 27 Mar 2025 13:43:51 +0000
Subject: [PATCH] Fix OperatorTensor::forwardDims() if OptionalData is not
 connected.

---
 python_binding/operator/pybind_OperatorTensor.cpp | 1 +
 src/operator/OperatorTensor.cpp                   | 6 ++++++
 2 files changed, 7 insertions(+)

diff --git a/python_binding/operator/pybind_OperatorTensor.cpp b/python_binding/operator/pybind_OperatorTensor.cpp
index 2602e115d..350c0958a 100644
--- a/python_binding/operator/pybind_OperatorTensor.cpp
+++ b/python_binding/operator/pybind_OperatorTensor.cpp
@@ -33,6 +33,7 @@ void init_OperatorTensor(py::module& m){
     .def("set_output", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&) const) &OperatorTensor::setOutput, py::arg("outputIdx"), py::arg("data"))
     .def("set_input", (void (OperatorTensor::*)(const IOIndex_t, const std::shared_ptr<Data>&)) &OperatorTensor::setInput, py::arg("outputIdx"), py::arg("data"))
     .def("forward_dims", &OperatorTensor::forwardDims, py::arg("allow_data_dependency") = false)
+    .def("forward_dtype", &OperatorTensor::forwardDType)
     .def("dims_forwarded", &OperatorTensor::dimsForwarded)
     ;
 }
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
index e1b803c14..b7aa5e707 100644
--- a/src/operator/OperatorTensor.cpp
+++ b/src/operator/OperatorTensor.cpp
@@ -181,6 +181,12 @@ bool Aidge::OperatorTensor::forwardDType(){
                 // Param input can be different dtype than data input
                 continue;
             }
+            if (inputCategory(i) == InputCategory::OptionalData
+            && !getInput(i)){
+                // If OptionalData is not set, skip
+                continue;
+            }
+
             if (expectedDType != getInput(i)->dataType()) {
                 Log::info("{} operator's inputs should have the same datatype: expected {} (input #0), given {} (input #{})",
                     type(), expectedDType, getInput(i)->dataType(), i);
-- 
GitLab