Forked from
Eclipse Projects / aidge / aidge_core
1292 commits behind the upstream repository.
-
Cyril Moineau authoredCyril Moineau authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
pybind_Scheduler.cpp 2.10 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
#include "aidge/scheduler/MemoryManager.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/ParallelScheduler.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/data/Tensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Scheduler(py::module& m){
py::class_<Scheduler, std::shared_ptr<Scheduler>>(m, "Scheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("graph_view", &Scheduler::graphView)
.def("save_scheduling_diagram", &Scheduler::saveSchedulingDiagram, py::arg("file_name"))
.def("resetScheduling", &Scheduler::resetScheduling)
.def("generate_scheduling", &Scheduler::generateScheduling)
.def("get_static_scheduling", &Scheduler::getStaticScheduling, py::arg("step") = 0)
.def("generate_memory", &Scheduler::generateMemory, py::arg("inc_producers") = false, py::arg("wrap_around_buffer") = false)
;
py::class_<SequentialScheduler, std::shared_ptr<SequentialScheduler>, Scheduler>(m, "SequentialScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &SequentialScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
.def("backward", &SequentialScheduler::backward, py::arg("instanciate_grad")=true)
;
py::class_<ParallelScheduler, std::shared_ptr<ParallelScheduler>, Scheduler>(m, "ParallelScheduler")
.def(py::init<std::shared_ptr<GraphView>&>(), py::arg("graph_view"))
.def("forward", &ParallelScheduler::forward, py::arg("forward_dims")=true, py::arg("data")=std::vector<Tensor>())
;
}
}