Skip to content
Snippets Groups Projects
Commit 2a2046a7 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'Fix-forwardDims-processing-order' into 'dev'

Fix forwardDims() processing order

See merge request eclipse/aidge/aidge_core!132
parents de1a11b9 261d904b
No related branches found
No related tags found
No related merge requests found
......@@ -208,7 +208,12 @@ public:
/**
* @brief Compute dimensions of input/output Tensors for each Operator of the
* GraphView object's Nodes.
* GraphView object's Nodes, by calling Node::forwardDims().
* This function verifies the following conditions:
* - Every node will forwardDims() regardless of if dims were previously forwarded or not;
* - forwadDims() calls are made in node dependencies order, because if dims have changed
* at any point in the graph, it must de propagated correctly to all succeeding nodes;
* - It handles cyclic dependencies correctly (currently only induced by the Memorize_Op).
*/
bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
......
......@@ -72,7 +72,6 @@ public:
void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final;
bool forwardDims(bool allowDataDependency = false) override final {
// Check first that all required inputs are available, otherwise
......
......@@ -87,7 +87,6 @@ public:
* @param data Data to copy.
*/
virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0;
virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0;
/**
* @brief Set the specified output value by performing a deep copy of the given data.
......@@ -95,7 +94,6 @@ public:
* @param inputIdx Index of the input to set.
*/
virtual void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) = 0;
virtual void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) = 0;
virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0;
std::shared_ptr<Hook> getHook(const std::string& hookName) {
......
......@@ -57,13 +57,11 @@ public:
// Tensor access
// input management
void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override;
void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override;
const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const;
std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final;
// output management
void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override;
void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override;
virtual const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const;
std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final;
///////////////////////////////////////////////////
......
......@@ -107,12 +107,6 @@ public:
void backward() override final {
// fmt::print("Basic Producer backward() function.\n");
}
void setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) override {
if (getAttr<ProdAttr::Constant>()) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output.");
}
OperatorTensor::setOutput(outputIdx, std::move(data));
}
void setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) override {
if (getAttr<ProdAttr::Constant>()) {
......
......@@ -31,6 +31,7 @@
#include "aidge/operator/MetaOperator.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/utils/Directories.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
......@@ -425,22 +426,68 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
}
}
// Compute dimensions of every node
std::set<std::shared_ptr<Node>> listNodes = getNodes();
// List of nodes that are already dims forwarded
std::set<std::shared_ptr<Node>> dimsForwarded;
// Establish initial list of dims forwardable nodes:
// input nodes and childs from Producers
std::set<std::shared_ptr<Node>> listNodes = inputNodes();
for (const auto& nodePtr : getNodes()) {
if (nodePtr->type() == Producer_Op::Type) {
// Producers are already dims forwarded!
dimsForwarded.insert(nodePtr);
// Producers childs are dims forwardable
for (const auto& child : nodePtr->getChildren()) {
if (inView(child)) {
listNodes.insert(child);
}
}
}
}
do {
std::set<std::shared_ptr<Node>> nextList;
for (std::shared_ptr<Node> nodePtr : listNodes) {
for (const auto& nodePtr : listNodes) {
if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) {
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
// Recompute everytime, even if it was already computed in a
// previous call of forwardDims(), as the graph may have changed!
op->forwardDims(allowDataDependency);
if (!op->dimsForwarded()) {
nextList.insert(nodePtr);
}
const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator());
bool anyParent = false;
bool parentsForwarded = true;
for (const auto& parent : nodePtr->getParents()) {
if (parent != nullptr && inView(parent) && dimsForwarded.find(parent) == dimsForwarded.end()) {
Log::debug("Dimensions not forwarded for parent (node {} (of type {})) of node {} (of type {})",
parent->name(), parent->type(), nodePtr->name(), nodePtr->type());
parentsForwarded = false;
}
else {
anyParent = true;
}
}
// Special rule for Memorize_Op, which only requires one parent
// to have its dims forwarded. This avoids circular dependency.
if (nodePtr->type() == Memorize_Op::Type && anyParent) {
parentsForwarded = true;
}
if (parentsForwarded && op->forwardDims(allowDataDependency)) {
// Recompute everytime, even if it was already computed in a
// previous call of forwardDims(), as the graph may have changed!
dimsForwarded.insert(nodePtr);
for (const auto& child : nodePtr->getChildren()) {
if (inView(child) && dimsForwarded.find(child) == dimsForwarded.end()) {
nextList.insert(child);
}
}
}
else {
Log::debug("Unable to forward dimensions for node {} (of type {}) yet", nodePtr->name(), nodePtr->type());
nextList.insert(nodePtr);
}
}
}
Log::debug("********************");
// Internal check to make sure we won't enter in an infinite loop!
if (nextList == listNodes) {
// We are stuck!
......@@ -452,7 +499,6 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
Log::warn("Unable to forward dimensions (circular dependency and/or wrong dimensions and/or data dependent dimension?). Unable to compute output dims for nodes {}.", nodesName);
return false;
}
listNodes.swap(nextList);
}
while (!listNodes.empty());
......
......@@ -58,16 +58,6 @@ void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
}
void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Data>&& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
const auto& inputOp = mGraph->getOrderedInputs()[inputIdx];
inputOp.first->getOperator()->setInput(inputOp.second, std::forward<std::shared_ptr<Data>>(data));
// Associate inputs for custom implementation
mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
}
Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const {
if (mImpl) {
return mImpl->getNbRequiredData(inputIdx);
......
......@@ -62,15 +62,6 @@ void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std:
Aidge::OperatorTensor::~OperatorTensor() = default;
void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Aidge::Data>&& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
if (getInput(inputIdx)) {
*mInputs[inputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data));
} else {
mInputs[inputIdx] = std::make_shared<Tensor>(std::move(*std::dynamic_pointer_cast<Tensor>(data)));
}
}
std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawInput(const Aidge::IOIndex_t inputIdx) const {
return std::static_pointer_cast<Data>(getInput(inputIdx));
}
......@@ -88,15 +79,6 @@ void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, const st
*mOutputs[outputIdx] = *data_tensor;
}
void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) {
AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type());
AIDGE_ASSERT(outputIdx < nbOutputs(), "{} Operator has {} outputs", type(), nbOutputs());
auto&& data_tensor = std::dynamic_pointer_cast<Tensor>(data);
// if (mImpl)
// AIDGE_ASSERT(data_tensor->getImpl()->backend() == backend(), "Data parameter and Operator have different backends: {} and {}", data_tensor->getImpl()->backend(), backend());
*mOutputs[outputIdx] = std::move(*data_tensor);
}
std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawOutput(const Aidge::IOIndex_t outputIdx) const {
return std::static_pointer_cast<Data>(getOutput(outputIdx));
}
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment