Skip to content
Snippets Groups Projects
Commit aeb8d162 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed Identity to not require forwardDims() and removed associateInput() from forwardDims()

parent b3a28bdf
No related branches found
No related tags found
3 merge requests!1190.2.1,!113Draft: Fix slice,!104Make forwardDims() optional and handle data dependency
Pipeline #43274 failed
...@@ -78,29 +78,10 @@ public: ...@@ -78,29 +78,10 @@ public:
} }
void forward() override final { runHooks(); } void forward() override final;
void backward() override final { } void backward() override final { }
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final {
AIDGE_ASSERT(data->type() == "Tensor", "{} Operator only accepts Tensors as outputs", type());
AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
*mInputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final {
AIDGE_ASSERT(data->type() == "Tensor", "{} Operator only accepts Tensors as inputs", type());
AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
*mInputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const override final {
AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
if (mInputs[outputIdx] == nullptr){
return mOutputs[outputIdx]; // Input is not initialized with empty tensor
}
return mInputs[outputIdx]; // Identity, so Output is Input
}
void setBackend(const std::string& /*name*/, DeviceIdx_t /*device*/ = 0) override final { void setBackend(const std::string& /*name*/, DeviceIdx_t /*device*/ = 0) override final {
// setBackend do nothing, Identity node has no backend it just pass the same Tensor // setBackend do nothing, Identity node has no backend it just pass the same Tensor
} }
......
...@@ -406,19 +406,14 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ ...@@ -406,19 +406,14 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
// Ensure every node in the graph is correctly connected // Ensure every node in the graph is correctly connected
for (std::shared_ptr<Node> nodePtr : getNodes()) { for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output
std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i); std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
if (inputI.first) { if (inputI.first) {
if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) { // Check that tensors are properly connected...
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i) == inputI.first->getOperator()->getRawOutput(inputI.second),
// assert provided Data is of "Tensor" type "Input#{} for node {} ({}) is not properly connected to output#{} of node {} ({}): Data or Tensor mismatch!",
nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second)); i, nodePtr->name(), nodePtr->type(), inputI.second, inputI.first->name(), inputI.first->type());
}
else {
AIDGE_ASSERT(false, "Non-tensor entries not handled yet, for node {} (of type {}).", nodePtr->name(), nodePtr->type());
}
}
} else { } else {
// Input is missing
AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i) AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i)
&& !std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty(), && !std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty(),
"Missing input#{} for node {} ({})", i, nodePtr->name(), nodePtr->type()); "Missing input#{} for node {} ({})", i, nodePtr->name(), nodePtr->type());
......
...@@ -13,4 +13,10 @@ ...@@ -13,4 +13,10 @@
#include "aidge/operator/Identity.hpp" #include "aidge/operator/Identity.hpp"
const std::string Aidge::Identity_Op::Type = "Identity"; const std::string Aidge::Identity_Op::Type = "Identity";
\ No newline at end of file
void Aidge::Identity_Op::forward() {
// Perform a shallow copy
*(mOutputs[0]) = *(mInputs[0]);
runHooks();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment