From 0350b632e16fec8cfd63e51dfdbf172002226306 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Sat, 2 Mar 2024 19:12:01 +0100 Subject: [PATCH] Added operator type checking --- src/recipes/ExplicitCastMove.cpp | 2 ++ src/recipes/FuseBatchNorm.cpp | 1 + src/recipes/FuseMulAdd.cpp | 3 ++- src/recipes/HorizontalTiling.cpp | 3 ++- src/scheduler/Scheduler.cpp | 1 + 5 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/recipes/ExplicitCastMove.cpp b/src/recipes/ExplicitCastMove.cpp index e6427b847..7d836c3ac 100644 --- a/src/recipes/ExplicitCastMove.cpp +++ b/src/recipes/ExplicitCastMove.cpp @@ -20,6 +20,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { for (auto node : nodes) { // TODO: currently, Operator data type is only reflected in its output tensor data type. // But an Operator might have multiple outputs of different data type(?) + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0); if (output->getImpl() == nullptr) { continue; @@ -32,6 +33,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { const auto parent = node->inputs()[0]; // Check parent is not nullptr, as this Operator may be an entry point of the graph without parent if (parent.first != nullptr) { + AIDGE_ASSERT(parent.first->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType()) diff --git a/src/recipes/FuseBatchNorm.cpp b/src/recipes/FuseBatchNorm.cpp index 40a1b5952..ac1fc8d79 100644 --- a/src/recipes/FuseBatchNorm.cpp +++ b/src/recipes/FuseBatchNorm.cpp @@ -53,6 +53,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, DimSize_t convNbOutChannels; DimSize_t channelsSize; std::array<DimSize_t, 2> kernelDims; + AIDGE_ASSERT(convNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator()); if (convNode->type() == Conv_Op<2>::Type) { const std::shared_ptr<Conv_Op<2>> convOpPtr = diff --git a/src/recipes/FuseMulAdd.cpp b/src/recipes/FuseMulAdd.cpp index 65837d023..a09c27c2b 100644 --- a/src/recipes/FuseMulAdd.cpp +++ b/src/recipes/FuseMulAdd.cpp @@ -44,7 +44,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< // TODO: find another way to get OutChannels for FC operator. // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output DimSize_t outSize = 0; - const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator()); + AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); + const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator()); for (size_t i = 0; i < op->nbInputs(); i++) { const auto& inTensor = op->getInput(i); diff --git a/src/recipes/HorizontalTiling.cpp b/src/recipes/HorizontalTiling.cpp index 4d18e012e..8e27fea58 100644 --- a/src/recipes/HorizontalTiling.cpp +++ b/src/recipes/HorizontalTiling.cpp @@ -36,7 +36,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: if (node->getOperator()->type() != "Conv") { AIDGE_INTERNAL_ASSERT("Operator should be a Convolution."); } - const auto& op = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator()); + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); + const auto& op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); if (op->nbOutputs() != 1 || op->nbData() > 1) { AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now."); } diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index d45bc5f8e..6c827f236 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -325,6 +325,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer } const auto childs = node->getChildren(); + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane; -- GitLab