Skip to content
Snippets Groups Projects
Commit 95a2f4da authored by Maxence Naud's avatar Maxence Naud
Browse files

[Add] features

[function] 'parameters()' to extract parameters of type Producer from a GraphView
[function] 'instanciateGraphView()' to initialize Tensors gradient with the same datatype/backend
parent 0ea59366
No related branches found
No related tags found
3 merge requests!105version 0.2.0,!88Basic supervised learning,!79Scheduler backward
...@@ -17,6 +17,8 @@ ...@@ -17,6 +17,8 @@
#include "aidge/graph/Node.hpp" #include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge { namespace Aidge {
...@@ -37,4 +39,41 @@ std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphVie ...@@ -37,4 +39,41 @@ std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphVie
return res; return res;
} }
} // namespace Aidge
\ No newline at end of file /**
* @brief Getter for every Producer operator in a GraphView that is a parameter.
* @param graphview GraphView instance where Producers should be searched.
* @return std::set<std::shared_ptr<Node>>
*/
std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) {
for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) {
const std::shared_ptr<Node>& parent = (*it)->getParent(inID);
if (parent && parent->type() == "Producer") {
res.insert(parent);
}
}
}
return res;
}
void instanciateGradient(std::shared_ptr<Aidge::GraphView> gv) {
for (const auto& node : gv->getNodes()) {
// TODO: check that each node is an OperatorTensor
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor.");
const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator());
for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
const auto& t = op->getOutput(o);
t -> grad() -> setDataType(t -> dataType());
t -> grad() -> setBackend(t -> getImpl() -> backend());
}
}
}
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment