Skip to content
Snippets Groups Projects
Commit 0350b632 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added operator type checking

parent 2f65f127
No related branches found
No related tags found
No related merge requests found
...@@ -20,6 +20,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { ...@@ -20,6 +20,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
for (auto node : nodes) { for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type. // TODO: currently, Operator data type is only reflected in its output tensor data type.
// But an Operator might have multiple outputs of different data type(?) // But an Operator might have multiple outputs of different data type(?)
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0); const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
if (output->getImpl() == nullptr) { if (output->getImpl() == nullptr) {
continue; continue;
...@@ -32,6 +33,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { ...@@ -32,6 +33,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
const auto parent = node->inputs()[0]; const auto parent = node->inputs()[0];
// Check parent is not nullptr, as this Operator may be an entry point of the graph without parent // Check parent is not nullptr, as this Operator may be an entry point of the graph without parent
if (parent.first != nullptr) { if (parent.first != nullptr) {
AIDGE_ASSERT(parent.first->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType()) if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType())
......
...@@ -53,6 +53,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode, ...@@ -53,6 +53,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
DimSize_t convNbOutChannels; DimSize_t convNbOutChannels;
DimSize_t channelsSize; DimSize_t channelsSize;
std::array<DimSize_t, 2> kernelDims; std::array<DimSize_t, 2> kernelDims;
AIDGE_ASSERT(convNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator()); std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
if (convNode->type() == Conv_Op<2>::Type) { if (convNode->type() == Conv_Op<2>::Type) {
const std::shared_ptr<Conv_Op<2>> convOpPtr = const std::shared_ptr<Conv_Op<2>> convOpPtr =
......
...@@ -44,7 +44,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -44,7 +44,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// TODO: find another way to get OutChannels for FC operator. // TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0; DimSize_t outSize = 0;
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator()); AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator());
for (size_t i = 0; i < op->nbInputs(); i++) for (size_t i = 0; i < op->nbInputs(); i++)
{ {
const auto& inTensor = op->getInput(i); const auto& inTensor = op->getInput(i);
......
...@@ -36,7 +36,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std: ...@@ -36,7 +36,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
if (node->getOperator()->type() != "Conv") { if (node->getOperator()->type() != "Conv") {
AIDGE_INTERNAL_ASSERT("Operator should be a Convolution."); AIDGE_INTERNAL_ASSERT("Operator should be a Convolution.");
} }
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator()); AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
if (op->nbOutputs() != 1 || op->nbData() > 1) { if (op->nbOutputs() != 1 || op->nbData() > 1) {
AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now."); AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now.");
} }
......
...@@ -325,6 +325,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer ...@@ -325,6 +325,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
} }
const auto childs = node->getChildren(); const auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane; std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment