diff --git a/include/aidge/recipies/GraphViewHelper.hpp b/include/aidge/recipies/GraphViewHelper.hpp index d7bcec713087054640c87c6fd229fee53d1ed4a6..14f59db9f8a2faad3910209501924a034094441c 100644 --- a/include/aidge/recipies/GraphViewHelper.hpp +++ b/include/aidge/recipies/GraphViewHelper.hpp @@ -17,6 +17,8 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/operator/OperatorTensor.hpp" +#include "aidge/utils/ErrorHandling.hpp" namespace Aidge { @@ -37,4 +39,41 @@ std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphVie return res; } -} // namespace Aidge \ No newline at end of file + +/** + * @brief Getter for every Producer operator in a GraphView that is a parameter. + * @param graphview GraphView instance where Producers should be searched. + * @return std::set<std::shared_ptr<Node>> + */ +std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview) { + std::set<std::shared_ptr<Node>> res; + const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes(); + + for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) { + for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) { + const std::shared_ptr<Node>& parent = (*it)->getParent(inID); + if (parent && parent->type() == "Producer") { + res.insert(parent); + } + } + } + + return res; +} + +void instanciateGradient(std::shared_ptr<Aidge::GraphView> gv) { + for (const auto& node : gv->getNodes()) { + // TODO: check that each node is an OperatorTensor + AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor."); + const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator()); + for (std::size_t o = 0; o < node -> nbOutputs(); ++o) { + const auto& t = op->getOutput(o); + t -> grad() -> setDataType(t -> dataType()); + t -> grad() -> setBackend(t -> getImpl() -> backend()); + } + } +} + +} // namespace Aidge + +#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */ \ No newline at end of file