From 2b71cded25b356360527a9b78867f0a5a87f98c3 Mon Sep 17 00:00:00 2001
From: NAUD Maxence <maxence.naud@cea.fr>
Date: Tue, 25 Jun 2024 21:58:12 +0000
Subject: [PATCH] Fix GraphView::replace() if the remove Operator had no input

The previous behaviour was to keep the Tensor of the removed layer. Now it is also removed
---
 include/aidge/operator/Operator.hpp       |  1 +
 include/aidge/operator/OperatorTensor.hpp |  3 ++-
 src/graph/GraphView.cpp                   |  4 ++++
 src/operator/OperatorTensor.cpp           | 10 +++++++---
 4 files changed, 14 insertions(+), 4 deletions(-)

diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 31aa0f0eb..d09c440d9 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -90,6 +90,7 @@ public:
      * @param data Data to copy.
      */
     virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
+    virtual void resetInput(const IOIndex_t inputIdx) = 0;
 
     /**
      * @brief Set the specified input value by performing a deep copy of the given data.
diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp
index 1197adb9c..2737a6d93 100644
--- a/include/aidge/operator/OperatorTensor.hpp
+++ b/include/aidge/operator/OperatorTensor.hpp
@@ -51,6 +51,7 @@ public:
     ///////////////////////////////////////////////////
     virtual void associateInput(const IOIndex_t inputIdx,
                                 const std::shared_ptr<Data>& data) override;
+    void resetInput(const IOIndex_t inputIdx) override final;
     ///////////////////////////////////////////////////
 
     ///////////////////////////////////////////////////
@@ -84,7 +85,7 @@ public:
 
     virtual void setDataType(const DataType& dataType) const override;
     virtual void setDataFormat(const DataFormat& dataFormat) const override;
-    
+
     virtual void forward() override;
 };
 }  // namespace Aidge
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 1581ac843..33d31636c 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -1052,6 +1052,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
                       for (const auto& child : outputChildren[i]) {
                         inputParents[i].first -> addChild(child.first, inputParents[i].second, child.second);
                       }
+                    } else {
+                      for (const auto& child : outputChildren[i]) {
+                        child.first->getOperator()->resetInput(child.second);
+                      }
                     }
                 }
             }
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
index a05155085..5da450311 100644
--- a/src/operator/OperatorTensor.cpp
+++ b/src/operator/OperatorTensor.cpp
@@ -9,7 +9,6 @@
  *
  ********************************************************************************/
 
-#include <cassert>
 #include <memory>
 
 #include "aidge/operator/OperatorTensor.hpp"
@@ -51,6 +50,11 @@ void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, cons
     mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
 }
 
+void Aidge::OperatorTensor::resetInput(const Aidge::IOIndex_t inputIdx) {
+    AIDGE_ASSERT(inputIdx < nbInputs(), "Input idx out of range.");
+    mInputs[inputIdx] = nullptr;
+}
+
 void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
     AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
     if (getInput(inputIdx)) {
@@ -160,8 +164,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
     // TODO: Fix -> if there is no parameter input connected (e.g optional bias) then this function will fail.
     // This behaviour should be decided in its own dedicated issue.
     for (IOIndex_t i = nbData(); i < nbInputs(); ++i) {
-        AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type());
-        getInput(i)->setDataType(dataType);
+        if (getInput(i))
+            getInput(i)->setDataType(dataType);
     }
 }
 
-- 
GitLab