From e07747238016b6107d83adcc954718c8f8bf0474 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Thu, 7 Dec 2023 23:54:39 +0100
Subject: [PATCH] Added ExplicitConvert recipe

---
 include/aidge/operator/Conv.hpp     | 12 +++++
 include/aidge/operator/Convert.hpp  |  2 +-
 include/aidge/recipies/Recipies.hpp |  8 ++++
 src/operator/OperatorTensor.cpp     |  2 +
 src/recipies/ExplicitConvert.cpp    | 73 +++++++++++++++++++++++++++++
 5 files changed, 96 insertions(+), 1 deletion(-)
 create mode 100644 src/recipies/ExplicitConvert.cpp

diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp
index 8ba1dfbf7..fc16359aa 100644
--- a/include/aidge/operator/Conv.hpp
+++ b/include/aidge/operator/Conv.hpp
@@ -174,6 +174,18 @@ std::vector<std::pair<std::size_t, std::vector<DimSize_t>>> computeReceptiveFiel
     void setBackend(const std::string &name, int device = 0) override {
         mImpl = Registrar<Conv_Op<DIM>>::create(name)(*this);
         mOutputs[0]->setBackend(name, device);
+
+        // By default, automatically set backend for weight and bias inputs
+        getInput(1)->setBackend(name, device);
+        getInput(2)->setBackend(name, device);
+    }
+
+    void setDataType(const DataType& dt) const override {
+        mOutputs[0]->setDataType(dt);
+
+        // By default, automatically set data type for weight and bias inputs
+        getInput(1)->setDataType(dt);
+        getInput(2)->setDataType(dt);
     }
 
     static const std::vector<std::string> getInputsName(){
diff --git a/include/aidge/operator/Convert.hpp b/include/aidge/operator/Convert.hpp
index f115a243d..6a08fbf0d 100644
--- a/include/aidge/operator/Convert.hpp
+++ b/include/aidge/operator/Convert.hpp
@@ -51,7 +51,7 @@ public:
     }
 
     void setBackend(const std::string& name, int device = 0) override {
-        if (Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
+        if (mInputs[0]->getImpl() && Registrar<Convert_Op>::exists({mInputs[0]->getImpl()->backend(), name})) {
             mImpl = Registrar<Convert_Op>::create({mInputs[0]->getImpl()->backend(), name})(*this);
         }
         mOutputs[0]->setBackend(name, device);
diff --git a/include/aidge/recipies/Recipies.hpp b/include/aidge/recipies/Recipies.hpp
index 5ad08a658..f1d9a9327 100644
--- a/include/aidge/recipies/Recipies.hpp
+++ b/include/aidge/recipies/Recipies.hpp
@@ -89,6 +89,14 @@ std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<No
 // std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
 // void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
 
+
+/**
+ * Add Convert operators where needed to ensure no conversion needs to be done
+ * at the Operator level.
+*/
+void explicitConvert(std::shared_ptr<GraphView> graphView);
+
+
 } // namespace Aidge
 
 #endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
diff --git a/src/operator/OperatorTensor.cpp b/src/operator/OperatorTensor.cpp
index 1237fdc0b..ccee3865e 100644
--- a/src/operator/OperatorTensor.cpp
+++ b/src/operator/OperatorTensor.cpp
@@ -149,6 +149,7 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
     for (IOIndex_t i = 0; i < nbOutputs(); ++i) {
         getOutput(i)->setDataType(dataType);
     }
+    /*
     for (IOIndex_t i = 0; i < nbInputs(); ++i) {
         if (!getInput(i)) {
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not set");
@@ -157,4 +158,5 @@ void Aidge::OperatorTensor::setDataType(const DataType& dataType) const {
             getInput(i)->setDataType(dataType);
         }
     }
+    */
 }
\ No newline at end of file
diff --git a/src/recipies/ExplicitConvert.cpp b/src/recipies/ExplicitConvert.cpp
new file mode 100644
index 000000000..59d30d16a
--- /dev/null
+++ b/src/recipies/ExplicitConvert.cpp
@@ -0,0 +1,73 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include "aidge/recipies/Recipies.hpp"
+#include "aidge/operator/OperatorTensor.hpp"
+#include "aidge/operator/Convert.hpp"
+
+void Aidge::explicitConvert(std::shared_ptr<GraphView> graph) {
+    const auto nodes = graph->getNodes();
+    for (auto node : nodes) {
+        // TODO: currently, Operator data type is only reflected in its output tensor data type.
+        // But an Operator might have multiple outputs of different data type(?)
+        const auto& output = std::static_pointer_cast<OperatorTensor>(node->getOperator())->getOutput(0);
+        if (output->getImpl() == nullptr) {
+            continue;
+        }
+        const auto& device = output->getImpl()->device();
+
+        if (node->type() == Convert_Op::Type) {
+            // Remove existing Convert operator, if not needed anymore
+            const auto& parent = node->inputs()[0];
+            const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
+
+            if (input->dataType() == output->dataType()
+                && (input->getImpl()->device() == device))
+            {
+                // Add direct connection bypassing Convert node
+                for (auto child : node->outputs()[0]) {
+                    parent.first->addChild(child.first, parent.second, child.second);
+                }
+
+                // Remove all Convert node connections
+                node->resetConnections();
+            }
+        }
+        else {
+            // Insert Convert operator between node inputs and parent output, if needed
+            IOIndex_t inputIdx = 0;
+            for (auto parent : node->inputs()) {
+                if (parent.first != nullptr) {
+                    const auto& input = std::static_pointer_cast<OperatorTensor>(parent.first->getOperator())->getOutput(parent.second);
+        
+                    if (input->dataType() != output->dataType()
+                        || (input->getImpl()->device() != device))
+                    {
+                        // A conversion Operator is needed
+                        auto convert = Convert();
+                        convert->addChild(node, 0, inputIdx);
+                        parent.first->addChild(convert, parent.second, 0);
+                        // Set backend AFTER connection in case a specific implementation
+                        // of the operator exists for the input type.
+                        convert->getOperator()->setBackend(device.first, device.second);
+
+                        // Add/update nodes in the GraphView
+                        graph->add(convert);
+                        graph->add(parent.first);
+                        graph->add(node);
+                    }
+                }
+
+                ++inputIdx;
+            }
+        }
+    }
+}
-- 
GitLab