Skip to content
Snippets Groups Projects
Commit 3de49bde authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Scheduler] Add Scheduler.getNodeScheduling method and fix sequential function.

parent a0bf5e8d
No related branches found
No related tags found
No related merge requests found
......@@ -55,7 +55,7 @@ public:
* @param inputs List of Node and GraphView to link sequentially.
* @return std::shared_ptr<GraphView> Pointer to the generated view.
*/
std::shared_ptr<GraphView> Sequential(std::initializer_list<OpArgs> inputs);
std::shared_ptr<GraphView> Sequential(std::vector<OpArgs> inputs);
/////////////////////////////
// Parallel
......@@ -65,7 +65,7 @@ std::shared_ptr<GraphView> Sequential(std::initializer_list<OpArgs> inputs);
* @param inputs List of Node and GraphView to link sequentially.
* @return std::shared_ptr<GraphView> pointer to the generated view.
*/
std::shared_ptr<GraphView> Parallel(std::initializer_list<OpArgs> inputs);
std::shared_ptr<GraphView> Parallel(std::vector<OpArgs> inputs);
/////////////////////////////
// Residual
......@@ -79,8 +79,8 @@ std::shared_ptr<GraphView> Parallel(std::initializer_list<OpArgs> inputs);
* @param inputs List of Node and GraphView to link sequentially.
* @return std::shared_ptr<GraphView> pointer to the generated view.
*/
std::shared_ptr<GraphView> Residual(std::initializer_list<OpArgs> inputs);
std::shared_ptr<GraphView> Residual(std::vector<OpArgs> inputs);
}
#endif /* AIDGE_CORE_GRAPH_OPARGS_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_GRAPH_OPARGS_H_ */
......@@ -54,6 +54,19 @@ public:
*/
void saveSchedulingDiagram(const std::string& fileName) const;
/**
* @brief Return a vector of Node ordered by the order they are called by the scheduler
*
* @return std::vector<std::shared_ptr<Node>>
*/
std::vector<std::shared_ptr<Node>> getNodeScheduling(){
std::vector<std::shared_ptr<Node>> nodeScheduling = {};
for(SchedulingElement & scheduleElt: mScheduling){
nodeScheduling.push_back(scheduleElt.node);
}
return nodeScheduling;
}
private:
/**
* @brief Set of layers receiving an input from currently processing layers
......@@ -68,4 +81,4 @@ private:
};
} // namespace Aidge
#endif /* AIDGE_SCHEDULER_H_ */
\ No newline at end of file
#endif /* AIDGE_SCHEDULER_H_ */
......@@ -10,19 +10,20 @@
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/graph/OpArgs.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include <pybind11/stl.h>
#include <pybind11/complex.h>
#include <pybind11/functional.h>
#include <pybind11/chrono.h>
namespace py = pybind11;
namespace Aidge {
void init_OpArgs(py::module& m){
py::class_<OpArgs, std::shared_ptr<OpArgs>>(m, "OpArgs")
.def(py::init<const std::shared_ptr<GraphView>&>(), py::arg("view_"))
.def(py::init<const std::shared_ptr<Node>&>(), py::arg("node_"))
.def("node", &OpArgs::node)
.def("view", &OpArgs::view)
;
......
......@@ -10,6 +10,7 @@
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/graph/GraphView.hpp"
......@@ -20,6 +21,7 @@ void init_Scheduler(py::module& m){
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false)
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("get_node_scheduling", &SequentialScheduler::getNodeScheduling)
;
}
}
......
......@@ -14,13 +14,13 @@
#include "aidge/graph/OpArgs.hpp"
std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs> inputs) {
std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
for (const OpArgs& elt : inputs) {
if(elt.node() != nullptr) {
// >= to allow incomplete graphViews
assert(static_cast<std::size_t>(elt.node()->getNbFreeDataInputs()) >= gv->outputNodes().size());
/*
/*
* /!\ mn.view()->outputNodes() is a set, order of Nodes cannot be guaranted.
* Prefer a functional description for detailed inputs
*/
......@@ -44,7 +44,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::Sequential(std::initializer_list<OpArgs
}
std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::initializer_list<OpArgs> inputs) {
std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = std::make_shared<GraphView>();
for(const OpArgs& elt : inputs) {
if (elt.node()!=nullptr)
......@@ -56,7 +56,7 @@ std::shared_ptr<Aidge::GraphView> Aidge::Parallel(std::initializer_list<OpArgs>
}
std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::initializer_list<OpArgs> inputs) {
std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::vector<OpArgs> inputs) {
std::shared_ptr<GraphView> gv = Sequential(inputs);
assert(gv->outputNodes().size() == 1U && "Zero or more than one output Node for the GraphView, don't know which one to choose from for the residual connection");
std::shared_ptr<Node> lastNode = *gv->outputNodes().begin();
......@@ -70,4 +70,4 @@ std::shared_ptr<Aidge::GraphView> Aidge::Residual(std::initializer_list<OpArgs>
assert(lastNode->getNbFreeDataInputs()>=1);
gv->addChild(lastNode, firstNode, 0U, gk_IODefaultIndex);
return gv;
}
\ No newline at end of file
}
......@@ -34,8 +34,8 @@ void drawProgressBar(double progress, int barWidth, const std::string& additiona
}
// TODO: handle multiple inputs/outputs
void Aidge::SequentialScheduler::forward(bool frowardDims, bool verbose) {
if (frowardDims) {mGraphView->forwardDims(); }
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) {
if (forwardDims) {mGraphView->forwardDims(); }
mScheduling.clear();
......@@ -231,4 +231,4 @@ std::set<std::shared_ptr<Aidge::Node>> Aidge::SequentialScheduler::getConsumers(
}
return consumers;
}
\ No newline at end of file
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment