Skip to content
Snippets Groups Projects
Commit cc40ecc3 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'scheduler_backprop' into 'learning'

Scheduler backward

See merge request eclipse/aidge/aidge_core!79
parents 75968752 7f7b5c3a
No related branches found
No related tags found
No related merge requests found
......@@ -62,11 +62,7 @@ public:
return mNodes == gv.mNodes;
}
NodePtr operator[](const std::string& nodeName)
{
AIDGE_ASSERT(mNodeRegistry.find(nodeName) != mNodeRegistry.end(), "No node named {} in graph {}.", nodeName, name());
return mNodeRegistry.at(nodeName);
}
const NodePtr operator[](const std::string& nodeName) const;
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
......@@ -82,14 +78,14 @@ public:
* @brief Name of the node.
* @return std::string
*/
std::string name() const;
inline std::string name() const noexcept { return mName; }
/**
* @brief Set the node name.
* @warning Undefined behaviour when several Nodes have the same name.
* @param name New name for the node.
*/
void setName(const std::string &name);
inline void setName(const std::string &name) { mName = name; }
/**
* @brief Save the GraphView as a Mermaid graph in a .md file at the
......@@ -105,11 +101,9 @@ public:
* @param nodePtr Node to check
* @return bool True is nodePtr belongs to the GraphView.
*/
inline bool inView(NodePtr nodePtr) const {
return mNodes.find(nodePtr) != mNodes.end();
}
bool inView(const NodePtr& nodePtr) const;
NodePtr getRootNode() {
inline NodePtr rootNode() const noexcept {
return mRootNode;
}
......@@ -120,41 +114,32 @@ public:
///////////////////////////////////////////////////////
public:
/** @brief Get reference to the set of input Nodes. */
inline std::set<NodePtr> inputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mInputNodes) {
if (node.first != nullptr) {
nodes.insert(node.first);
}
}
return nodes;
}
std::set<NodePtr> inputNodes() const;
/** @brief Get reference to the set of output Nodes. */
inline std::set<NodePtr> outputNodes() const noexcept {
std::set<NodePtr> nodes;
for (auto node : mOutputNodes) {
if (node.first != nullptr) {
nodes.insert(node.first);
}
}
return nodes;
}
std::set<NodePtr> outputNodes() const;
/** @brief Assess if the given Node is an input Node of the GraphView object. */
inline bool isInputNode(NodePtr nodePtr) const {
const auto nodes = inputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
}
bool isInputNode(const NodePtr& nodePtr) const;
/** @brief Assess if the given Node is an output Node of the GraphView object. */
inline bool isOutputNode(NodePtr nodePtr) const {
const auto nodes = outputNodes();
return (nodes.find(nodePtr) != nodes.end()) ? true : false;
}
bool isOutputNode(const NodePtr& nodePtr) const;
void setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs);
void setOrderedOutputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& outputs);
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const { return mInputNodes; };
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const { return mOutputNodes; };
/**
* @brief Get inputs of the current GraphView with their associated id.
* The rank of the nodes are their rank in the vector.
* @return const std::vector<std::pair<NodePtr, IOIndex_t>>&
*/
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedInputs() const noexcept { return mInputNodes; };
/**
* @brief Get outputs of the current GraphView with their associated id.
* The rank of the nodes are their rank in the vector.
* @return const std::vector<std::pair<NodePtr, IOIndex_t>>&
*/
inline const std::vector<std::pair<NodePtr, IOIndex_t>>& getOrderedOutputs() const noexcept { return mOutputNodes; };
/**
* @brief List outside data input connections of the GraphView.
......@@ -225,9 +210,9 @@ public:
void forwardDims(const std::vector<std::vector<DimSize_t>> dims = {});
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string &backend, DeviceIdx_t device = 0);
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setDataType(const DataType &datatype);
void setDataType(const DataType& datatype) const;
///////////////////////////////////////////////////////
// TOPOLOGY
......
......@@ -28,6 +28,8 @@ namespace Aidge {
*/
std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview);
// TODO: change for every Tensor of Operator Producer not constant
/**
* @brief Getter for every ``Tensor`` owned by an ``Operator`` inside the provided ``GraphView``.
* @note An ``Operator`` owns its output ``Tensor``s.
......@@ -37,6 +39,8 @@ std::set<std::shared_ptr<Tensor>> producers(std::shared_ptr<GraphView> graphview
*/
std::set<std::shared_ptr<Tensor>> parameters(std::shared_ptr<GraphView> graphview);
void compile_gradient(std::shared_ptr<Aidge::GraphView> gv);
} // namespace Aidge
#endif /* AIDGE_CORE_UTILS_GRAPHVIEWHELPER_H_ */
......@@ -69,7 +69,7 @@ public:
/**
* @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph.
*
*
* @param data data input tensors
*/
void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data);
......@@ -79,6 +79,11 @@ public:
*/
void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
/**
* @brief Run the provided Computational Graph with a batch of data
*/
void backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instantiateGrad = true, bool verbose = false);
/**
* @brief Save in a Markdown file the order of layers execution.
* @param fileName Name of the generated file.
......
......@@ -76,6 +76,7 @@ void init_Tensor(py::module& m){
.def("set_datatype", &Tensor::setDataType, py::arg("datatype"), py::arg("copyCast") = true)
.def("set_backend", &Tensor::setBackend, py::arg("name"), py::arg("device") = 0, py::arg("copyFrom") = true)
.def("dims", (const std::vector<DimSize_t>& (Tensor::*)()const) &Tensor::dims)
.def("grad", &Tensor::grad)
.def("dtype", &Tensor::dataType)
.def("size", &Tensor::size)
.def("resize", (void (Tensor::*)(const std::vector<DimSize_t>&, std::vector<DimSize_t>)) &Tensor::resize)
......
......@@ -31,6 +31,8 @@ void init_GraphView(py::module& m) {
:type path: str
)mydelimiter")
.def("log_outputs", &GraphView::logOutputs, py::arg("path"))
.def("get_ordered_inputs", &GraphView::getOrderedInputs)
.def("get_ordered_outputs", &GraphView::getOrderedOutputs)
.def("get_output_nodes", &GraphView::outputNodes,
R"mydelimiter(
Get set of output Nodes.
......
......@@ -68,6 +68,7 @@ void init_GraphRegex(py::module&);
void init_MatchSolution(py::module&);
void init_Recipes(py::module&);
void init_GraphViewHelper(py::module&);
void init_Scheduler(py::module&);
void init_TensorUtils(py::module&);
......@@ -129,6 +130,7 @@ void init_Aidge(py::module& m) {
init_MatchSolution(m);
init_Recipes(m);
init_GraphViewHelper(m);
init_Scheduler(m);
init_TensorUtils(m);
init_Filler(m);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include <memory>
#include <set>
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
namespace py = pybind11;
namespace Aidge {
void init_GraphViewHelper(py::module &m) {
m.def("producers", &producers, py::arg("graphview"));
}
} // namespace Aidge
......@@ -21,6 +21,7 @@ void init_Scheduler(py::module& m){
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>())
.def("backward", &SequentialScheduler::backward, py::arg("data"), py::arg("instanciate_grad")=true, py::arg("verbose")=false)
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &SequentialScheduler::resetScheduling)
.def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false)
......
......@@ -35,6 +35,11 @@
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::shared_ptr<Aidge::Node> Aidge::GraphView::operator[](const std::string& nodeName) const {
return (mNodeRegistry.find(nodeName) != mNodeRegistry.cend()) ? mNodeRegistry.at(nodeName) : nullptr;
}
///////////////////////////////////////////////////////
// FUNCTIONAL DESCRIPTION
///////////////////////////////////////////////////////
......@@ -64,9 +69,10 @@ Aidge::Connector Aidge::GraphView::operator()(
// INNER
///////////////////////////////////////////////////////
std::string Aidge::GraphView::name() const { return mName; }
bool Aidge::GraphView::inView(const std::shared_ptr<Aidge::Node>& nodePtr) const {
return mNodes.find(nodePtr) != mNodes.cend();
}
void Aidge::GraphView::setName(const std::string &name) { mName = name; }
void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProducers) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((path + ".mmd").c_str(), "w"), &std::fclose);
......@@ -125,8 +131,8 @@ void Aidge::GraphView::save(const std::string& path, bool verbose, bool showProd
continue;
}
IOIndex_t outputIdx = 0;
for (auto childs : node_ptr->getOrderedChildren()) {
for (auto child : childs) {
for (const auto& childs : node_ptr->getOrderedChildren()) {
for (const auto& child : childs) {
if (child != nullptr) {
IOIndex_t inputIdx = 0;
for (auto parent : child->inputs()) {
......@@ -233,6 +239,33 @@ void Aidge::GraphView::setRootNode(NodePtr node) {
// TENSOR MANAGEMENT
///////////////////////////////////////////////////////
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const {
std::set<std::shared_ptr<Aidge::Node>> nodes;
for (const auto& node : mInputNodes) {
nodes.insert(node.first);
}
return nodes;
}
std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::outputNodes() const {
std::set<std::shared_ptr<Aidge::Node>> nodes;
for (const auto& node : mOutputNodes) {
nodes.insert(node.first);
}
return nodes;
}
bool Aidge::GraphView::isInputNode(const std::shared_ptr<Aidge::Node>& nodePtr) const {
const auto nodes = inputNodes();
return (nodes.find(nodePtr) != nodes.cend());
}
bool Aidge::GraphView::isOutputNode(const std::shared_ptr<Aidge::Node>& nodePtr) const {
const auto nodes = outputNodes();
return (nodes.find(nodePtr) != nodes.cend());
}
void Aidge::GraphView::setOrderedInputs(const std::vector<std::pair<NodePtr, IOIndex_t>>& inputs) {
size_t nbInputs = 0;
std::vector<std::pair<NodePtr, IOIndex_t>> ignoredInputs(mInputNodes);
......@@ -425,14 +458,14 @@ void Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
while (!listNodes.empty());
}
void Aidge::GraphView::setBackend(const std::string &backend, DeviceIdx_t device) {
for (auto node : getNodes()) {
void Aidge::GraphView::setBackend(const std::string &backend, const DeviceIdx_t device) const {
for (const auto& node : getNodes()) {
node->getOperator()->setBackend(backend, device);
}
}
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) {
for (auto node : getNodes()) {
void Aidge::GraphView::setDataType(const Aidge::DataType &datatype) const {
for (const auto& node : getNodes()) {
node->getOperator()->setDataType(datatype);
}
}
......@@ -666,11 +699,9 @@ bool Aidge::GraphView::add(std::pair<NodePtr, std::set<NodePtr>> nodes, bool inc
}
bool Aidge::GraphView::add(std::shared_ptr<GraphView> graph) {
if (mRootNode == nullptr) {
mRootNode = graph->getRootNode();
}
return add(graph->getNodes(), false);
// set the rootNode to the other graphView rootNode if no rootNode yet
mRootNode = mRootNode ? mRootNode : graph->rootNode();
return add(graph->getNodes(), false);
}
void Aidge::GraphView::addChild(std::shared_ptr<Node> toOtherNode,
......
......@@ -9,16 +9,15 @@
*
********************************************************************************/
#include "aidge/recipes/GraphViewHelper.hpp"
#include <memory>
#include <set>
#include <vector>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/recipes/GraphViewHelper.hpp"
std::set<std::shared_ptr<Aidge::Tensor>> Aidge::producers(std::shared_ptr<Aidge::GraphView> graphview) {
......@@ -45,3 +44,14 @@ std::set<std::shared_ptr<Aidge::Tensor>> Aidge::parameters(std::shared_ptr<Aidge
}
return res;
}
void Aidge::compile_gradient(std::shared_ptr<Aidge::GraphView> gv) {
for (const auto& node : gv->getNodes()) {
// TODO: check that each node is an OperatorTensor
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Cannot instanciate gradient of an Operator ({}) that doesn't use Tensor.", node->getOperator()->type());
const std::shared_ptr<OperatorTensor> op = std::dynamic_pointer_cast<OperatorTensor>(node -> getOperator());
for (std::size_t o = 0; o < node -> nbOutputs(); ++o) {
op->getOutput(o)->initGradient();
}
}
}
\ No newline at end of file
......@@ -21,8 +21,9 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/utils/Types.h"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
#include "aidge/recipes/GraphViewHelper.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/Memorize.hpp"
#include "aidge/operator/MetaOperator.hpp"
......@@ -71,14 +72,14 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
do {
// 2) From the current consumers list, check if any prior consumer node
// is needed. A prior will generally be required for any node consuming
// is needed. A prior will generally be required for any node consuming
// parameters (weights and bias) that is not an input node.
// If for a given node, only parent producers (at any depth) are needed
// to satisfy its required data, it becomes a prior.
// If the prior node is a producer, it is added to the list of required
// producers.
// If the prior node is of another type, it replaces the initial consumer
// in the new priorConsumers list. The initial consumer will become
// in the new priorConsumers list. The initial consumer will become
// again a consumer later, by construction.
if (verbose) fmt::print("List of consumers with their priors:\n");
std::set<std::shared_ptr<Node>> requiredProducers;
......@@ -129,7 +130,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
}
// 5) Find runnable consumers.
// A consumer is runnable if the required data is available for all of
// A consumer is runnable if the required data is available for all of
// its inputs. At this point, not all consumers are necessarily
// runnable because some may depend on the execution of others (when
// there is multiple successive priors for example).
......@@ -153,7 +154,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
fmt::print("{}", consumer->getOperator()->getNbProducedData(static_cast<IOIndex_t>(consumer->nbOutputs()) - 1));
fmt::print("\n");
}
bool isRunnable = true;
for (IOIndex_t inputIdx = 0; inputIdx < consumer->nbInputs(); ++inputIdx) {
if (/*consumer->getOperator()->getNbRequiredData(inputIdx) > 0
......@@ -189,7 +190,7 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
// 6) Push runnable consumers in the list of nodes to run and update the
// consumer producer system.
// At this point, simultaneously runnable consumers have no data
// At this point, simultaneously runnable consumers have no data
// dependency and could be run in parallel!
for (const auto& runnable : runnableConsumers) {
if (verbose) fmt::print("Runnable: {}\n", namePtrTable[runnable]);
......@@ -323,7 +324,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
memManager.releaseDependencies(node);
continue;
}
const auto childs = node->getChildren();
AIDGE_ASSERT(node->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
......@@ -347,7 +348,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
length = op->getOutput(outputIdx)->dims().end()[-1];
count = op->getOutput(outputIdx)->dims().end()[-2];
}
// Check if wrap around buffer is possible for this node
// (re-using previous node outputs memory for this node outputs).
// => only if this node is the only child of its parent(s)
......@@ -355,7 +356,7 @@ Aidge::MemoryManager Aidge::SequentialScheduler::generateMemory(bool incProducer
size_t wrapAroundExtra = 0;
wrapAroundMemPlane.push_back(nullptr);
// Select the best parent among all allocable nodes for
// Select the best parent among all allocable nodes for
// reallocation, which is the one with most memory (in order
// to minimize the reallocation size).
IOIndex_t inputIdx = 0;
......@@ -426,7 +427,7 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge
// Assert that the number of input data producers corresponds to the number of data input
assert(data.size() == inputNodes.size() && "Scheduler connectInput error - Inconsistent number of graph inputs and inputs passed to the graph");
for (std::size_t i = 0; i < data.size(); ++i){
// TODO : maybe shallow copy instead of deepcopy
inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]);
......@@ -435,7 +436,7 @@ void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
......@@ -475,6 +476,59 @@ void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::ve
}
}
void Aidge::SequentialScheduler::backward(std::vector<std::shared_ptr<Aidge::Tensor>> data, bool instanciateGrad, bool verbose) {
// create ad set Grad values
if (instanciateGrad) { compile_gradient(mGraphView); }
const auto& ordered_outputs = mGraphView->getOrderedOutputs();
AIDGE_ASSERT(ordered_outputs.size() == data.size(), "You must provide the \
right number of data objects to run the backward function. \
{} outputs detected for the current GraphView when {} were \
provided.", ordered_outputs.size(), data.size());
for (std::size_t i = 0; i < ordered_outputs.size(); ++i) {
const std::shared_ptr<OperatorTensor> op_ = std::dynamic_pointer_cast<OperatorTensor>(ordered_outputs[i].first->getOperator());
const std::shared_ptr<Tensor> t_grad = op_->getOutput(ordered_outputs[i].second)->grad();
AIDGE_ASSERT(data[i]->dims() == t_grad->dims(), "Wrong gradient size.");
*t_grad = data[i]->clone();
}
// Generate scheduling *only if empty*
// If scheduling was already generated (in one or several steps, i.e. one or
// several successive call to generateScheduling()), do not generate it twice
if (mStaticSchedule.empty()) {
this->generateScheduling();
}
// map of node <-> info to display with verbose
const auto namePtrTable = mGraphView->getRankedNodesName("{0} ({1}#{3})");
// Clear previous scheduling results
mScheduling.clear();
std::size_t cpt = 0;
// run scheduled operators in reverse order
const auto& runnableList = mStaticSchedule.at(mStaticScheduleStep);
for (auto runnable = runnableList.crbegin(); runnable != runnableList.crend(); ++runnable) {
if (verbose)
fmt::print("run: {}\n", namePtrTable.at(*runnable));
else
drawProgressBar(static_cast<float>(cpt) / static_cast<float>(mStaticSchedule.size()), 50,
(std::string("running ") + namePtrTable.at(*runnable)));
const auto tStart = std::chrono::high_resolution_clock::now();
(*runnable)->backward();
const auto tEnd = std::chrono::high_resolution_clock::now();
mScheduling.push_back(SchedulingElement(*runnable, tStart, tEnd));
cpt++;
}
if (!verbose) drawProgressBar(1.0, 50, " ");
fmt::print("\n");
++mStaticScheduleStep;
if (mStaticScheduleStep == mStaticSchedule.size()) {
mStaticScheduleStep = 0;
}
}
void Aidge::SequentialScheduler::saveSchedulingDiagram(const std::string& fileName) const {
auto fp = std::unique_ptr<FILE, decltype(&std::fclose)>(std::fopen((fileName + ".mmd").c_str(), "w"), &std::fclose);
......@@ -540,7 +594,7 @@ Aidge::NbElts_t Aidge::SequentialScheduler::getNbAvailableData(const std::shared
const auto upperInput = upperNode->inputs()[nodeInputIdx];
if (upperInput.first) {
return upperInput.first->getOperator()->getNbProducedData(upperInput.second);
}
}
}
++nodeInputIdx;
}
......
......@@ -122,7 +122,7 @@ TEST_CASE("[core/operators] MetaOperator", "[Operator][MetaOperator]") {
expandMetaOps(g);
g->setRootNode(pop);
REQUIRE(g->getRootNode() == pop);
REQUIRE(g->rootNode() == pop);
g->save("lstm_expanded", true, true);
REQUIRE(g->getNodes().size() == 41);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment