From a232eaba95959264f4fa02a75001aa3a21e814c7 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Mon, 15 Jan 2024 15:30:11 +0000 Subject: [PATCH] Add python biding for MNIST database --- python_binding/database/pybind_MNIST.cpp | 33 ++++++++++++++++++++++++ python_binding/pybind_opencv.cpp | 7 ++--- 2 files changed, 35 insertions(+), 5 deletions(-) create mode 100644 python_binding/database/pybind_MNIST.cpp diff --git a/python_binding/database/pybind_MNIST.cpp b/python_binding/database/pybind_MNIST.cpp new file mode 100644 index 0000000..3c73a5f --- /dev/null +++ b/python_binding/database/pybind_MNIST.cpp @@ -0,0 +1,33 @@ +#include <pybind11/pybind11.h> +#include <pybind11/stl.h> +#include "aidge/backend/opencv/database/MNIST.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_MNIST(py::module& m){ + + py::class_<MNIST, std::shared_ptr<MNIST>, Database>(m, "MNIST") + .def(py::init<const std::string&, bool, bool>(), py::arg("dataPath"), py::arg("train"), py::arg("load_data_in_memory")=false) + .def("get_item", &MNIST::getItem, py::arg("index"), + R"mydelimiter( + Return samples of each data modality for the given index. + + :param index: Database index corresponding to one item + :type index: int + )mydelimiter") + + .def("get_len", &MNIST::getLen, + R"mydelimiter( + Return the number of items in the database. + + )mydelimiter") + + .def("get_nb_modalities", &MNIST::getNbModalities, + R"mydelimiter( + Return the number of modalities in one item of the database. + + )mydelimiter"); + +} +} diff --git a/python_binding/pybind_opencv.cpp b/python_binding/pybind_opencv.cpp index 0e5507f..276467d 100644 --- a/python_binding/pybind_opencv.cpp +++ b/python_binding/pybind_opencv.cpp @@ -6,12 +6,9 @@ namespace py = pybind11; namespace Aidge { -void init_Aidge(py::module& /*m*/){ - -} +void init_MNIST(py::module&); PYBIND11_MODULE(aidge_backend_opencv, m) { - init_Aidge(m); + init_MNIST(m); } } - -- GitLab