From a71e2361cd847a8465b0d9533bdcc8efb6b17c84 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Mon, 15 Jan 2024 15:01:35 +0000 Subject: [PATCH] Add ConnectInputs function in Scheduler forward and Update pybind scheduler --- include/aidge/scheduler/Scheduler.hpp | 10 ++++++++- python_binding/scheduler/pybind_Scheduler.cpp | 3 ++- src/scheduler/Scheduler.cpp | 22 ++++++++++++++++++- 3 files changed, 32 insertions(+), 3 deletions(-) diff --git a/include/aidge/scheduler/Scheduler.hpp b/include/aidge/scheduler/Scheduler.hpp index 6dcec5aaa..7a81503c9 100644 --- a/include/aidge/scheduler/Scheduler.hpp +++ b/include/aidge/scheduler/Scheduler.hpp @@ -18,6 +18,8 @@ #include <string> #include <vector> +#include "aidge/data/Tensor.hpp" + namespace Aidge { class Node; class GraphView; @@ -49,11 +51,17 @@ public: mScheduling.clear(); mStaticSchedule.clear(); } + /** + * @brief Place the data tensors inside in the data input tensor of the graphView. In case of multiple data input tensors, they are mapped to producers in the order given by the graph. + * + * @param data data input tensors + */ + void connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data); /** * @brief Run the provided Computational Graph with a batch of data */ - void forward(bool forwardDims = true, bool verbose = false); + void forward(bool forwardDims = true, bool verbose = false, std::vector<std::shared_ptr<Aidge::Tensor>> data = {}); /** * @brief Save in a Markdown file the order of layers execution. diff --git a/python_binding/scheduler/pybind_Scheduler.cpp b/python_binding/scheduler/pybind_Scheduler.cpp index d963b81d5..4eb715e79 100644 --- a/python_binding/scheduler/pybind_Scheduler.cpp +++ b/python_binding/scheduler/pybind_Scheduler.cpp @@ -13,13 +13,14 @@ #include <pybind11/stl.h> #include "aidge/scheduler/Scheduler.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/data/Tensor.hpp" namespace py = pybind11; namespace Aidge { void init_Scheduler(py::module& m){ py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>>(m, "SequentialScheduler") .def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view")) - .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false) + .def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("verbose")=false, py::arg("data")=std::vector<Tensor>()) .def("save_scheduling_diagram", &SequentialScheduler::saveSchedulingDiagram, py::arg("file_name")) .def("resetScheduling", &SequentialScheduler::resetScheduling) .def("generate_scheduling", &SequentialScheduler::generateScheduling, py::arg("verbose")=false) diff --git a/src/scheduler/Scheduler.cpp b/src/scheduler/Scheduler.cpp index 3afbcd044..380ff8bf3 100644 --- a/src/scheduler/Scheduler.cpp +++ b/src/scheduler/Scheduler.cpp @@ -174,8 +174,28 @@ void Aidge::SequentialScheduler::generateScheduling(bool verbose) { } +void Aidge::SequentialScheduler::connectInputs(std::vector<std::shared_ptr<Aidge::Tensor>> data){ + // This version of connect inputs only connects tensor inputs in input data producers. + auto inputNodes = mGraphView->getOrderedInputs(); + + // Assert that the number of input data producers corresponds to the number of data input + assert(data.size() == inputNodes.size() && "Scheduler connectInput error - Inconsistent number of graph inputs and inputs passed to the graph"); + + for (std::size_t i = 0; i < data.size(); ++i){ + // TODO : maybe shallow copy instead of deepcopy + inputNodes[i].first->getOperator()->setInput(inputNodes[i].second, data[i]); + } +} + + // TODO: handle multiple inputs/outputs -void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose) { +void Aidge::SequentialScheduler::forward(bool forwardDims, bool verbose, std::vector<std::shared_ptr<Aidge::Tensor>> data) { + + // Collect all data input of the graph (that are producers) + if (!data.empty()){ + connectInputs(data); + } + // Forward dims (if allowed) if (forwardDims) {mGraphView->forwardDims(); } -- GitLab