From aeb8d162fb18099e74c2339474cd6257a511b8dc Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Tue, 9 Apr 2024 10:39:06 +0200 Subject: [PATCH] Fixed Identity to not require forwardDims() and removed associateInput() from forwardDims() --- include/aidge/operator/Identity.hpp | 21 +-------------------- src/graph/GraphView.cpp | 15 +++++---------- src/operator/Identity.cpp | 8 +++++++- 3 files changed, 13 insertions(+), 31 deletions(-) diff --git a/include/aidge/operator/Identity.hpp b/include/aidge/operator/Identity.hpp index 08634d9fa..51c70eae5 100644 --- a/include/aidge/operator/Identity.hpp +++ b/include/aidge/operator/Identity.hpp @@ -78,29 +78,10 @@ public: } - void forward() override final { runHooks(); } + void forward() override final; void backward() override final { } - void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final { - AIDGE_ASSERT(data->type() == "Tensor", "{} Operator only accepts Tensors as outputs", type()); - AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs()); - *mInputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data); - } - - void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final { - AIDGE_ASSERT(data->type() == "Tensor", "{} Operator only accepts Tensors as inputs", type()); - AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs()); - *mInputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data)); - } - - const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const override final { - AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs()); - if (mInputs[outputIdx] == nullptr){ - return mOutputs[outputIdx]; // Input is not initialized with empty tensor - } - return mInputs[outputIdx]; // Identity, so Output is Input - } void setBackend(const std::string& /*name*/, DeviceIdx_t /*device*/ = 0) override final { // setBackend do nothing, Identity node has no backend it just pass the same Tensor } diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 9b53a9d82..88c7383a9 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -406,19 +406,14 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ // Ensure every node in the graph is correctly connected for (std::shared_ptr<Node> nodePtr : getNodes()) { for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { - // assess if the input was not already set and is a Tensor then link it to parent output std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i); if (inputI.first) { - if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) { - if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { - // assert provided Data is of "Tensor" type - nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second)); - } - else { - AIDGE_ASSERT(false, "Non-tensor entries not handled yet, for node {} (of type {}).", nodePtr->name(), nodePtr->type()); - } - } + // Check that tensors are properly connected... + AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i) == inputI.first->getOperator()->getRawOutput(inputI.second), + "Input#{} for node {} ({}) is not properly connected to output#{} of node {} ({}): Data or Tensor mismatch!", + i, nodePtr->name(), nodePtr->type(), inputI.second, inputI.first->name(), inputI.first->type()); } else { + // Input is missing AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i) && !std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty(), "Missing input#{} for node {} ({})", i, nodePtr->name(), nodePtr->type()); diff --git a/src/operator/Identity.cpp b/src/operator/Identity.cpp index f57906dd4..2b8107bfc 100644 --- a/src/operator/Identity.cpp +++ b/src/operator/Identity.cpp @@ -13,4 +13,10 @@ #include "aidge/operator/Identity.hpp" -const std::string Aidge::Identity_Op::Type = "Identity"; \ No newline at end of file +const std::string Aidge::Identity_Op::Type = "Identity"; + +void Aidge::Identity_Op::forward() { + // Perform a shallow copy + *(mOutputs[0]) = *(mInputs[0]); + runHooks(); +} -- GitLab