From 0350b632e16fec8cfd63e51dfdbf172002226306 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Sat, 2 Mar 2024 19:12:01 +0100
Subject: [PATCH] Added operator type checking

---
 src/recipes/ExplicitCastMove.cpp | 2 ++
 src/recipes/FuseBatchNorm.cpp    | 1 +
 src/recipes/FuseMulAdd.cpp       | 3 ++-
 src/recipes/HorizontalTiling.cpp | 3 ++-
 src/scheduler/Scheduler.cpp      | 1 +
 5 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/src/recipes/ExplicitCastMove.cpp b/src/recipes/ExplicitCastMove.cpp
index e6427b847..7d836c3ac 100644
--- a/src/recipes/ExplicitCastMove.cpp
+++ b/src/recipes/ExplicitCastMove.cpp
@@ -20,6 +20,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
     for (auto node : nodes) {
         // TODO: currently, Operator data type is only reflected in its output tensor data type.
         // But an Operator might have multiple outputs of different data type(?)
+        AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
         const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
         if (output->getImpl() == nullptr) {
             continue;
@@ -32,6 +33,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
             const auto parent = node->inputs()[0];
             // Check parent is not nullptr, as this Operator may be an entry point of the graph without parent
             if (parent.first != nullptr) {
+                AIDGE_ASSERT(parent.first->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
                 const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
 
                 if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType())
diff --git a/src/recipes/FuseBatchNorm.cpp b/src/recipes/FuseBatchNorm.cpp
index 40a1b5952..ac1fc8d79 100644
--- a/src/recipes/FuseBatchNorm.cpp
+++ b/src/recipes/FuseBatchNorm.cpp
@@ -53,6 +53,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
     DimSize_t convNbOutChannels;
     DimSize_t channelsSize;
     std::array<DimSize_t, 2> kernelDims;
+    AIDGE_ASSERT(convNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
     std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
     if (convNode->type() == Conv_Op<2>::Type) {
         const std::shared_ptr<Conv_Op<2>> convOpPtr =
diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp
index 65837d023..a09c27c2b 100644
--- a/src/recipes/FuseMulAdd.cpp
+++ b/src/recipes/FuseMulAdd.cpp
@@ -44,7 +44,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
     // TODO: find another way to get OutChannels for FC operator.
     // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
     DimSize_t outSize = 0;
-    const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator());
+    AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
+    const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator());
     for (size_t i = 0; i < op->nbInputs(); i++)
     {
         const auto& inTensor = op->getInput(i);
diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp
index 4d18e012e..8e27fea58 100644
--- a/src/recipes/HorizontalTiling.cpp
+++ b/src/recipes/HorizontalTiling.cpp
@@ -36,7 +36,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
     if (node->getOperator()->type() != "Conv") {
         AIDGE_INTERNAL_ASSERT("Operator should be a Convolution.");
     }
-    const auto& op = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
+    AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
+    const auto& op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
     if (op->nbOutputs() != 1 || op->nbData() > 1) {
         AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now.");
     }
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index d45bc5f8e..6c827f236 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -325,6 +325,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
             }
             
             const auto childs = node->getChildren();
+            AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
             const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
 
             std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
-- 
GitLab