Skip to content
Snippets Groups Projects
Commit aeb8d162 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed Identity to not require forwardDims() and removed associateInput() from forwardDims()

parent b3a28bdf
No related branches found
No related tags found
No related merge requests found
......@@ -78,29 +78,10 @@ public:
}
void forward() override final { runHooks(); }
void forward() override final;
void backward() override final { }
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override final {
AIDGE_ASSERT(data->type() == "Tensor", "{} Operator only accepts Tensors as outputs", type());
AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
*mInputs[outputIdx] = *std::dynamic_pointer_cast<Tensor>(data);
}
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override final {
AIDGE_ASSERT(data->type() == "Tensor", "{} Operator only accepts Tensors as inputs", type());
AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
*mInputs[outputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
}
const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const override final {
AIDGE_ASSERT(outputIdx < nbInputs(), "{} Operator has {} outputs", type(), nbInputs());
if (mInputs[outputIdx] == nullptr){
return mOutputs[outputIdx]; // Input is not initialized with empty tensor
}
return mInputs[outputIdx]; // Identity, so Output is Input
}
void setBackend(const std::string& /*name*/, DeviceIdx_t /*device*/ = 0) override final {
// setBackend do nothing, Identity node has no backend it just pass the same Tensor
}
......
......@@ -406,19 +406,14 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
// Ensure every node in the graph is correctly connected
for (std::shared_ptr<Node> nodePtr : getNodes()) {
for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) {
// assess if the input was not already set and is a Tensor then link it to parent output
std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i);
if (inputI.first) {
if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
// assert provided Data is of "Tensor" type
nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second));
}
else {
AIDGE_ASSERT(false, "Non-tensor entries not handled yet, for node {} (of type {}).", nodePtr->name(), nodePtr->type());
}
}
// Check that tensors are properly connected...
AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i) == inputI.first->getOperator()->getRawOutput(inputI.second),
"Input#{} for node {} ({}) is not properly connected to output#{} of node {} ({}): Data or Tensor mismatch!",
i, nodePtr->name(), nodePtr->type(), inputI.second, inputI.first->name(), inputI.first->type());
} else {
// Input is missing
AIDGE_ASSERT(nodePtr->getOperator()->getRawInput(i)
&& !std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty(),
"Missing input#{} for node {} ({})", i, nodePtr->name(), nodePtr->type());
......
......@@ -13,4 +13,10 @@
#include "aidge/operator/Identity.hpp"
const std::string Aidge::Identity_Op::Type = "Identity";
\ No newline at end of file
const std::string Aidge::Identity_Op::Type = "Identity";
void Aidge::Identity_Op::forward() {
// Perform a shallow copy
*(mOutputs[0]) = *(mInputs[0]);
runHooks();
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment