Skip to content
Snippets Groups Projects
Commit 689a11f5 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[GraphView.replaceWith] Fix issue when replacing a node with parameters.

parent ade2edaf
No related branches found
No related tags found
1 merge request!9Fuse bn
This commit is part of merge request !9. Comments created here will be created in the context of that merge request.
...@@ -197,7 +197,7 @@ void Aidge::GraphView::forwardDims() { ...@@ -197,7 +197,7 @@ void Aidge::GraphView::forwardDims() {
{ {
assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
} }
} }
} }
// Compute dimensions of every node // Compute dimensions of every node
...@@ -533,28 +533,24 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { ...@@ -533,28 +533,24 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
assert(mNodes.size()>0 && "There must be at least one Node to replace"); assert(mNodes.size()>0 && "There must be at least one Node to replace");
bool replacable; bool replacable;
std::shared_ptr<Node> previousInputNode; std::shared_ptr<Node> previousInputNode = (*inputNodes().begin());
std::shared_ptr<Node> newInputNode; std::shared_ptr<Node> previousOutputNode = (*outputNodes().begin());
std::shared_ptr<Node> previousOutputNode;
std::shared_ptr<Node> newOutputNode; std::shared_ptr<Node> newOutputNode;
auto gNew = std::make_shared<GraphView>(); auto gNew = std::make_shared<GraphView>();
gNew->add(newNodes, false); gNew->add(newNodes, false);
if (newNodes.empty()) { if (newNodes.empty()) {
replacable = (outputNodes().size() == 1) && replacable = (outputNodes().size() == 1) &&
(inputNodes().size() == 1) && (inputNodes().size() == 1) &&
((*outputNodes().begin())->nbOutputs() == 1) && ((*outputNodes().begin())->nbOutputs() == 1) &&
((*inputNodes().begin())->nbInputs() == 1); ((*inputNodes().begin())->nbDataInputs() == 1);
previousOutputNode = (*outputNodes().begin());
previousInputNode = (*inputNodes().begin());
newOutputNode = previousInputNode->input(0).first; newOutputNode = previousInputNode->input(0).first;
} else { } else {
replacable = ((outputNodes().size() == gNew->outputNodes().size()) &&
(outputNodes().size() == 1));
previousOutputNode = (*outputNodes().begin());
newOutputNode = (*gNew->outputNodes().begin()); newOutputNode = (*gNew->outputNodes().begin());
replacable = replacable && (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs()); replacable = (outputNodes().size() == gNew->outputNodes().size()) &&
(outputNodes().size() == 1) &&
(previousOutputNode->nbOutputs() == newOutputNode->nbOutputs());
} }
if (replacable) { if (replacable) {
...@@ -673,4 +669,4 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { ...@@ -673,4 +669,4 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
mOutputNodes.erase(val); mOutputNodes.erase(val);
} }
} }
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment