Skip to content
Snippets Groups Projects
Commit d00e9a7f authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'forwarddims' into 'dev'

Fix issues with forwardDims()

See merge request eclipse/aidge/aidge_core!92
parents 70729d7e f998fb57
No related branches found
No related tags found
No related merge requests found
...@@ -523,7 +523,6 @@ private: ...@@ -523,7 +523,6 @@ private:
// TOPOLOGY // TOPOLOGY
/////////////////////////////////////////////////////// ///////////////////////////////////////////////////////
void _forwardDims(std::set<NodePtr> listNodes);
}; };
/** /**
......
...@@ -328,8 +328,6 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType ...@@ -328,8 +328,6 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
} }
void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) { void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) {
std::set<NodePtr> startNodes = inputNodes();
// setInputs // setInputs
// Link every tensor to the right pointer // Link every tensor to the right pointer
// following parent - children informations // following parent - children informations
...@@ -340,7 +338,8 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ ...@@ -340,7 +338,8 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor); mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor);
} }
} }
// Ensure every node in the graph is correctly connected
for (std::shared_ptr<Node> nodePtr : getNodes()) { for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output // assess if the input was not already set and is a Tensor then link it to parent output
...@@ -362,60 +361,37 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ ...@@ -362,60 +361,37 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
} }
} }
if (nodePtr->type() == Producer_Op::Type) {
startNodes.insert(nodePtr);
}
} }
// Compute dimensions of every node
_forwardDims(startNodes);
} // Compute dimensions of every node
std::set<std::shared_ptr<Node>> listNodes = getNodes();
void Aidge::GraphView::_forwardDims(std::set<std::shared_ptr<Node>> listNodes) { do {
// TODO: support multi-inputs/outputs std::set<std::shared_ptr<Node>> nextList;
std::set<std::shared_ptr<Node>> nextList = std::set<std::shared_ptr<Node>>(); for (std::shared_ptr<Node> nodePtr : listNodes) {
for (std::shared_ptr<Node> nodePtr : listNodes) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
if (!op->outputDimsForwarded()) {
op->computeOutputDims();
}
if (!op->outputDimsForwarded()) { // try to compute output dimensions again later
nextList.insert(nodePtr);
} else { // compute output dimensions of children
std::set<std::shared_ptr<Node>> children = nodePtr->getChildren();
for (auto child : children) {
const auto childOp = std::static_pointer_cast<OperatorTensor>(child->getOperator());
if (!childOp->outputDimsForwarded()) {
nextList.insert(child);
}
}
}
}
}
if (nextList.empty()) {
for (std::shared_ptr<Node> nodePtr : getNodes()) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
if (!std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator())->outputDimsForwarded()) { const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
nextList.insert(nodePtr); // Recompute everytime, even if it was already computed in a
} // previous call of forwardDims(), as the graph may have changed!
op->computeOutputDims();
if (!op->outputDimsForwarded()) {
nextList.insert(nodePtr);
}
} }
} }
}
// Internal check to make sure we won't enter in an infinite loop! // Internal check to make sure we won't enter in an infinite loop!
if (nextList == listNodes) { if (nextList == listNodes) {
std::vector<std::string> nodesName; // We are stuck!
std::transform(nextList.begin(), nextList.end(), std::vector<std::string> nodesName;
std::back_inserter(nodesName), std::transform(nextList.begin(), nextList.end(),
[](auto val){ return val->name() + " (" + val->type() + ")"; }); std::back_inserter(nodesName),
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName); [](auto val){ return val->name() + " (" + val->type() + ")"; });
} AIDGE_THROW_OR_ABORT(std::runtime_error, "Unable to forward dimensions (circular dependency and/or wrong dimensions?). Unable to compute output dims for nodes {}.", nodesName);
}
if (!nextList.empty()) { listNodes.swap(nextList);
_forwardDims(nextList);
} }
while (!listNodes.empty());
} }
void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) { void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment