Skip to content
Snippets Groups Projects
Commit 0350b632 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added operator type checking

parent 2f65f127
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!77Support for recurrent networks
......@@ -20,6 +20,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
for (auto node : nodes) {
// TODO: currently, Operator data type is only reflected in its output tensor data type.
// But an Operator might have multiple outputs of different data type(?)
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
if (output->getImpl() == nullptr) {
continue;
......@@ -32,6 +33,7 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
const auto parent = node->inputs()[0];
// Check parent is not nullptr, as this Operator may be an entry point of the graph without parent
if (parent.first != nullptr) {
AIDGE_ASSERT(parent.first->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType())
......
......@@ -53,6 +53,7 @@ void Aidge::fuseBatchNorm(std::shared_ptr<Aidge::Node> convNode,
DimSize_t convNbOutChannels;
DimSize_t channelsSize;
std::array<DimSize_t, 2> kernelDims;
AIDGE_ASSERT(convNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
std::shared_ptr<OperatorTensor> convOp = std::static_pointer_cast<OperatorTensor>(convNode->getOperator());
if (convNode->type() == Conv_Op<2>::Type) {
const std::shared_ptr<Conv_Op<2>> convOpPtr =
......
......@@ -44,7 +44,8 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
// TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0;
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator());
AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator());
for (size_t i = 0; i < op->nbInputs(); i++)
{
const auto& inTensor = op->getInput(i);
......
......@@ -36,7 +36,8 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::getConvHorizontalTiling(const std:
if (node->getOperator()->type() != "Conv") {
AIDGE_INTERNAL_ASSERT("Operator should be a Convolution.");
}
const auto& op = std::dynamic_pointer_cast<OperatorTensor>(node->getOperator());
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
if (op->nbOutputs() != 1 || op->nbData() > 1) {
AIDGE_INTERNAL_ASSERT("Only slice Operators with one output and at most one input for now.");
}
......
......@@ -325,6 +325,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
}
const auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
std::vector<const MemoryManager::MemoryPlane*> wrapAroundMemPlane;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment