From 95a2f4da366c6453800f78422e5b25d8c709befe Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Tue, 6 Feb 2024 13:47:45 +0000 Subject: [PATCH] [Add] features [function] 'parameters()' to extract parameters of type Producer from a GraphView [function] 'instanciateGraphView()' to initialize Tensors gradient with the same datatype/backend --- include/aidge/recipies/GraphViewHelper.hpp | 41 +++++++++++++++++++++- 1 file changed, 40 insertions(+), 1 deletion(-) diff --git a/include/aidge/recipies/GraphViewHelper.hpp b/include/aidge/recipies/GraphViewHelper.hpp index d7bcec713..14f59db9f 100644 --- a/include/aidge/recipies/GraphViewHelper.hpp +++ b/include/aidge/recipies/GraphViewHelper.hpp @@ -17,6 +17,8 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" namespace Aidge { @@ -37,4 +39,41 @@ std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphVie return res; } -} // namespace Aidge \ No newline at end of file + +/** + * @brief Getter for every Producer operator in a GraphView that is a parameter. + * @param graphview GraphView instance where Producers should be searched. + * @return std::set<std::shared_ptr<Node>> + */ +std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview) { + std::set<std::shared_ptr<Node>> res; + const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); + + for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) { + for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) { + const std::shared_ptr<Node>& parent = (*it)->getParent(inID); + if (parent && parent->type() == "Producer") { + res.insert(parent); + } + } + } + + return res; +} + +void instanciateGradient(std::shared_ptr<Aidge::GraphView> gv) { + for (const auto& node : gv->getNodes()) { + // TODO: check that each node is an OperatorTensor + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor."); + const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator()); + for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { + const auto& t = op->getOutput(o); + t -> grad() -> setDataType(t -> dataType()); + t -> grad() -> setBackend(t -> getImpl() -> backend()); + } + } +} + +} // namespace Aidge + +#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ \ No newline at end of file -- GitLab