Skip to content
Snippets Groups Projects
Commit 53bdfccc authored by Maxence Naud's avatar Maxence Naud
Browse files

Upd GraphViewHelper functions to return Tensors instead of Nodes

parent ea27e836
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!88Basic supervised learning
......@@ -9,14 +9,14 @@
*
********************************************************************************/
#ifndef AIDGE_CORE_UTILS_RECIPES_H_
#define AIDGE_CORE_UTILS_RECIPES_H_
#ifndef AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_
#define AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_
#include <memory>
#include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
namespace Aidge {
......@@ -26,15 +26,17 @@ namespace Aidge {
* @param graphview GraphView instance where Producers should be searched.
* @return std::set<std::shared_ptr<Node>>
*/
std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
std::copy_if(nodes.cbegin(),
nodes.cend(),
std::inserter(res, res.begin()),
[](std::shared_ptr<Node> n){ return n->type() == "Producer"; });
return res;
}
} // namespace Aidge
\ No newline at end of file
std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview);
/**
* @brief Getter for every ``Tensor`` owned by an ``Operator`` inside the provided ``GraphView``.
* @note An ``Operator`` owns its output ``Tensor``s.
*
* @param graphview Pointer to the ``GraphView`` from which ``Tensor``s should be extracted.
* @return std::set<std::shared_ptr<Tensor>> Set of pointers to the ``Tensor``s.
*/
std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview);
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/recipes/GraphViewHelper.hpp"
#include <memory>
#include <set>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Tensor>> res;
const auto& nodes = graphview->getNodes();
for (const auto& node : nodes) {
if (node->type() == "Producer") {
const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator());
res.insert(param->getOutput(0));
}
}
return res;
}
std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Tensor>> res;
const auto& nodes = graphview->getNodes();
for (const auto& node : nodes) {
const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator());
for (std::size_t o = 0; o < param->nbOutputs(); ++o) {
res.insert(param->getOutput(o));
}
}
return res;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment