From 114b4ec1200fe5a4b04eef4c91471e42fecdc59a Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 30 Nov 2023 13:46:29 +0000 Subject: [PATCH] Minor changes --- src/graph/GraphView.cpp | 23 +++++++++++------------ src/scheduler/Scheduler.cpp | 1 + 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 96466cd1a..ce956d115 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -188,18 +188,17 @@ void Aidge::GraphView::forwardDims() { // assess if the input was not already set and is a Tensor then link it to parent output std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i); if (inputI.first) { - if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) { - if ((strcmp(nodePtr->getOperator()->getRawInput(i)->type(), Tensor::Type) == 0) && (strcmp(inputI.first->getOperator()->getRawOutput(inputI.second)->type(), Tensor::Type)==0)) { - // assert provided Data is of "Tensor" type - nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second)); - } - else { - assert(false && "Non-tensor entries not handled yet.\n"); - } - } - } else - { - assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); + if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) { + if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { + // assert provided Data is of "Tensor" type + nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second)); + } + else { + assert(false && "Non-tensor entries not handled yet.\n"); + } + } + } else { + assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 1f34091e5..3afbcd044 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -19,6 +19,7 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h" +#include "aidge/operator/OperatorTensor.hpp" void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('['); -- GitLab