diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index c9a4c11d780a41a1620518047d66a7de2d7b55fa..627e78790020c04d50f839f01de2130ba8d8d774 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -208,7 +208,12 @@ public: /** * @brief Compute dimensions of input/output Tensors for each Operator of the - * GraphView object's Nodes. + * GraphView object's Nodes, by calling Node::forwardDims(). + * This function verifies the following conditions: + * - Every node will forwardDims() regardless of if dims were previously forwarded or not; + * - forwadDims() calls are made in node dependencies order, because if dims have changed + * at any point in the graph, it must de propagated correctly to all succeeding nodes; + * - It handles cyclic dependencies correctly (currently only induced by the Memorize_Op). */ bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 75c4d886ff761b93dc7b58b98a6ab72453ed804b..fb8c73af33dd081664c82427ea8aa6876117d695 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -72,7 +72,6 @@ public: void associateInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override final; - void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override final; bool forwardDims(bool allowDataDependency = false) override final { // Check first that all required inputs are available, otherwise diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 3ee2342297208f6f4e4b061409bc5071c811d2ac..09172f9d59d417132da7577fdec148e882e3d613 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -87,7 +87,6 @@ public: * @param data Data to copy. */ virtual void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) = 0; - virtual void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) = 0; virtual std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const = 0; /** * @brief Set the specified output value by performing a deep copy of the given data. @@ -95,7 +94,6 @@ public: * @param inputIdx Index of the input to set. */ virtual void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) = 0; - virtual void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) = 0; virtual std::shared_ptr<Data> getRawOutput(const IOIndex_t outputIdx) const = 0; std::shared_ptr<Hook> getHook(const std::string& hookName) { diff --git a/include/aidge/operator/OperatorTensor.hpp b/include/aidge/operator/OperatorTensor.hpp index a493793278d42904d8a62e31571720f94ff1655d..f2a59dda743af52647ad650aae516ef07ba89ac4 100644 --- a/include/aidge/operator/OperatorTensor.hpp +++ b/include/aidge/operator/OperatorTensor.hpp @@ -57,13 +57,11 @@ public: // Tensor access // input management void setInput(const IOIndex_t inputIdx, const std::shared_ptr<Data>& data) override; - void setInput(const IOIndex_t inputIdx, std::shared_ptr<Data>&& data) override; const std::shared_ptr<Tensor>& getInput(const IOIndex_t inputIdx) const; std::shared_ptr<Data> getRawInput(const IOIndex_t inputIdx) const override final; // output management void setOutput(const IOIndex_t outputIdx, const std::shared_ptr<Data>& data) override; - void setOutput(const IOIndex_t outputIdx, std::shared_ptr<Data>&& data) override; virtual const std::shared_ptr<Tensor>& getOutput(const IOIndex_t outputIdx) const; std::shared_ptr<Aidge::Data> getRawOutput(const Aidge::IOIndex_t outputIdx) const override final; /////////////////////////////////////////////////// diff --git a/include/aidge/operator/Producer.hpp b/include/aidge/operator/Producer.hpp index 23825079673129ea08aa7da40b21a8cc921d6ba0..c376bab3db22b6710a0915f7fcf2f749a60b7b61 100644 --- a/include/aidge/operator/Producer.hpp +++ b/include/aidge/operator/Producer.hpp @@ -107,12 +107,6 @@ public: void backward() override final { // fmt::print("Basic Producer backward() function.\n"); } - void setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) override { - if (getAttr<ProdAttr::Constant>()) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Producer is constant, cannot update output."); - } - OperatorTensor::setOutput(outputIdx, std::move(data)); - } void setOutput(const Aidge::IOIndex_t outputIdx, const std::shared_ptr<Aidge::Data>& data) override { if (getAttr<ProdAttr::Constant>()) { diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 163ea35c716cd6948c998f1b08f9f07d28fe1940..77ca0b00c40e578f45834a16da65ae37ac4b7d3c 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -31,6 +31,7 @@ #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Memorize.hpp" #include "aidge/utils/Directories.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" @@ -425,22 +426,68 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ } } - // Compute dimensions of every node - std::set<std::shared_ptr<Node>> listNodes = getNodes(); + // List of nodes that are already dims forwarded + std::set<std::shared_ptr<Node>> dimsForwarded; + // Establish initial list of dims forwardable nodes: + // input nodes and childs from Producers + std::set<std::shared_ptr<Node>> listNodes = inputNodes(); + for (const auto& nodePtr : getNodes()) { + if (nodePtr->type() == Producer_Op::Type) { + // Producers are already dims forwarded! + dimsForwarded.insert(nodePtr); + // Producers childs are dims forwardable + for (const auto& child : nodePtr->getChildren()) { + if (inView(child)) { + listNodes.insert(child); + } + } + } + } + do { std::set<std::shared_ptr<Node>> nextList; - for (std::shared_ptr<Node> nodePtr : listNodes) { + for (const auto& nodePtr : listNodes) { if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { - const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); - // Recompute everytime, even if it was already computed in a - // previous call of forwardDims(), as the graph may have changed! - op->forwardDims(allowDataDependency); - if (!op->dimsForwarded()) { - nextList.insert(nodePtr); - } + const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); + + bool anyParent = false; + bool parentsForwarded = true; + for (const auto& parent : nodePtr->getParents()) { + if (parent != nullptr && inView(parent) && dimsForwarded.find(parent) == dimsForwarded.end()) { + Log::debug("Dimensions not forwarded for parent (node {} (of type {})) of node {} (of type {})", + parent->name(), parent->type(), nodePtr->name(), nodePtr->type()); + parentsForwarded = false; + } + else { + anyParent = true; + } + } + + // Special rule for Memorize_Op, which only requires one parent + // to have its dims forwarded. This avoids circular dependency. + if (nodePtr->type() == Memorize_Op::Type && anyParent) { + parentsForwarded = true; + } + + if (parentsForwarded && op->forwardDims(allowDataDependency)) { + // Recompute everytime, even if it was already computed in a + // previous call of forwardDims(), as the graph may have changed! + dimsForwarded.insert(nodePtr); + for (const auto& child : nodePtr->getChildren()) { + if (inView(child) && dimsForwarded.find(child) == dimsForwarded.end()) { + nextList.insert(child); + } + } + } + else { + Log::debug("Unable to forward dimensions for node {} (of type {}) yet", nodePtr->name(), nodePtr->type()); + nextList.insert(nodePtr); + } } } + Log::debug("********************"); + // Internal check to make sure we won't enter in an infinite loop! if (nextList == listNodes) { // We are stuck! @@ -452,7 +499,6 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ Log::warn("Unable to forward dimensions (circular dependency and/or wrong dimensions and/or data dependent dimension?). Unable to compute output dims for nodes {}.", nodesName); return false; } - listNodes.swap(nextList); } while (!listNodes.empty()); diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index 36ff1854703d015980a1943390eb87d0863d877f..1397b69b9c126c0e2d0ec84bf900a320b95f0d80 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -58,16 +58,6 @@ void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); } -void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Data>&& data) { - AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); - - const auto& inputOp = mGraph->getOrderedInputs()[inputIdx]; - inputOp.first->getOperator()->setInput(inputOp.second, std::forward<std::shared_ptr<Data>>(data)); - - // Associate inputs for custom implementation - mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); -} - Aidge::Elts_t Aidge::MetaOperator_Op::getNbRequiredData(const IOIndex_t inputIdx) const { if (mImpl) { return mImpl->getNbRequiredData(inputIdx); diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp index 25c9deb2adaca65748d7f6981de574d0a674af5d..af20c1ff4ddd71479fcc899f7fe87be1d0000c72 100644 --- a/src/operator/OperatorTensor.cpp +++ b/src/operator/OperatorTensor.cpp @@ -62,15 +62,6 @@ void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, const std: Aidge::OperatorTensor::~OperatorTensor() = default; -void Aidge::OperatorTensor::setInput(const Aidge::IOIndex_t inputIdx, std::shared_ptr<Aidge::Data>&& data) { - AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); - if (getInput(inputIdx)) { - *mInputs[inputIdx] = std::move(*std::dynamic_pointer_cast<Tensor>(data)); - } else { - mInputs[inputIdx] = std::make_shared<Tensor>(std::move(*std::dynamic_pointer_cast<Tensor>(data))); - } -} - std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawInput(const Aidge::IOIndex_t inputIdx) const { return std::static_pointer_cast<Data>(getInput(inputIdx)); } @@ -88,15 +79,6 @@ void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, const st *mOutputs[outputIdx] = *data_tensor; } -void Aidge::OperatorTensor::setOutput(const Aidge::IOIndex_t outputIdx, std::shared_ptr<Aidge::Data>&& data) { - AIDGE_ASSERT(data->type() == Tensor::Type, "{} Operator only accepts Tensors as inputs", type()); - AIDGE_ASSERT(outputIdx < nbOutputs(), "{} Operator has {} outputs", type(), nbOutputs()); - auto&& data_tensor = std::dynamic_pointer_cast<Tensor>(data); - // if (mImpl) - // AIDGE_ASSERT(data_tensor->getImpl()->backend() == backend(), "Data parameter and Operator have different backends: {} and {}", data_tensor->getImpl()->backend(), backend()); - *mOutputs[outputIdx] = std::move(*data_tensor); -} - std::shared_ptr<Aidge::Data> Aidge::OperatorTensor::getRawOutput(const Aidge::IOIndex_t outputIdx) const { return std::static_pointer_cast<Data>(getOutput(outputIdx)); }