Skip to content
Snippets Groups Projects
Commit a71e2361 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Add ConnectInputs function in Scheduler forward and Update pybind scheduler

parent a343e118
No related branches found
No related tags found
No related merge requests found
...@@ -18,6 +18,8 @@ ...@@ -18,6 +18,8 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "aidge/data/Tensor.hpp"
namespace Aidge { namespace Aidge {
class Node; class Node;
class GraphView; class GraphView;
...@@ -49,11 +51,17 @@ public: ...@@ -49,11 +51,17 @@ public:
mScheduling.clear(); mScheduling.clear();
mStaticSchedule.clear(); mStaticSchedule.clear();
} }
/**
* @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph.
*
* @param data data input tensors
*/
void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data);
/** /**
* @brief Run the provided Computational Graph with a batch of data * @brief Run the provided Computational Graph with a batch of data
*/ */
void forward(bool forwardDims = true, bool verbose = false); void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {});
/** /**
* @brief Save in a Markdown file the order of layers execution. * @brief Save in a Markdown file the order of layers execution.
......
...@@ -13,13 +13,14 @@ ...@@ -13,13 +13,14 @@
#include <pybind11/stl.h> #include <pybind11/stl.h>
#include "aidge/scheduler/Scheduler.hpp" #include "aidge/scheduler/Scheduler.hpp"
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Scheduler(py::module& m){ void init_Scheduler(py::module& m){
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler") py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false) .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>())
.def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &SequentialScheduler::resetScheduling) .def("resetScheduling", &SequentialScheduler::resetScheduling)
.def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false)
......
...@@ -174,8 +174,28 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { ...@@ -174,8 +174,28 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) {
} }
void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){
// This version of connect inputs only connects tensor inputs in input data producers.
auto inputNodes = mGraphView->getOrderedInputs();
// Assert that the number of input data producers corresponds to the number of data input
assert(data.size() == inputNodes.size() && "Scheduler connectInput error - Inconsistent number of graph inputs and inputs passed to the graph");
for (std::size_t i = 0; i < data.size(); ++i){
// TODO : maybe shallow copy instead of deepcopy
inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]);
}
}
// TODO: handle multiple inputs/outputs // TODO: handle multiple inputs/outputs
void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) {
// Collect all data input of the graph (that are producers)
if (!data.empty()){
connectInputs(data);
}
// Forward dims (if allowed) // Forward dims (if allowed)
if (forwardDims) {mGraphView->forwardDims(); } if (forwardDims) {mGraphView->forwardDims(); }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment