diff --git a/include/aidge/aidge.hpp b/include/aidge/aidge.hpp index 7f32d695a41d954e9f31c6682e3cc6fc0226aed9..e04912c637d5339600c1708a7fb3b68c3ddb494c 100644 --- a/include/aidge/aidge.hpp +++ b/include/aidge/aidge.hpp @@ -16,6 +16,9 @@ #include "aidge/backend/TensorImpl.hpp" #include "aidge/data/Data.hpp" #include "aidge/data/Tensor.hpp" +#include "aidge/data/Database.hpp" +#include "aidge/data/DatabaseTensor.hpp" +#include "aidge/data/Dataloader.hpp" #include "aidge/graph/Connector.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" diff --git a/include/aidge/data/Database.hpp b/include/aidge/data/Database.hpp new file mode 100644 index 0000000000000000000000000000000000000000..6b22bdef6a7706ed339fd7c94243f266df53ea63 --- /dev/null +++ b/include/aidge/data/Database.hpp @@ -0,0 +1,43 @@ +#ifndef Database_H_ +#define Database_H_ + +#include <cstring> + +#include "aidge/data/Tensor.hpp" + +namespace Aidge{ + +/** + * @brief Database. An abstract class representing a database. All databases should inherit from this class. All subclasses should overwrite ```get_item()``` to fetch data from a given index. + * @todo Make the dataset generic. Currently supprting only tensor. Always ground truth. + */ +class Database { + +public: + + virtual ~Database() = default; + + /** + * @brief Fetch a data sample and its corresponding ground_truth + * @param index index of the pair (```data```, ```ground truth```) to fetch from the database + * @return A pair of pointers to the data (first) and its corresping ground truth (second) + */ + virtual std::pair<std::shared_ptr<Tensor>,std::shared_ptr<Tensor>> get_item(unsigned int index) = 0; + + /** + * @return The number of data samples in the database. + */ + virtual unsigned int get_len() = 0; + + // void load(const std::string& /*dataPath*/, const std::string& labelPath = ""); + +protected: + + std::vector<std::shared_ptr<Tensor>> mData; + std::vector<std::shared_ptr<Tensor>> mLabel; + +}; + +} + +#endif /* Database_H_ */ \ No newline at end of file