From 2b71cded25b356360527a9b78867f0a5a87f98c3 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 25 Jun 2024 21:58:12 +0000 Subject: [PATCH] Fix GraphView::replace() if the remove Operator had no input The previous behaviour was to keep the Tensor of the removed layer. Now it is also removed --- include/aidge/operator/Operator.hpp | 1 + include/aidge/operator/OperatorTensor.hpp | 3 ++- src/graph/GraphView.cpp | 4 ++++ src/operator/OperatorTensor.cpp | 10 +++++++--- 4 files changed, 14 insertions(+), 4 deletions(-) diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 31aa0f0eb..d09c440d9 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -90,6 +90,7 @@ public: * @param data Data to copy. */ virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0; + virtual void resetInput(const IOIndex_t inputIdx) = 0; /** * @brief Set the specified input value by performing a deep copy of the given data. diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index 1197adb9c..2737a6d93 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -51,6 +51,7 @@ public: /////////////////////////////////////////////////// virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override; + void resetInput(const IOIndex_t inputIdx) override final; /////////////////////////////////////////////////// /////////////////////////////////////////////////// @@ -84,7 +85,7 @@ public: virtual void setDataType(const DataType& dataType) const override; virtual void setDataFormat(const DataFormat& dataFormat) const override; - + virtual void forward() override; }; } // namespace Aidge diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 1581ac843..33d31636c 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -1052,6 +1052,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const for (const auto& child : outputChildren[i]) { inputParents[i].first -> addChild(child.first, inputParents[i].second, child.second); } + } else { + for (const auto& child : outputChildren[i]) { + child.first->getOperator()->resetInput(child.second); + } } } } diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index a05155085..5da450311 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -9,7 +9,6 @@ * ********************************************************************************/ -#include <cassert> #include <memory> #include "aidge/operator/OperatorTensor.hpp" @@ -51,6 +50,11 @@ void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, cons mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data); } +void Aidge::OperatorTensor::resetInput(const Aidge::IOIndex_t inputIdx) { + AIDGE_ASSERT(inputIdx < nbInputs(), "Input idx out of range."); + mInputs[inputIdx] = nullptr; +} + void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) { AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); if (getInput(inputIdx)) { @@ -160,8 +164,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const { // TODO: Fix -> if there is no parameter input connected (e.g optional bias) then this function will fail. // This behaviour should be decided in its own dedicated issue. for (IOIndex_t i = nbData(); i < nbInputs(); ++i) { - AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type()); - getInput(i)->setDataType(dataType); + if (getInput(i)) + getInput(i)->setDataType(dataType); } } -- GitLab