From 649254e519c05f87f24ed7f49fe0c6ba60300a03 Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 11 Sep 2024 15:41:32 +0200 Subject: [PATCH] Improved fuseToMetaOps to set backend of fused meta op --- include/aidge/operator/MetaOperator.hpp | 1 + include/aidge/operator/Operator.hpp | 2 +- src/operator/MetaOperator.cpp | 6 ++++++ src/recipes/FuseToMetaOps.cpp | 6 ++++++ src/scheduler/Scheduler.cpp | 3 +++ 5 files changed, 17 insertions(+), 1 deletion(-) diff --git a/include/aidge/operator/MetaOperator.hpp b/include/aidge/operator/MetaOperator.hpp index 69f2120d9..bd85346f7 100644 --- a/include/aidge/operator/MetaOperator.hpp +++ b/include/aidge/operator/MetaOperator.hpp @@ -79,6 +79,7 @@ public: return false; } + std::string backend() const noexcept override; void setBackend(const std::string &name, DeviceIdx_t device = 0) override; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index c938fc362..282270736 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -124,7 +124,7 @@ public: /////////////////////////////////////////////////////// // IMPLEMENTATION /////////////////////////////////////////////////////// - std::string backend() const noexcept { + virtual std::string backend() const noexcept { return mImpl ? mImpl->backend() : ""; } diff --git a/src/operator/MetaOperator.cpp b/src/operator/MetaOperator.cpp index a7997bc1a..fc094464b 100644 --- a/src/operator/MetaOperator.cpp +++ b/src/operator/MetaOperator.cpp @@ -69,6 +69,12 @@ void Aidge::MetaOperator_Op::setInput(const Aidge::IOIndex_t inputIdx, const std mInputs[inputIdx] = std::dynamic_pointer_cast<Tensor>(inputOp.first->getOperator()->getRawInput(inputOp.second)); } +std::string Aidge::MetaOperator_Op::backend() const noexcept { + return (mImpl) + ? mImpl->backend() + : mGraph->rootNode()->getOperator()->backend(); +} + void Aidge::MetaOperator_Op::setBackend(const std::string &name, Aidge::DeviceIdx_t device) { if (Registrar<MetaOperator_Op>::exists({name, type()})) { // A custom implementation exists for this meta operator diff --git a/src/recipes/FuseToMetaOps.cpp b/src/recipes/FuseToMetaOps.cpp index e7748936c..b4aa477a2 100644 --- a/src/recipes/FuseToMetaOps.cpp +++ b/src/recipes/FuseToMetaOps.cpp @@ -24,6 +24,12 @@ size_t Aidge::fuseToMetaOps(std::shared_ptr<GraphView> graphView, const std::str size_t nbReplaced = 0; for (const auto& match : matches) { auto metaOp = MetaOperator(metaType.c_str(), match.graph->clone()); + // Clone does not clone implementation, which is therefore empty. + // Use the root node backend for the meta op backend, even though some + // matching nodes might be on a different backend, as nodes in the meta + // op are required to share the same backend! + metaOp->getOperator()->setBackend(match.graph->rootNode()->getOperator()->backend()); + auto metaOpGraph = std::make_shared<GraphView>(); metaOpGraph->add(metaOp, false); const auto success = GraphView::replace(match.graph, metaOpGraph); diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 851f1895c..161345050 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -612,6 +612,9 @@ void Aidge::Scheduler::saveStaticSchedulingDiagram(const std::string& fileName) } std::vector<std::shared_ptr<Aidge::Node>> Aidge::Scheduler::getStaticScheduling(std::size_t step) const { + AIDGE_ASSERT(!mStaticSchedule.empty(), "Scheduler::getStaticScheduling(): static scheduling is empty, did you generate scheduling first?"); + AIDGE_ASSERT(step < mStaticSchedule.size(), "Scheduler::getStaticScheduling(): no static scheduling at step {} (available steps: {})", mStaticSchedule.size(), step); + const auto& staticSchedule = mStaticSchedule.at(step); std::vector<std::shared_ptr<Node>> schedule; std::transform(staticSchedule.begin(), staticSchedule.end(), std::back_inserter(schedule), [](const auto& v) { return v->node; }); -- GitLab