From b282d8a01af6b8d96fb31a6e1bf2875327f41595 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Thu, 14 Dec 2023 18:55:27 +0100 Subject: [PATCH] Robustified explicitCastMove following split in two operators --- src/recipies/ExplicitCastMove.cpp | 128 ++++++++++++++++++------------ 1 file changed, 77 insertions(+), 51 deletions(-) diff --git a/src/recipies/ExplicitCastMove.cpp b/src/recipies/ExplicitCastMove.cpp index f0cd20502..5651f2ba4 100644 --- a/src/recipies/ExplicitCastMove.cpp +++ b/src/recipies/ExplicitCastMove.cpp @@ -15,7 +15,8 @@ #include "aidge/operator/Move.hpp" void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { - const auto nodes = graph->getNodes(); + // First, remove existing Cast and Move operators, if not needed anymore + auto nodes = graph->getNodes(); for (auto node : nodes) { // TODO: currently, Operator data type is only reflected in its output tensor data type. // But an Operator might have multiple outputs of different data type(?) @@ -27,71 +28,96 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) { if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) { // Remove existing Cast and Move operators, if not needed anymore + AIDGE_INTERNAL_ASSERT(node->inputs().size() == 1); const auto parent = node->inputs()[0]; - const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); + // Check parent is not nullptr, as this Operator may be an entry point of the graph without parent + if (parent.first != nullptr) { + const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); - if (input->dataType() == output->dataType() - && (input->getImpl()->device() == device)) - { - // Add direct connection bypassing Cast/Move node - const auto childs = node->outputs()[0]; - for (const auto& child : childs) { - parent.first->addChild(child.first, parent.second, child.second); - } + if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType()) + || (node->type() == Move_Op::Type && input->getImpl() != nullptr && input->getImpl()->device() == device)) + { + // Add direct connection bypassing Cast/Move node + const auto childs = node->outputs()[0]; + for (const auto& child : childs) { + parent.first->addChild(child.first, parent.second, child.second); + } - // Remove all node connections - node->resetConnections(); + // Remove all node connections + node->resetConnections(); + // Remove node from view + graph->remove(node); + } } } - else { - // Insert Cast and/or Move operator between node inputs and parent output, if needed - IOIndex_t inputIdx = 0; - for (auto parent : node->inputs()) { - if (parent.first != nullptr) { - const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); + } - NodePtr moveOp = nullptr; - NodePtr castOp = nullptr; + // Note: why two steps and not merge the two node loops? + // User may have changed some data type/backends on top of existing Cast/Move operators + // This may lead to situation where a Cast should be removed but a Move should + // be inserted at the same place. In this case, some conversion may be missed + // depending on the order of iteration over the nodes (which are non ordered!). - if (input->getImpl()->device() != device) { - // Change of backend => a Move operator is required - moveOp = Move(); - moveOp->getOperator()->setDataType(input->dataType()); - castOp = moveOp; - } + // Second, insert Cast and/or Move operator between node inputs and parent output, if needed + nodes = graph->getNodes(); + for (auto node : nodes) { + // TODO: currently, Operator data type is only reflected in its output tensor data type. + // But an Operator might have multiple outputs of different data type(?) + const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0); + if (output->getImpl() == nullptr) { + continue; + } + const auto& device = output->getImpl()->device(); - if (input->dataType() != output->dataType()) { - // Change of date type => a Cast operator is required - castOp = Cast(); - castOp->getOperator()->setDataType(output->dataType()); - castOp->getOperator()->setBackend(device.first, device.second); + IOIndex_t inputIdx = 0; + for (auto parent : node->inputs()) { + // TODO: possible optimization: currently, a Cast/Move Operator may + // be added several time to the same output, if it has multiple childs, + // even if it is the same conversion each time. + if (parent.first != nullptr) { + const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second); - if (moveOp == nullptr) { - moveOp = castOp; - } - else { - moveOp->addChild(castOp, 0, 0); - } - } + NodePtr moveOp = nullptr; + NodePtr castOp = nullptr; + + if (node->type() != Move_Op::Type && input->getImpl()->device() != device) { + // Change of backend => a Move operator is required + moveOp = Move(); + moveOp->getOperator()->setDataType(input->dataType()); + castOp = moveOp; + } - if (moveOp != nullptr && castOp != nullptr) { - // Move and/or Cast Operator(s) are needed - castOp->addChild(node, 0, inputIdx); - parent.first->addChild(moveOp, parent.second, 0); - // Set backend AFTER connection in case a specific implementation - // of the operator exists for the input type. - moveOp->getOperator()->setBackend(device.first, device.second); + if (node->type() != Cast_Op::Type && input->dataType() != output->dataType()) { + // Change of date type => a Cast operator is required + castOp = Cast(); + castOp->getOperator()->setDataType(output->dataType()); + castOp->getOperator()->setBackend(device.first, device.second); - // Add/update nodes in the GraphView - graph->add(moveOp); - graph->add(castOp); - graph->add(parent.first); - graph->add(node); + if (moveOp == nullptr) { + moveOp = castOp; + } + else { + moveOp->addChild(castOp, 0, 0); } } - ++inputIdx; + if (moveOp != nullptr && castOp != nullptr) { + // Move and/or Cast Operator(s) are needed + castOp->addChild(node, 0, inputIdx); + parent.first->addChild(moveOp, parent.second, 0); + // Set backend AFTER connection in case a specific implementation + // of the operator exists for the input type. + moveOp->getOperator()->setBackend(device.first, device.second); + + // Add/update nodes in the GraphView + graph->add(moveOp); + graph->add(castOp); + graph->add(parent.first); + graph->add(node); + } } + + ++inputIdx; } } } -- GitLab