Skip to content
Snippets Groups Projects
Commit 689a11f5 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[GraphView.replaceWith] Fix issue when replacing a node with parameters.

parent ade2edaf
No related branches found
No related tags found
No related merge requests found
......@@ -197,7 +197,7 @@ void Aidge::GraphView::forwardDims() {
{
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
}
}
}
// Compute dimensions of every node
......@@ -533,28 +533,24 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
assert(mNodes.size()>0 && "There must be at least one Node to replace");
bool replacable;
std::shared_ptr<Node> previousInputNode;
std::shared_ptr<Node> newInputNode;
std::shared_ptr<Node> previousOutputNode;
std::shared_ptr<Node> previousInputNode = (*inputNodes().begin());
std::shared_ptr<Node> previousOutputNode = (*outputNodes().begin());
std::shared_ptr<Node> newOutputNode;
auto gNew = std::make_shared<GraphView>();
gNew->add(newNodes, false);
if (newNodes.empty()) {
replacable = (outputNodes().size() == 1) &&
(inputNodes().size() == 1) &&
((*outputNodes().begin())->nbOutputs() == 1) &&
((*inputNodes().begin())->nbInputs() == 1);
previousOutputNode = (*outputNodes().begin());
previousInputNode = (*inputNodes().begin());
(inputNodes().size() == 1) &&
((*outputNodes().begin())->nbOutputs() == 1) &&
((*inputNodes().begin())->nbDataInputs() == 1);
newOutputNode = previousInputNode->input(0).first;
} else {
replacable = ((outputNodes().size() == gNew->outputNodes().size()) &&
(outputNodes().size() == 1));
previousOutputNode = (*outputNodes().begin());
newOutputNode = (*gNew->outputNodes().begin());
replacable = replacable && (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs());
replacable = (outputNodes().size() == gNew->outputNodes().size()) &&
(outputNodes().size() == 1) &&
(previousOutputNode->nbOutputs() == newOutputNode->nbOutputs());
}
if (replacable) {
......@@ -673,4 +669,4 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
mOutputNodes.erase(val);
}
}
}
\ No newline at end of file
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment