From 2bba4f964f28f45b45db5cea9c09e880940d5b54 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 8 Jan 2025 18:13:34 +0100
Subject: [PATCH] Multiple fixes related to adaptToBackend()

---
 include/aidge/recipes/Recipes.hpp             |  6 +++++
 .../backend/pybind_OperatorImpl.cpp           |  1 +
 src/backend/OperatorImpl.cpp                  | 27 ++++++++++++++-----
 src/data/Tensor.cpp                           |  1 +
 src/operator/Transpose.cpp                    | 11 +++++---
 src/recipes/AdaptToBackend.cpp                |  1 +
 src/recipes/ExpandMetaOps.cpp                 | 15 +++++++++++
 7 files changed, 53 insertions(+), 9 deletions(-)

diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp
index 0fb405bfe..aa4d3ae1b 100644
--- a/include/aidge/recipes/Recipes.hpp
+++ b/include/aidge/recipes/Recipes.hpp
@@ -124,6 +124,12 @@ void explicitCastMove(std::shared_ptr<GraphView> graphView);
 */
 void explicitTranspose(std::shared_ptr<GraphView> graphView);
 
+/**
+ * Replace a single meta operator by its micro graph.
+ * @return true if node is indeed a meta operator and could be expanded.
+*/
+bool expandMetaOp(std::shared_ptr<Node> node);
+
 /**
  * Flatten the graph by replacing the meta operators by their micro graph.
  * @param recursive If true, recursively replace meta operators until there is
diff --git a/python_binding/backend/pybind_OperatorImpl.cpp b/python_binding/backend/pybind_OperatorImpl.cpp
index 49e45ed7e..cd94997cf 100644
--- a/python_binding/backend/pybind_OperatorImpl.cpp
+++ b/python_binding/backend/pybind_OperatorImpl.cpp
@@ -81,6 +81,7 @@ void init_OperatorImpl(py::module& m){
     .def(py::init<const DynamicAttributes&>(), py::arg("attr") = DynamicAttributes())
     .def(py::init<const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("io"), py::arg("attr") = DynamicAttributes())
     .def(py::init<const ImplSpec::IOSpec&, const ImplSpec::IOSpec&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes())
+    .def(py::init<const std::vector<ImplSpec::IOSpec>&, const std::vector<ImplSpec::IOSpec>&, const DynamicAttributes&>(), py::arg("i"), py::arg("o"), py::arg("attr") = DynamicAttributes())
     .def("__eq__", static_cast<bool(*)(const ImplSpec&, const ImplSpec&)>(&operator==))
     .def("__repr__", [](ImplSpec self){
         return fmt::format("{}\n", self);
diff --git a/src/backend/OperatorImpl.cpp b/src/backend/OperatorImpl.cpp
index c74b538a4..8a4924c0e 100644
--- a/src/backend/OperatorImpl.cpp
+++ b/src/backend/OperatorImpl.cpp
@@ -250,9 +250,10 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
             && requiredIOSpec.type != IOSpec.type)
         {
             const auto cast = Cast(IOSpec.type);
+            cast->getOperator()->setBackend(node->getOperator()->backend());
             cast->addChild(parent, 0, i);
 
-            op->getInput(i)->setDataType(IOSpec.type);
+            op->getInput(i)->setDataType(requiredIOSpec.type);
         }
 
         // Input format
@@ -263,10 +264,11 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
             const auto transpose = getDataFormatTranspose(requiredIOSpec.format, IOSpec.format);
             auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
             transposeOp->getOperator()->setDataFormat(IOSpec.format);
-            transposeOp->getOperator()->setDataType(IOSpec.type);
+            transposeOp->getOperator()->setDataType(requiredIOSpec.type);
+            transposeOp->getOperator()->setBackend(node->getOperator()->backend());
             transposeOp->addChild(parent, 0, i);
 
-            op->getInput(i)->setDataFormat(IOSpec.format);
+            op->getInput(i)->setDataFormat(requiredIOSpec.format);
         }
 
         // Input dims
@@ -301,6 +303,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
             && requiredIOSpec.type != IOSpec.type)
         {
             const auto cast = Cast(requiredIOSpec.type);
+            cast->getOperator()->setBackend(node->getOperator()->backend());
             parent->addChild(cast, i, 0);
 
             op->getOutput(i)->setDataType(IOSpec.type);
@@ -315,6 +318,7 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
             auto transposeOp = Transpose(std::vector<DimSize_t>(transpose.begin(), transpose.end()));
             transposeOp->getOperator()->setDataFormat(requiredIOSpec.format);
             transposeOp->getOperator()->setDataType(requiredIOSpec.type);
+            transposeOp->getOperator()->setBackend(node->getOperator()->backend());
             parent->addChild(transposeOp, i, 0);
 
             op->getOutput(i)->setDataFormat(IOSpec.format);
@@ -340,7 +344,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getAdaptation(const ImplSpec&
         }
     }
 
-    return MetaOperator(std::string("Adapted_" + op->type()).c_str(), getConnectedGraphView(node));
+    auto adaptedGraph = getConnectedGraphView(node);
+    if (adaptedGraph->getNodes().size() > 1) {
+        return MetaOperator(std::string("Adapted_" + op->type()).c_str(), adaptedGraph);
+    }
+    else {
+        return node;
+    }
 }
 
 std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSpec& requiredSpecs) const {
@@ -354,8 +364,13 @@ std::shared_ptr<Aidge::Node> Aidge::OperatorImpl::getBestAdaptation(const ImplSp
         auto adaptation = getAdaptation(availableSpec, requiredSpecs);
 
         if (adaptation) {
-            auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph();
-            adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size()));
+            if (adaptation->getOperator()->isAtomic()) {
+                adaptations.insert(std::make_pair(adaptation, 1));
+            }
+            else {
+                auto microGraph = std::dynamic_pointer_cast<MetaOperator_Op>(adaptation->getOperator())->getMicroGraph();
+                adaptations.insert(std::make_pair(adaptation, microGraph->getNodes().size()));
+            }
         }
     }
 
diff --git a/src/data/Tensor.cpp b/src/data/Tensor.cpp
index c834167ab..e8a0e9ede 100644
--- a/src/data/Tensor.cpp
+++ b/src/data/Tensor.cpp
@@ -538,6 +538,7 @@ void Tensor::copyTranspose(const Tensor& src, const std::vector<DimSize_t>& tran
         }
     }
 
+    AIDGE_ASSERT(mImpl, "Tensor::copyTranspose(): an implementation is required, use setBackend() first!");
     std::shared_ptr<TensorImpl> newImpl = Registrar<Tensor>::create({mImpl->backend(), mDataType})(mImpl->device().second, newDims);
 
     std::vector<size_t> indices(newDims.size(), 0);
diff --git a/src/operator/Transpose.cpp b/src/operator/Transpose.cpp
index d24b9c909..b550db16d 100644
--- a/src/operator/Transpose.cpp
+++ b/src/operator/Transpose.cpp
@@ -66,12 +66,17 @@ bool Aidge::Transpose_Op::forwardDims(bool /*allowDataDependency*/) {
             std::iota(this->outputDimsOrder().rbegin(), this->outputDimsOrder().rend(), 0);
         }
 
-        AIDGE_ASSERT(outputDimsOrder().size() == getInput(0)->nbDims(),
-                     "Permutation vector must have the same rank as input tensor.");
+        AIDGE_ASSERT(outputDimsOrder().size() >= getInput(0)->nbDims(),
+            "Permutation vector ({}) must have at least the same rank as input tensor ({}).", outputDimsOrder(), getInput(0)->dims());
         std::vector<DimSize_t> outputDims;
-        for (std::size_t i = 0; i < outputDimsOrder().size(); ++i) {
+        std::size_t i = 0;
+        for (; i < getInput(0)->nbDims(); ++i) {
             outputDims.push_back(getInput(0)->dims()[outputDimsOrder()[i]]);
         }
+        for (; i < outputDimsOrder().size(); ++i) {
+            AIDGE_ASSERT(i == outputDimsOrder()[i],
+                "Permutation vector ({}) must be the identity above the input tensor rank ({}).", outputDimsOrder(), getInput(0)->dims());
+        }
         mOutputs[0]->resize(outputDims);
         return true;
     }
diff --git a/src/recipes/AdaptToBackend.cpp b/src/recipes/AdaptToBackend.cpp
index e625a52f6..bb4222c49 100644
--- a/src/recipes/AdaptToBackend.cpp
+++ b/src/recipes/AdaptToBackend.cpp
@@ -33,6 +33,7 @@ void Aidge::adaptToBackend(std::shared_ptr<GraphView> graphView) {
             Log::info("Adapted node {} (of type {}) to backend {}",
                 node->name(), node->type(), impl->backend());
             AIDGE_ASSERT(GraphView::replace({node}, {adaptedNode}), "Unable to replace adapted node!");
+            expandMetaOp(adaptedNode);
         }
     }
 }
diff --git a/src/recipes/ExpandMetaOps.cpp b/src/recipes/ExpandMetaOps.cpp
index 16f0b4c52..459a1ca85 100644
--- a/src/recipes/ExpandMetaOps.cpp
+++ b/src/recipes/ExpandMetaOps.cpp
@@ -14,6 +14,21 @@
 #include "aidge/recipes/Recipes.hpp"
 #include "aidge/operator/MetaOperator.hpp"
 
+bool Aidge::expandMetaOp(std::shared_ptr<Node> node) {
+    auto metaOp = std::dynamic_pointer_cast<MetaOperator_Op>(node->getOperator());
+
+    if (metaOp != nullptr) {
+        // Replace meta op by its micro-graph
+        // graph will be updated accordingly in GraphView::replace()
+        auto g = std::make_shared<GraphView>();
+        g->add(node, false);
+        GraphView::replace(g, metaOp->getMicroGraph());
+        return true;
+    }
+
+    return false;
+}
+
 void Aidge::expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive) {
     bool found = false;
     const auto nodes = graph->getNodes();
-- 
GitLab