diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index d63c93deb1ba2d7974ffc6e5b8ccd1e9c57dc76c..4585e08d5ca3a2c37e9d8911cea9b1f25c3720b6 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -526,23 +526,28 @@ void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Te // This version of connect inputs only connects tensor inputs in input data producers. auto inputNodes = mGraphView->getOrderedInputs(); - // Assert that the number of input data producers corresponds to the number of data input - if (data.size() != inputNodes.size()) { - const std::map<std::shared_ptr<Node>, std::string> namePtrTable - = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - - std::vector<std::pair<std::string, IOIndex_t>> inputNodesName; - std::transform(inputNodes.begin(), inputNodes.end(), - std::back_inserter(inputNodesName), - [&namePtrTable](auto val){ return std::make_pair(namePtrTable.at(val.first), val.second); }); - - AIDGE_THROW_OR_ABORT(std::runtime_error, "Provided {} inputs to the scheduler, but graph has {} inputs (required inputs in order: )", - data.size(), inputNodes.size(), inputNodesName); - } - - for (std::size_t i = 0; i < data.size(); ++i){ - // TODO : maybe shallow copy instead of deepcopy - inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]); + std::size_t i = 0; + for (auto& input : inputNodes) { + if (i < data.size() && data[i]) { + // TODO : maybe shallow copy instead of deepcopy + input.first->getOperator()->setInput(input.second, data[i]); + } + else { + const auto& currentTensorPtr = + std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second); + const bool optional = (input.first->inputCategory(input.second) == InputCategory::OptionalData + || input.first->inputCategory(input.second) == InputCategory::OptionalParam); + + if (currentTensorPtr) { + Log::debug("connectInputs(): existing tensor dims are {} for graph input#{} for input#{} of node {} (of type {})", + i, input.second, input.first->name(), input.first->type(), currentTensorPtr->dims()); + } + else if (!optional) { + Log::warn("connectInputs(): did not specify tensor for mandatory graph input#{} for input#{} of node {} (of type {})", + i, input.second, input.first->name(), input.first->type()); + } + } + ++i; } }