Skip to content
Snippets Groups Projects
Commit 59fb7f6e authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fix issues with forwardDims()

parent dd09301f
No related branches found
No related tags found
No related merge requests found
......@@ -523,7 +523,6 @@ private:
// TOPOLOGY
///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes);
};
/**
......
......@@ -328,8 +328,6 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
}
void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) {
std::set<NodePtr> startNodes = inputNodes();
// setInputs
// Link every tensor to the right pointer
// following parent - children informations
......@@ -340,7 +338,8 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor);
}
}
// Ensure every node in the graph is correctly connected
for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output
......@@ -362,60 +361,37 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
}
}
if (nodePtr->type() == Producer_Op::Type) {
startNodes.insert(nodePtr);
}
}
// Compute dimensions of every node
_forwardDims(startNodes);
}
void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) {
// TODO: support multi-inputs/outputs
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>();
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
if (!op->outputDimsForwarded()) {
op->computeOutputDims();
}
if (!op->outputDimsForwarded()) { // try to compute output dimensions again later
nextList.insert(nodePtr);
} else { // compute output dimensions of children
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
for (auto child : children) {
const auto childOp = std::static_pointer_cast<OperatorTensor>(child->getOperator());
if (!childOp->outputDimsForwarded()) {
nextList.insert(child);
}
}
}
}
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
// Compute dimensions of every node
std::set<std::shared_ptr<Node>> listNodes = getNodes();
do {
std::set<std::shared_ptr<Node>> nextList;
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
// Recompute everytime, even if it was already computed in a
// previous call of forwardDims(), as the graph may have changed!
op->computeOutputDims();
if (!op->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
}
}
}
// Internal check to make sure we won't enter in an infinite loop!
if (nextList == listNodes) {
std::vector<std::string> nodesName;
std::transform(nextList.begin(), nextList.end(),
std::back_inserter(nodesName),
[](auto val){ return val->name() + " (" + val->type() + ")"; });
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName);
}
// Internal check to make sure we won't enter in an infinite loop!
if (nextList == listNodes) {
// We are stuck!
std::vector<std::string> nodesName;
std::transform(nextList.begin(), nextList.end(),
std::back_inserter(nodesName),
[](auto val){ return val->name() + " (" + val->type() + ")"; });
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName);
}
if (!nextList.empty()) {
_forwardDims(nextList);
listNodes.swap(nextList);
}
while (!listNodes.empty());
}
void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment