Skip to content
Snippets Groups Projects
Commit 0350b632 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added operator type checking

parent 2f65f127
No related merge requests found
......@@ -20,6 +20,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type.
// But an Operator might have multiple outputs of different data type(?)
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
if (output->getImpl() == nullptr) {
continue;
......@@ -32,6 +33,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
const auto parent = node->inputs()[0];
// Check parent is not nullptr, as this Operator may be an entry point of the graph without parent
if (parent.first != nullptr) {
AIDGE_ASSERT(parent.first->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType())
......
......@@ -53,6 +53,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
DimSize_t convNbOutChannels;
DimSize_t channelsSize;
std::array<DimSize_t, 2> kernelDims;
AIDGE_ASSERT(convNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
if (convNode->type() == Conv_Op<2>::Type) {
const std::shared_ptr<Conv_Op<2>> convOpPtr =
......
......@@ -44,7 +44,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0;
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator());
AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator());
for (size_t i = 0; i < op->nbInputs(); i++)
{
const auto& inTensor = op->getInput(i);
......
......@@ -36,7 +36,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
if (node->getOperator()->type() != "Conv") {
AIDGE_INTERNAL_ASSERT("Operator should be a Convolution.");
}
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
if (op->nbOutputs() != 1 || op->nbData() > 1) {
AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now.");
}
......
......@@ -325,6 +325,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
}
const auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment