From 649254e519c05f87f24ed7f49fe0c6ba60300a03 Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 11 Sep 2024 15:41:32 +0200
Subject: [PATCH] Improved fuseToMetaOps to set backend of fused meta op

---
 include/aidge/operator/MetaOperator.hpp | 1 +
 include/aidge/operator/Operator.hpp     | 2 +-
 src/operator/MetaOperator.cpp           | 6 ++++++
 src/recipes/FuseToMetaOps.cpp           | 6 ++++++
 src/scheduler/Scheduler.cpp             | 3 +++
 5 files changed, 17 insertions(+), 1 deletion(-)

diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp
index 69f2120d9..bd85346f7 100644
--- a/include/aidge/operator/MetaOperator.hpp
+++ b/include/aidge/operator/MetaOperator.hpp
@@ -79,6 +79,7 @@ public:
         return false;
     }
 
+    std::string backend() const noexcept override;
 
     void setBackend(const std::string &name, DeviceIdx_t device = 0) override;
 
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index c938fc362..282270736 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -124,7 +124,7 @@ public:
 ///////////////////////////////////////////////////////
 //        IMPLEMENTATION
 ///////////////////////////////////////////////////////
-    std::string backend() const noexcept {
+    virtual std::string backend() const noexcept {
         return mImpl ? mImpl->backend() : "";
     }
 
diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp
index a7997bc1a..fc094464b 100644
--- a/src/operator/MetaOperator.cpp
+++ b/src/operator/MetaOperator.cpp
@@ -69,6 +69,12 @@ void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std
     mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second));
 }
 
+std::string Aidge::MetaOperator_Op::backend() const noexcept {
+    return (mImpl)
+        ? mImpl->backend()
+        : mGraph->rootNode()->getOperator()->backend();
+}
+
 void Aidge::MetaOperator_Op::setBackend(const std::string &name, Aidge::DeviceIdx_t device) {
     if (Registrar<MetaOperator_Op>::exists({name, type()})) {
         // A custom implementation exists for this meta operator
diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp
index e7748936c..b4aa477a2 100644
--- a/src/recipes/FuseToMetaOps.cpp
+++ b/src/recipes/FuseToMetaOps.cpp
@@ -24,6 +24,12 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str
     size_t nbReplaced = 0;
     for (const auto& match : matches) {
         auto metaOp = MetaOperator(metaType.c_str(), match.graph->clone());
+        // Clone does not clone implementation, which is therefore empty.
+        // Use the root node backend for the meta op backend, even though some
+        // matching nodes might be on a different backend, as nodes in the meta
+        // op are required to share the same backend!
+        metaOp->getOperator()->setBackend(match.graph->rootNode()->getOperator()->backend());
+
         auto metaOpGraph = std::make_shared<GraphView>();
         metaOpGraph->add(metaOp, false);
         const auto success = GraphView::replace(match.graph, metaOpGraph);
diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp
index 851f1895c..161345050 100644
--- a/src/scheduler/Scheduler.cpp
+++ b/src/scheduler/Scheduler.cpp
@@ -612,6 +612,9 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName)
 }
 
 std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const {
+    AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?");
+    AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step);
+
     const auto& staticSchedule = mStaticSchedule.at(step);
     std::vector<std::shared_ptr<Node>> schedule;
     std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });
-- 
GitLab