From b282d8a01af6b8d96fb31a6e1bf2875327f41595 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 14 Dec 2023 18:55:27 +0100
Subject: [PATCH] Robustified explicitCastMove following split in two operators

---
 src/recipies/ExplicitCastMove.cpp | 128 ++++++++++++++++++------------
 1 file changed, 77 insertions(+), 51 deletions(-)

diff --git a/src/recipies/ExplicitCastMove.cpp b/src/recipies/ExplicitCastMove.cpp
index f0cd20502..5651f2ba4 100644
--- a/src/recipies/ExplicitCastMove.cpp
+++ b/src/recipies/ExplicitCastMove.cpp
@@ -15,7 +15,8 @@
 #include "aidge/operator/Move.hpp"
 
 void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
-    const auto nodes = graph->getNodes();
+    // First, remove existing Cast and Move operators, if not needed anymore
+    auto nodes = graph->getNodes();
     for (auto node : nodes) {
         // TODO: currently, Operator data type is only reflected in its output tensor data type.
         // But an Operator might have multiple outputs of different data type(?)
@@ -27,71 +28,96 @@ void Aidge::explicitCastMove(std::shared_ptr<GraphView> graph) {
 
         if (node->type() == Cast_Op::Type || node->type() == Move_Op::Type) {
             // Remove existing Cast and Move operators, if not needed anymore
+            AIDGE_INTERNAL_ASSERT(node->inputs().size() == 1);
             const auto parent = node->inputs()[0];
-            const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
+            // Check parent is not nullptr, as this Operator may be an entry point of the graph without parent
+            if (parent.first != nullptr) {
+                const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
 
-            if (input->dataType() == output->dataType()
-                && (input->getImpl()->device() == device))
-            {
-                // Add direct connection bypassing Cast/Move node
-                const auto childs = node->outputs()[0];
-                for (const auto& child : childs) {
-                    parent.first->addChild(child.first, parent.second, child.second);
-                }
+                if ((node->type() == Cast_Op::Type && input->dataType() == output->dataType())
+                    || (node->type() == Move_Op::Type && input->getImpl() != nullptr && input->getImpl()->device() == device))
+                {
+                    // Add direct connection bypassing Cast/Move node
+                    const auto childs = node->outputs()[0];
+                    for (const auto& child : childs) {
+                        parent.first->addChild(child.first, parent.second, child.second);
+                    }
 
-                // Remove all node connections
-                node->resetConnections();
+                    // Remove all node connections
+                    node->resetConnections();
+                    // Remove node from view
+                    graph->remove(node);
+                }
             }
         }
-        else {
-            // Insert Cast and/or Move operator between node inputs and parent output, if needed
-            IOIndex_t inputIdx = 0;
-            for (auto parent : node->inputs()) {
-                if (parent.first != nullptr) {
-                    const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
+    }
 
-                    NodePtr moveOp = nullptr;
-                    NodePtr castOp = nullptr;
+    // Note: why two steps and not merge the two node loops?
+    // User may have changed some data type/backends on top of existing Cast/Move operators
+    // This may lead to situation where a Cast should be removed but a Move should
+    // be inserted at the same place. In this case, some conversion may be missed
+    // depending on the order of iteration over the nodes (which are non ordered!).
 
-                    if (input->getImpl()->device() != device) {
-                        // Change of backend => a Move operator is required
-                        moveOp = Move();
-                        moveOp->getOperator()->setDataType(input->dataType());
-                        castOp = moveOp;
-                    }
+    // Second, insert Cast and/or Move operator between node inputs and parent output, if needed
+    nodes = graph->getNodes();
+    for (auto node : nodes) {
+        // TODO: currently, Operator data type is only reflected in its output tensor data type.
+        // But an Operator might have multiple outputs of different data type(?)
+        const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
+        if (output->getImpl() == nullptr) {
+            continue;
+        }
+        const auto& device = output->getImpl()->device();
 
-                    if (input->dataType() != output->dataType()) {
-                        // Change of date type => a Cast operator is required
-                        castOp = Cast();
-                        castOp->getOperator()->setDataType(output->dataType());
-                        castOp->getOperator()->setBackend(device.first, device.second);
+        IOIndex_t inputIdx = 0;
+        for (auto parent : node->inputs()) {
+            // TODO: possible optimization: currently, a Cast/Move Operator may 
+            // be added several time to the same output, if it has multiple childs,
+            // even if it is the same conversion each time.
+            if (parent.first != nullptr) {
+                const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
 
-                        if (moveOp == nullptr) {
-                            moveOp = castOp;
-                        }
-                        else {
-                            moveOp->addChild(castOp, 0, 0);
-                        }
-                    }
+                NodePtr moveOp = nullptr;
+                NodePtr castOp = nullptr;
+
+                if (node->type() != Move_Op::Type && input->getImpl()->device() != device) {
+                    // Change of backend => a Move operator is required
+                    moveOp = Move();
+                    moveOp->getOperator()->setDataType(input->dataType());
+                    castOp = moveOp;
+                }
 
-                    if (moveOp != nullptr && castOp != nullptr) {
-                        // Move and/or Cast Operator(s) are needed
-                        castOp->addChild(node, 0, inputIdx);
-                        parent.first->addChild(moveOp, parent.second, 0);
-                        // Set backend AFTER connection in case a specific implementation
-                        // of the operator exists for the input type.
-                        moveOp->getOperator()->setBackend(device.first, device.second);
+                if (node->type() != Cast_Op::Type && input->dataType() != output->dataType()) {
+                    // Change of date type => a Cast operator is required
+                    castOp = Cast();
+                    castOp->getOperator()->setDataType(output->dataType());
+                    castOp->getOperator()->setBackend(device.first, device.second);
 
-                        // Add/update nodes in the GraphView
-                        graph->add(moveOp);
-                        graph->add(castOp);
-                        graph->add(parent.first);
-                        graph->add(node);
+                    if (moveOp == nullptr) {
+                        moveOp = castOp;
+                    }
+                    else {
+                        moveOp->addChild(castOp, 0, 0);
                     }
                 }
 
-                ++inputIdx;
+                if (moveOp != nullptr && castOp != nullptr) {
+                    // Move and/or Cast Operator(s) are needed
+                    castOp->addChild(node, 0, inputIdx);
+                    parent.first->addChild(moveOp, parent.second, 0);
+                    // Set backend AFTER connection in case a specific implementation
+                    // of the operator exists for the input type.
+                    moveOp->getOperator()->setBackend(device.first, device.second);
+
+                    // Add/update nodes in the GraphView
+                    graph->add(moveOp);
+                    graph->add(castOp);
+                    graph->add(parent.first);
+                    graph->add(node);
+                }
             }
+
+            ++inputIdx;
         }
     }
 }
-- 
GitLab