From 689a11f5a8b764f466ad6637b10897e1e19503fc Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 22 Sep 2023 11:37:58 +0000
Subject: [PATCH] [GraphView.replaceWith] Fix issue when replacing a node with
 parameters.

---
 src/graph/GraphView.cpp | 26 +++++++++++---------------
 1 file changed, 11 insertions(+), 15 deletions(-)

diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index a06410322..9798dfe63 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -197,7 +197,7 @@ void Aidge::GraphView::forwardDims() {
             {
               assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty());
             }
-            
+
         }
     }
     // Compute dimensions of every node
@@ -533,28 +533,24 @@ bool Aidge::GraphView::replaceWith(std::set<std::shared_ptr<Node>> newNodes) {
   assert(mNodes.size()>0 && "There must be at least one Node to replace");
 
   bool replacable;
-  std::shared_ptr<Node> previousInputNode;
-  std::shared_ptr<Node> newInputNode;
-  std::shared_ptr<Node> previousOutputNode;
+  std::shared_ptr<Node> previousInputNode = (*inputNodes().begin());
+  std::shared_ptr<Node> previousOutputNode = (*outputNodes().begin());
   std::shared_ptr<Node> newOutputNode;
-  
+
   auto gNew = std::make_shared<GraphView>();
   gNew->add(newNodes, false);
 
   if (newNodes.empty()) {
     replacable = (outputNodes().size() == 1) &&
-                      (inputNodes().size() == 1) &&
-                      ((*outputNodes().begin())->nbOutputs() == 1) &&
-                      ((*inputNodes().begin())->nbInputs() == 1);
-    previousOutputNode = (*outputNodes().begin());
-    previousInputNode = (*inputNodes().begin());
+                 (inputNodes().size() == 1) &&
+                 ((*outputNodes().begin())->nbOutputs() == 1) &&
+                 ((*inputNodes().begin())->nbDataInputs() == 1);
     newOutputNode = previousInputNode->input(0).first;
   } else {
-    replacable = ((outputNodes().size() == gNew->outputNodes().size()) &&
-                     (outputNodes().size() == 1));
-    previousOutputNode = (*outputNodes().begin());
     newOutputNode = (*gNew->outputNodes().begin());
-    replacable = replacable && (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs());
+    replacable = (outputNodes().size() == gNew->outputNodes().size()) &&
+                 (outputNodes().size() == 1) &&
+                 (previousOutputNode->nbOutputs() == newOutputNode->nbOutputs());
   }
 
   if (replacable) {
@@ -673,4 +669,4 @@ void Aidge::GraphView::removeOutputNode(const std::string nodeName) {
       mOutputNodes.erase(val);
     }
   }
-}
\ No newline at end of file
+}
-- 
GitLab