From 689a11f5a8b764f466ad6637b10897e1e19503fc Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 22 Sep 2023 11:37:58 +0000 Subject: [PATCH] [GraphView.replaceWith] Fix issue when replacing a node with parameters. --- src/graph/GraphView.cpp | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index a06410322..9798dfe63 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -197,7 +197,7 @@ void Aidge::GraphView::forwardDims() { { assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); } - + } } // Compute dimensions of every node @@ -533,28 +533,24 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) { assert(mNodes.size()>0 && "There must be at least one Node to replace"); bool replacable; - std::shared_ptr<Node> previousInputNode; - std::shared_ptr<Node> newInputNode; - std::shared_ptr<Node> previousOutputNode; + std::shared_ptr<Node> previousInputNode = (*inputNodes().begin()); + std::shared_ptr<Node> previousOutputNode = (*outputNodes().begin()); std::shared_ptr<Node> newOutputNode; - + auto gNew = std::make_shared<GraphView>(); gNew->add(newNodes, false); if (newNodes.empty()) { replacable = (outputNodes().size() == 1) && - (inputNodes().size() == 1) && - ((*outputNodes().begin())->nbOutputs() == 1) && - ((*inputNodes().begin())->nbInputs() == 1); - previousOutputNode = (*outputNodes().begin()); - previousInputNode = (*inputNodes().begin()); + (inputNodes().size() == 1) && + ((*outputNodes().begin())->nbOutputs() == 1) && + ((*inputNodes().begin())->nbDataInputs() == 1); newOutputNode = previousInputNode->input(0).first; } else { - replacable = ((outputNodes().size() == gNew->outputNodes().size()) && - (outputNodes().size() == 1)); - previousOutputNode = (*outputNodes().begin()); newOutputNode = (*gNew->outputNodes().begin()); - replacable = replacable && (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs()); + replacable = (outputNodes().size() == gNew->outputNodes().size()) && + (outputNodes().size() == 1) && + (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs()); } if (replacable) { @@ -673,4 +669,4 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) { mOutputNodes.erase(val); } } -} \ No newline at end of file +} -- GitLab