diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 813301a144682ba3e99de31ae324ffaedcc5209f..392fb59e65b8b844a091aaa89e7d623986dda85b 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -209,7 +209,7 @@ public: * @brief Compute dimensions of input/output Tensors for each Operator of the * GraphView object's Nodes. */ - void forwardDims(); + void forwardDims(const std::vector<std::vector<DimSize_t>> dims = {}); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string &backend, DeviceIdx_t device = 0); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 968e98e75cc587977eb3033fe7f25936880755a4..a93d9af8a972605b1519e9974971ff9e7ad3ef2f 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -265,10 +265,18 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType forwardDims(); } -void Aidge::GraphView::forwardDims() { +void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) { // setInputs // Link every tensor to the right pointer // following parent - children informations + if (!dims.empty()){ + AIDGE_ASSERT(dims.size() == mInputNodes.size(), "GraphView forwardDims error - Inconsistent number of dimensions and graph inputs"); + for (std::size_t i = 0; i < dims.size(); ++i){ + auto tensor = std::make_shared<Tensor>(dims[i]); + mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor); + } + } + for (std::shared_ptr<Node> nodePtr : getNodes()) { for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { // assess if the input was not already set and is a Tensor then link it to parent output