From 2c3257fe46c8017fc184ac272ec3120b5c89d8ce Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Tue, 29 Aug 2023 14:36:07 +0000 Subject: [PATCH] Add Database abstract class --- include/aidge/aidge.hpp | 3 +++ include/aidge/data/Database.hpp | 43 +++++++++++++++++++++++++++++++++ 2 files changed, 46 insertions(+) create mode 100644 include/aidge/data/Database.hpp diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 7f32d695a..e04912c63 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -16,6 +16,9 @@ #include "aidge/backend/TensorImpl.hpp" #include "aidge/data/Data.hpp" #include "aidge/data/Tensor.hpp" +#include "aidge/data/Database.hpp" +#include "aidge/data/DatabaseTensor.hpp" +#include "aidge/data/Dataloader.hpp" #include "aidge/graph/Connector.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" diff --git a/include/aidge/data/Database.hpp b/include/aidge/data/Database.hpp new file mode 100644 index 000000000..6b22bdef6 --- /dev/null +++ b/include/aidge/data/Database.hpp @@ -0,0 +1,43 @@ +#ifndef Database_H_ +#define Database_H_ + +#include <cstring> + +#include "aidge/data/Tensor.hpp" + +namespace Aidge{ + +/** + * @brief Database. An abstract class representing a database. All databases should inherit from this class. All subclasses should overwrite ```get_item()``` to fetch data from a given index. + * @todo Make the dataset generic. Currently supprting only tensor. Always ground truth. + */ +class Database { + +public: + + virtual ~Database() = default; + + /** + * @brief Fetch a data sample and its corresponding ground_truth + * @param index index of the pair (```data```, ```ground truth```) to fetch from the database + * @return A pair of pointers to the data (first) and its corresping ground truth (second) + */ + virtual std::pair<std::shared_ptr<Tensor>,std::shared_ptr<Tensor>> get_item(unsigned int index) = 0; + + /** + * @return The number of data samples in the database. + */ + virtual unsigned int get_len() = 0; + + // void load(const std::string& /*dataPath*/, const std::string& labelPath = ""); + +protected: + + std::vector<std::shared_ptr<Tensor>> mData; + std::vector<std::shared_ptr<Tensor>> mLabel; + +}; + +} + +#endif /* Database_H_ */ \ No newline at end of file -- GitLab