From a9e44123c1db68396a620cb3895403d76940af44 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 21 Mar 2024 14:12:47 +0000 Subject: [PATCH] Upd GraphViewHelper functions to return Tensors instead of Nodes --- include/aidge/recipes/GraphViewHelper.hpp | 32 +++++++-------- src/recipes/GraphViewHelper.cpp | 47 +++++++++++++++++++++++ 2 files changed, 64 insertions(+), 15 deletions(-) create mode 100644 src/recipes/GraphViewHelper.cpp diff --git a/include/aidge/recipes/GraphViewHelper.hpp b/include/aidge/recipes/GraphViewHelper.hpp index c6204cdff..8fdf1e1d7 100644 --- a/include/aidge/recipes/GraphViewHelper.hpp +++ b/include/aidge/recipes/GraphViewHelper.hpp @@ -9,14 +9,14 @@ * ********************************************************************************/ -#ifndef AIDGE_CORE_UTILS_RECIPES_H_ -#define AIDGE_CORE_UTILS_RECIPES_H_ +#ifndef AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ +#define AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ #include <memory> #include <set> -#include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/data/Tensor.hpp" namespace Aidge { @@ -26,15 +26,17 @@ namespace Aidge { * @param graphview GraphView instance where Producers should be searched. * @return std::set<std::shared_ptr<Node>> */ -std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) { - std::set<std::shared_ptr<Node>> res; - const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); - - std::copy_if(nodes.cbegin(), - nodes.cend(), - std::inserter(res, res.begin()), - [](std::shared_ptr<Node> n){ return n->type() == "Producer"; }); - - return res; -} -} // namespace Aidge \ No newline at end of file +std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview); + +/** + * @brief Getter for every ``Tensor`` owned by an ``Operator`` inside the provided ``GraphView``. + * @note An ``Operator`` owns its output ``Tensor``s. + * + * @param graphview Pointer to the ``GraphView`` from which ``Tensor``s should be extracted. + * @return std::set<std::shared_ptr<Tensor>> Set of pointers to the ``Tensor``s. + */ +std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview); + +} // namespace Aidge + +#endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */ diff --git a/src/recipes/GraphViewHelper.cpp b/src/recipes/GraphViewHelper.cpp new file mode 100644 index 000000000..ec58871b2 --- /dev/null +++ b/src/recipes/GraphViewHelper.cpp @@ -0,0 +1,47 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/recipes/GraphViewHelper.hpp" + +#include <memory> +#include <set> +#include <vector> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/OperatorTensor.hpp" + + +std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) { + std::set<std::shared_ptr<Tensor>> res; + const auto& nodes = graphview->getNodes(); + for (const auto& node : nodes) { + if (node->type() == "Producer") { + const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator()); + res.insert(param->getOutput(0)); + } + } + return res; +} + + +std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge::GraphView> graphview) { + std::set<std::shared_ptr<Tensor>> res; + const auto& nodes = graphview->getNodes(); + for (const auto& node : nodes) { + const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator()); + for (std::size_t o = 0; o < param->nbOutputs(); ++o) { + res.insert(param->getOutput(o)); + } + } + return res; +} -- GitLab