diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 96466cd1a4b81dae3eec120360055bdf0f8c5844..ce956d115e282c43751619070dd8a10ac5c9cfae 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -188,18 +188,17 @@ void Aidge::GraphView::forwardDims() { // assess if the input was not already set and is a Tensor then link it to parent output std::pair<std::shared_ptr<Node>, IOIndex_t> inputI = nodePtr->input(i); if (inputI.first) { - if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) { - if ((strcmp(nodePtr->getOperator()->getRawInput(i)->type(), Tensor::Type) == 0) && (strcmp(inputI.first->getOperator()->getRawOutput(inputI.second)->type(), Tensor::Type)==0)) { - // assert provided Data is of "Tensor" type - nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second)); - } - else { - assert(false && "Non-tensor entries not handled yet.\n"); - } - } - } else - { - assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); + if ( std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i)) != inputI.first->getOperator()->getRawOutput(inputI.second)) { + if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { + // assert provided Data is of "Tensor" type + nodePtr->getOperator()->associateInput(i, inputI.first->getOperator()->getRawOutput(inputI.second)); + } + else { + assert(false && "Non-tensor entries not handled yet.\n"); + } + } + } else { + assert(!std::static_pointer_cast<Tensor>(nodePtr->getOperator()->getRawInput(i))->empty()); } } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 1f34091e54c0f83dae6b60589c20fb8fdf1d5064..3afbcd0442fd40214687751d50bfc98809bba840 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -19,6 +19,7 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/utils/Types.h" +#include "aidge/operator/OperatorTensor.hpp" void drawProgressBar(double progress, int barWidth, const std::string& additionalInfo = "") { putchar('[');