From ef3acd297eeeb2b8ba129dbac2147b400cbec0b5 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Wed, 7 Feb 2024 14:27:18 +0000 Subject: [PATCH] [Upd][WIP][NF] 'backward()' for SequentialScheduler --- include/aidge/recipies/GraphViewHelper.hpp | 44 ++------------- src/recipies/GraphViewHelper.cpp | 62 ++++++++++++++++++++++ src/scheduler/Scheduler.cpp | 2 +- 3 files changed, 66 insertions(+), 42 deletions(-) create mode 100644 src/recipies/GraphViewHelper.cpp diff --git a/include/aidge/recipies/GraphViewHelper.hpp b/include/aidge/recipies/GraphViewHelper.hpp index 14f59db9f..7cd5d662f 100644 --- a/include/aidge/recipies/GraphViewHelper.hpp +++ b/include/aidge/recipies/GraphViewHelper.hpp @@ -15,10 +15,7 @@ #include <memory> #include <set> -#include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" -#include "aidge/operator/OperatorTensor.hpp" -#include "aidge/utils/ErrorHandling.hpp" namespace Aidge { @@ -28,51 +25,16 @@ namespace Aidge { * @param graphview GraphView instance where Producers should be searched. * @return std::set<std::shared_ptr<Node>> */ -std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) { - std::set<std::shared_ptr<Node>> res; - const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); - - std::copy_if(nodes.cbegin(), - nodes.cend(), - std::inserter(res, res.begin()), - [](std::shared_ptr<Node> n){ return n->type() == "Producer"; }); - - return res; -} +std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview); /** * @brief Getter for every Producer operator in a GraphView that is a parameter. * @param graphview GraphView instance where Producers should be searched. * @return std::set<std::shared_ptr<Node>> */ -std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview) { - std::set<std::shared_ptr<Node>> res; - const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); - - for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) { - for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) { - const std::shared_ptr<Node>& parent = (*it)->getParent(inID); - if (parent && parent->type() == "Producer") { - res.insert(parent); - } - } - } - - return res; -} +std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview); -void instanciateGradient(std::shared_ptr<Aidge::GraphView> gv) { - for (const auto& node : gv->getNodes()) { - // TODO: check that each node is an OperatorTensor - AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor."); - const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator()); - for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { - const auto& t = op->getOutput(o); - t -> grad() -> setDataType(t -> dataType()); - t -> grad() -> setBackend(t -> getImpl() -> backend()); - } - } -} +void compile_gradient(std::shared_ptr<Aidge::GraphView> gv); } // namespace Aidge diff --git a/src/recipies/GraphViewHelper.cpp b/src/recipies/GraphViewHelper.cpp new file mode 100644 index 000000000..ac2cb1fdf --- /dev/null +++ b/src/recipies/GraphViewHelper.cpp @@ -0,0 +1,62 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> +#include <set> + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/recipies/GraphViewHelper.hpp" + + +std::set<std::shared_ptr<Aidge::Node>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) { + std::set<std::shared_ptr<Node>> res; + const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); + + std::copy_if(nodes.cbegin(), + nodes.cend(), + std::inserter(res, res.begin()), + [](std::shared_ptr<Node> n){ return n->type() == "Producer"; }); + + return res; +} + + +std::set<std::shared_ptr<Aidge::Node>> Aidge::parameters(std::shared_ptr<Aidge::GraphView> graphview) { + std::set<std::shared_ptr<Node>> res; + const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); + + for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) { + for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) { + const std::shared_ptr<Node>& parent = (*it)->getParent(inID); + if (parent && parent->type() == "Producer") { + res.insert(parent); + } + } + } + + return res; +} + +void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) { + for (const auto& node : gv->getNodes()) { + // TODO: check that each node is an OperatorTensor + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor."); + const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator()); + for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { + const auto& t = op->getOutput(o); + t -> grad() -> setDataType(t -> dataType()); + t -> grad() -> setBackend(t -> getImpl() -> backend()); + } + } +} \ No newline at end of file diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 074f3a98e..d5a3d2764 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -211,7 +211,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { // Forward dims (if allowed) - if (instanciateGrad) {instanciateGradient(mGraphView); } + if (instanciateGrad) {compile_gradient(mGraphView); } // Generate scheduling *only if empty* // If scheduling was already generated (in one or several steps, i.e. one or -- GitLab