Skip to content
Snippets Groups Projects
Commit b64f7286 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added same behavior than forwardDims() for connectInputs()

parent 9191c54e
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!175Improve dims argument behavior
Pipeline #52774 passed
......@@ -526,23 +526,28 @@ void Aidge::Scheduler::connectInputs(const std::vector<std::shared_ptr<Aidge::Te
// This version of connect inputs only connects tensor inputs in input data producers.
auto inputNodes = mGraphView->getOrderedInputs();
// Assert that the number of input data producers corresponds to the number of data input
if (data.size() != inputNodes.size()) {
const std::map<std::shared_ptr<Node>, std::string> namePtrTable
= mGraphView->getRankedNodesName("{0} ({1}#{3})");
std::vector<std::pair<std::string, IOIndex_t>> inputNodesName;
std::transform(inputNodes.begin(), inputNodes.end(),
std::back_inserter(inputNodesName),
[&namePtrTable](auto val){ return std::make_pair(namePtrTable.at(val.first), val.second); });
AIDGE_THROW_OR_ABORT(std::runtime_error, "Provided {} inputs to the scheduler, but graph has {} inputs (required inputs in order: )",
data.size(), inputNodes.size(), inputNodesName);
}
for (std::size_t i = 0; i < data.size(); ++i){
// TODO : maybe shallow copy instead of deepcopy
inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]);
std::size_t i = 0;
for (auto& input : inputNodes) {
if (i < data.size() && data[i]) {
// TODO : maybe shallow copy instead of deepcopy
input.first->getOperator()->setInput(input.second, data[i]);
}
else {
const auto& currentTensorPtr =
std::dynamic_pointer_cast<OperatorTensor>(input.first->getOperator())->getInput(input.second);
const bool optional = (input.first->inputCategory(input.second) == InputCategory::OptionalData
|| input.first->inputCategory(input.second) == InputCategory::OptionalParam);
if (currentTensorPtr) {
Log::debug("connectInputs(): existing tensor dims are {} for graph input#{} for input#{} of node {} (of type {})",
i, input.second, input.first->name(), input.first->type(), currentTensorPtr->dims());
}
else if (!optional) {
Log::warn("connectInputs(): did not specify tensor for mandatory graph input#{} for input#{} of node {} (of type {})",
i, input.second, input.first->name(), input.first->type());
}
}
++i;
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment