diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 8a0c3b275c09ac4c685078201a8d3f1e7d833db7..d7a6e27fb1a739bd8b27411cf21b30bf58e2a3ad 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -427,16 +427,36 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ // Link every tensor to the right pointer // following parent - children informations if (!dims.empty()){ - AIDGE_ASSERT(dims.size() == mInputNodes.size(), "GraphView forwardDims error - Inconsistent number of given dimensions ({}) and graph inputs ({})", dims.size(), mInputNodes.size()); - for (std::size_t i = 0; i < dims.size(); ++i){ + Log::debug("forwardDims(): setting graph input dims ({} dims provided).", dims.size()); + + std::size_t i = 0; + for (auto& input : mInputNodes) { const auto& currentTensorPtr = - std::dynamic_pointer_cast<OperatorTensor>(mInputNodes[i].first->getOperator())->getInput(mInputNodes[i].second); - if (currentTensorPtr) { // tensor detected - AIDGE_ASSERT(currentTensorPtr->dims() == dims[i], "Tensor of unexpected size provided.") - } else { - auto tensor = std::make_shared<Tensor>(dims[i]); - mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor); + std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second); + if (i < dims.size() && !dims[i].empty()) { + if (currentTensorPtr) { // tensor detected + AIDGE_ASSERT(currentTensorPtr->dims() == dims[i], + "forwardDims(): mismatch between existing and provided size for graph input#{} (existing size: {}, provided size: {})", + i, currentTensorPtr->dims(), dims[i]) + } else { + auto tensor = std::make_shared<Tensor>(dims[i]); + input.first->getOperator()->setInput(input.second, tensor); + } + } + else { + const bool optional = (input.first->inputCategory(input.second) == InputCategory::OptionalData + || input.first->inputCategory(input.second) == InputCategory::OptionalParam); + + if (currentTensorPtr) { + Log::debug("forwardDims(): existing dims are {} for graph input#{} for input#{} of node {} (of type {})", + i, input.second, input.first->name(), input.first->type(), currentTensorPtr->dims()); + } + else if (!optional) { + Log::warn("forwardDims(): did not specify dims for mandatory graph input#{} for input#{} of node {} (of type {})", + i, input.second, input.first->name(), input.first->type()); + } } + ++i; } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index d63c93deb1ba2d7974ffc6e5b8ccd1e9c57dc76c..4585e08d5ca3a2c37e9d8911cea9b1f25c3720b6 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -526,23 +526,28 @@ void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Te // This version of connect inputs only connects tensor inputs in input data producers. auto inputNodes = mGraphView->getOrderedInputs(); - // Assert that the number of input data producers corresponds to the number of data input - if (data.size() != inputNodes.size()) { - const std::map<std::shared_ptr<Node>, std::string> namePtrTable - = mGraphView->getRankedNodesName("{0} ({1}#{3})"); - - std::vector<std::pair<std::string, IOIndex_t>> inputNodesName; - std::transform(inputNodes.begin(), inputNodes.end(), - std::back_inserter(inputNodesName), - [&namePtrTable](auto val){ return std::make_pair(namePtrTable.at(val.first), val.second); }); - - AIDGE_THROW_OR_ABORT(std::runtime_error, "Provided {} inputs to the scheduler, but graph has {} inputs (required inputs in order: )", - data.size(), inputNodes.size(), inputNodesName); - } - - for (std::size_t i = 0; i < data.size(); ++i){ - // TODO : maybe shallow copy instead of deepcopy - inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]); + std::size_t i = 0; + for (auto& input : inputNodes) { + if (i < data.size() && data[i]) { + // TODO : maybe shallow copy instead of deepcopy + input.first->getOperator()->setInput(input.second, data[i]); + } + else { + const auto& currentTensorPtr = + std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second); + const bool optional = (input.first->inputCategory(input.second) == InputCategory::OptionalData + || input.first->inputCategory(input.second) == InputCategory::OptionalParam); + + if (currentTensorPtr) { + Log::debug("connectInputs(): existing tensor dims are {} for graph input#{} for input#{} of node {} (of type {})", + i, input.second, input.first->name(), input.first->type(), currentTensorPtr->dims()); + } + else if (!optional) { + Log::warn("connectInputs(): did not specify tensor for mandatory graph input#{} for input#{} of node {} (of type {})", + i, input.second, input.first->name(), input.first->type()); + } + } + ++i; } }