Skip to content
Snippets Groups Projects
Commit 2b71cded authored by Maxence Naud's avatar Maxence Naud
Browse files

Fix GraphView::replace() if the remove Operator had no input

The previous behaviour was to keep the Tensor of the removed layer. Now it is also removed
parent 58d698bb
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!145Improve UI for Operator/Node/GraphView/Tensor
......@@ -90,6 +90,7 @@ public:
* @param data Data to copy.
*/
virtual void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void resetInput(const IOIndex_t inputIdx) = 0;
/**
* @brief Set the specified input value by performing a deep copy of the given data.
......
......@@ -51,6 +51,7 @@ public:
///////////////////////////////////////////////////
virtual void associateInput(const IOIndex_t inputIdx,
const std::shared_ptr<Data>& data) override;
void resetInput(const IOIndex_t inputIdx) override final;
///////////////////////////////////////////////////
///////////////////////////////////////////////////
......@@ -84,7 +85,7 @@ public:
virtual void setDataType(const DataType& dataType) const override;
virtual void setDataFormat(const DataFormat& dataFormat) const override;
virtual void forward() override;
};
} // namespace Aidge
......
......@@ -1052,6 +1052,10 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const
for (const auto& child : outputChildren[i]) {
inputParents[i].first -> addChild(child.first, inputParents[i].second, child.second);
}
} else {
for (const auto& child : outputChildren[i]) {
child.first->getOperator()->resetInput(child.second);
}
}
}
}
......
......@@ -9,7 +9,6 @@
*
********************************************************************************/
#include <cassert>
#include <memory>
#include "aidge/operator/OperatorTensor.hpp"
......@@ -51,6 +50,11 @@ void Aidge::OperatorTensor::associateInput(const Aidge::IOIndex_t inputIdx, cons
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(data);
}
void Aidge::OperatorTensor::resetInput(const Aidge::IOIndex_t inputIdx) {
AIDGE_ASSERT(inputIdx < nbInputs(), "Input idx out of range.");
mInputs[inputIdx] = nullptr;
}
void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std::shared_ptr<Aidge::Data>& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
if (getInput(inputIdx)) {
......@@ -160,8 +164,8 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
// TODO: Fix -> if there is no parameter input connected (e.g optional bias) then this function will fail.
// This behaviour should be decided in its own dedicated issue.
for (IOIndex_t i = nbData(); i < nbInputs(); ++i) {
AIDGE_ASSERT(getInput(i) != nullptr, "Missing input#{} for operator {}", i, type());
getInput(i)->setDataType(dataType);
if (getInput(i))
getInput(i)->setDataType(dataType);
}
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment