diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 69f2120d90beb727bd661628c362410066ae3cff..bd85346f7211701e20a443685f24d37a76ae631b 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -79,6 +79,7 @@ public: return false; } + std::string backend() const noexcept override; void setBackend(const std::string &name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index c938fc362aa1f747f5f31bea3fdb08fa851e2333..282270736fb7d8ca1fbd3f2b9d1c12bf144e6d34 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -124,7 +124,7 @@ public: /////////////////////////////////////////////////////// // IMPLEMENTATION /////////////////////////////////////////////////////// - std::string backend() const noexcept { + virtual std::string backend() const noexcept { return mImpl ? mImpl->backend() : ""; } diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index a7997bc1a07e633feaf0873078ddb1ebb9bc71d4..fc094464bdce9473c40c9056e0f384400c4af72a 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -69,6 +69,12 @@ void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); } +std::string Aidge::MetaOperator_Op::backend() const noexcept { + return (mImpl) + ? mImpl->backend() + : mGraph->rootNode()->getOperator()->backend(); +} + void Aidge::MetaOperator_Op::setBackend(const std::string &name, Aidge::DeviceIdx_t device) { if (Registrar<MetaOperator_Op>::exists({name, type()})) { // A custom implementation exists for this meta operator diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp index e7748936c00a20ec235ea7853f4d17e2c10261fb..0ad5e5a1da0e6aef74f7e47751dd2d4e8648980b 100644 --- a/src/recipes/FuseToMetaOps.cpp +++ b/src/recipes/FuseToMetaOps.cpp @@ -24,6 +24,15 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str size_t nbReplaced = 0; for (const auto& match : matches) { auto metaOp = MetaOperator(metaType.c_str(), match.graph->clone()); + // Clone does not clone implementation, which is therefore empty. + // Use the root node backend for the meta op backend, even though some + // matching nodes might be on a different backend, as nodes in the meta + // op are required to share the same backend! + const auto backend = match.graph->rootNode()->getOperator()->backend(); + if (!backend.empty()) { + metaOp->getOperator()->setBackend(backend); + } + auto metaOpGraph = std::make_shared<GraphView>(); metaOpGraph->add(metaOp, false); const auto success = GraphView::replace(match.graph, metaOpGraph); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 851f1895c3862ed3deedc73f2ee70f6835b4a8a3..1613450508ea84a230f36ba6526a1322c6a70559 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -612,6 +612,9 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) } std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const { + AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); + AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); + const auto& staticSchedule = mStaticSchedule.at(step); std::vector<std::shared_ptr<Node>> schedule; std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; });