diff --git a/python_binding/database/pybind_MNIST.cpp b/python_binding/database/pybind_MNIST.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3c73a5f068b8b7376bcd8024d22cb6f0de899562 --- /dev/null +++ b/python_binding/database/pybind_MNIST.cpp @@ -0,0 +1,33 @@ +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include "aidge/backend/opencv/database/MNIST.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_MNIST(py::module& m){ + + py::class_<MNIST, std::shared_ptr<MNIST>, Database>(m, "MNIST") + .def(py::init<const std::string&, bool, bool>(), py::arg("dataPath"), py::arg("train"), py::arg("load_data_in_memory")=false) + .def("get_item", &MNIST::getItem, py::arg("index"), + R"mydelimiter( + Return samples of each data modality for the given index. + + :param index: Database index corresponding to one item + :type index: int + )mydelimiter") + + .def("get_len", &MNIST::getLen, + R"mydelimiter( + Return the number of items in the database. + + )mydelimiter") + + .def("get_nb_modalities", &MNIST::getNbModalities, + R"mydelimiter( + Return the number of modalities in one item of the database. + + )mydelimiter"); + +} +} diff --git a/python_binding/pybind_opencv.cpp b/python_binding/pybind_opencv.cpp index 0e5507fb757a2a72c9122bc7f46585fbfefd4faa..276467d9b87d0a6e6c90b32294748de9a91f7a88 100644 --- a/python_binding/pybind_opencv.cpp +++ b/python_binding/pybind_opencv.cpp @@ -6,12 +6,9 @@ namespace py = pybind11; namespace Aidge { -void init_Aidge(py::module& /*m*/){ - -} +void init_MNIST(py::module&); PYBIND11_MODULE(aidge_backend_opencv, m) { - init_Aidge(m); + init_MNIST(m); } } -