Skip to content
Snippets Groups Projects
Commit a291765a authored by Maxence Naud's avatar Maxence Naud Committed by Maxence Naud
Browse files

Upd GraphViewHelper functions to return Tensors instead of Nodes

parent 8fab299e
No related branches found
No related tags found
3 merge requests!105version 0.2.0,!88Basic supervised learning,!82Resolve "Optimizer to update gradients"
...@@ -9,14 +9,14 @@ ...@@ -9,14 +9,14 @@
* *
********************************************************************************/ ********************************************************************************/
#ifndef AIDGE_CORE_UTILS_RECIPES_H_ #ifndef AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_
#define AIDGE_CORE_UTILS_RECIPES_H_ #define AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_
#include <memory> #include <memory>
#include <set> #include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
namespace Aidge { namespace Aidge {
...@@ -26,15 +26,17 @@ namespace Aidge { ...@@ -26,15 +26,17 @@ namespace Aidge {
* @param graphview GraphView instance where Producers should be searched. * @param graphview GraphView instance where Producers should be searched.
* @return std::set<std::shared_ptr<Node>> * @return std::set<std::shared_ptr<Node>>
*/ */
std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) { std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview);
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); /**
* @brief Getter for every ``Tensor`` owned by an ``Operator`` inside the provided ``GraphView``.
std::copy_if(nodes.cbegin(), * @note An ``Operator`` owns its output ``Tensor``s.
nodes.cend(), *
std::inserter(res, res.begin()), * @param graphview Pointer to the ``GraphView`` from which ``Tensor``s should be extracted.
[](std::shared_ptr<Node> n){ return n->type() == "Producer"; }); * @return std::set<std::shared_ptr<Tensor>> Set of pointers to the ``Tensor``s.
*/
return res; std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview);
}
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
#endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/recipes/GraphViewHelper.hpp"
#include <memory>
#include <set>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/OperatorTensor.hpp"
std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Tensor>> res;
const auto& nodes = graphview->getNodes();
for (const auto& node : nodes) {
if (node->type() == "Producer") {
const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator());
res.insert(param->getOutput(0));
}
}
return res;
}
std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Tensor>> res;
const auto& nodes = graphview->getNodes();
for (const auto& node : nodes) {
const auto& param = std::static_pointer_cast<OperatorTensor>(node->getOperator());
for (std::size_t o = 0; o < param->nbOutputs(); ++o) {
res.insert(param->getOutput(o));
}
}
return res;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment