From b64f7286e21106b8b2c16abd92cfcad898b6c6da Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Mon, 29 Jul 2024 17:50:19 +0200
Subject: [PATCH] Added same behavior than forwardDims() for connectInputs()

---
 src/scheduler/Scheduler.cpp | 39 +++++++++++++++++++++----------------
 1 file changed, 22 insertions(+), 17 deletions(-)

diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index d63c93deb..4585e08d5 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -526,23 +526,28 @@ void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Te
     // This version of connect inputs only connects tensor inputs in input data producers.
     auto inputNodes = mGraphView->getOrderedInputs();
 
-    // Assert that the number of input data producers corresponds to the number of data input
-    if (data.size() != inputNodes.size()) {
-        const std::map<std::shared_ptr<Node>, std::string> namePtrTable
-            = mGraphView->getRankedNodesName("{0} ({1}#{3})");
-
-        std::vector<std::pair<std::string, IOIndex_t>> inputNodesName;
-        std::transform(inputNodes.begin(), inputNodes.end(),
-            std::back_inserter(inputNodesName),
-            [&namePtrTable](auto val){ return std::make_pair(namePtrTable.at(val.first), val.second); });
-
-        AIDGE_THROW_OR_ABORT(std::runtime_error, "Provided {} inputs to the scheduler, but graph has {} inputs (required inputs in order: )",
-            data.size(), inputNodes.size(), inputNodesName);
-    }
-
-    for (std::size_t i = 0; i < data.size(); ++i){
-        // TODO : maybe shallow copy instead of deepcopy
-        inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]);
+    std::size_t i = 0;
+    for (auto& input : inputNodes) {
+        if (i < data.size() && data[i]) {
+            // TODO : maybe shallow copy instead of deepcopy
+            input.first->getOperator()->setInput(input.second, data[i]);
+        }
+        else {
+            const auto& currentTensorPtr =
+                std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second);
+            const bool optional = (input.first->inputCategory(input.second) == InputCategory::OptionalData
+                || input.first->inputCategory(input.second) == InputCategory::OptionalParam);
+
+            if (currentTensorPtr) {
+                Log::debug("connectInputs(): existing tensor dims are {} for graph input#{} for input#{} of node {} (of type {})",
+                    i, input.second, input.first->name(), input.first->type(), currentTensorPtr->dims());
+            }
+            else if (!optional) {
+                Log::warn("connectInputs(): did not specify tensor for mandatory graph input#{} for input#{} of node {} (of type {})",
+                    i, input.second, input.first->name(), input.first->type());
+            }
+        }
+        ++i;
     }
 }
 
-- 
GitLab