Skip to content
Snippets Groups Projects
Commit ef3acd29 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd][WIP][NF] 'backward()' for SequentialScheduler

parent 764c6ad8
No related branches found
No related tags found
No related merge requests found
...@@ -15,10 +15,7 @@ ...@@ -15,10 +15,7 @@
#include <memory> #include <memory>
#include <set> #include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
namespace Aidge { namespace Aidge {
...@@ -28,51 +25,16 @@ namespace Aidge { ...@@ -28,51 +25,16 @@ namespace Aidge {
* @param graphview GraphView instance where Producers should be searched. * @param graphview GraphView instance where Producers should be searched.
* @return std::set<std::shared_ptr<Node>> * @return std::set<std::shared_ptr<Node>>
*/ */
std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview) { std::set<std::shared_ptr<Aidge::Node>> producers(std::shared_ptr<Aidge::GraphView> graphview);
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
std::copy_if(nodes.cbegin(),
nodes.cend(),
std::inserter(res, res.begin()),
[](std::shared_ptr<Node> n){ return n->type() == "Producer"; });
return res;
}
/** /**
* @brief Getter for every Producer operator in a GraphView that is a parameter. * @brief Getter for every Producer operator in a GraphView that is a parameter.
* @param graphview GraphView instance where Producers should be searched. * @param graphview GraphView instance where Producers should be searched.
* @return std::set<std::shared_ptr<Node>> * @return std::set<std::shared_ptr<Node>>
*/ */
std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview) { std::set<std::shared_ptr<Aidge::Node>> parameters(std::shared_ptr<Aidge::GraphView> graphview);
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) {
for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) {
const std::shared_ptr<Node>& parent = (*it)->getParent(inID);
if (parent && parent->type() == "Producer") {
res.insert(parent);
}
}
}
return res;
}
void instanciateGradient(std::shared_ptr<Aidge::GraphView> gv) { void compile_gradient(std::shared_ptr<Aidge::GraphView> gv);
for (const auto& node : gv->getNodes()) {
// TODO: check that each node is an OperatorTensor
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor.");
const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator());
for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
const auto& t = op->getOutput(o);
t -> grad() -> setDataType(t -> dataType());
t -> grad() -> setBackend(t -> getImpl() -> backend());
}
}
}
} // namespace Aidge } // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/recipies/GraphViewHelper.hpp"
std::set<std::shared_ptr<Aidge::Node>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
std::copy_if(nodes.cbegin(),
nodes.cend(),
std::inserter(res, res.begin()),
[](std::shared_ptr<Node> n){ return n->type() == "Producer"; });
return res;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::parameters(std::shared_ptr<Aidge::GraphView> graphview) {
std::set<std::shared_ptr<Node>> res;
const std::set<std::shared_ptr<Node>> nodes = graphview->getNodes();
for (auto it = nodes.cbegin(); it != nodes.cend(); ++it) {
for (std::size_t inID = (*it)->nbData(); inID < (*it)->nbInputs(); ++inID) {
const std::shared_ptr<Node>& parent = (*it)->getParent(inID);
if (parent && parent->type() == "Producer") {
res.insert(parent);
}
}
}
return res;
}
void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) {
for (const auto& node : gv->getNodes()) {
// TODO: check that each node is an OperatorTensor
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator that doesn't use Tensor.");
const std::shared_ptr<OperatorTensor> op = std::static_pointer_cast<OperatorTensor>(node -> getOperator());
for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
const auto& t = op->getOutput(o);
t -> grad() -> setDataType(t -> dataType());
t -> grad() -> setBackend(t -> getImpl() -> backend());
}
}
}
\ No newline at end of file
...@@ -211,7 +211,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { ...@@ -211,7 +211,7 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) { void Aidge::SequentialScheduler::backward(bool instanciateGrad, bool verbose) {
// Forward dims (if allowed) // Forward dims (if allowed)
if (instanciateGrad) {instanciateGradient(mGraphView); } if (instanciateGrad) {compile_gradient(mGraphView); }
// Generate scheduling *only if empty* // Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or // If scheduling was already generated (in one or several steps, i.e. one or
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment