From eb11c28421659cea472a5bc8072ec4cffc46f490 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Mon, 15 Jan 2024 14:55:42 +0000 Subject: [PATCH] Update forwardDims function to take dimensions of input tensors --- include/aidge/graph/GraphView.hpp | 2 +- src/graph/GraphView.cpp | 10 +++++++++- 2 files changed, 10 insertions(+), 2 deletions(-) diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 813301a14..392fb59e6 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -209,7 +209,7 @@ public: * @brief Compute dimensions of input/output Tensors for each Operator of the * GraphView object's Nodes. */ - void forwardDims(); + void forwardDims(const std::vector<std::vector<DimSize_t>> dims = {}); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string &backend, DeviceIdx_t device = 0); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 968e98e75..a93d9af8a 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -265,10 +265,18 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType forwardDims(); } -void Aidge::GraphView::forwardDims() { +void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>> dims) { // setInputs // Link every tensor to the right pointer // following parent - children informations + if (!dims.empty()){ + AIDGE_ASSERT(dims.size() == mInputNodes.size(), "GraphView forwardDims error - Inconsistent number of dimensions and graph inputs"); + for (std::size_t i = 0; i < dims.size(); ++i){ + auto tensor = std::make_shared<Tensor>(dims[i]); + mInputNodes[i].first->getOperator()->setInput(mInputNodes[i].second, tensor); + } + } + for (std::shared_ptr<Node> nodePtr : getNodes()) { for (IOIndex_t i = 0; i < nodePtr->nbInputs(); ++i) { // assess if the input was not already set and is a Tensor then link it to parent output -- GitLab