From 3dd2b5486c941f2f2014d250bc9c636c5a7c4960 Mon Sep 17 00:00:00 2001
From: thibault allenet <thibault.allenet@cea.fr>
Date: Mon, 15 Jan 2024 15:24:50 +0000
Subject: [PATCH] Add getNbModalities function in MNIST

---
 include/aidge/backend/opencv/database/MNIST.hpp | 3 ++-
 src/database/MNIST.cpp                          | 7 ++++++-
 2 files changed, 8 insertions(+), 2 deletions(-)

diff --git a/include/aidge/backend/opencv/database/MNIST.hpp b/include/aidge/backend/opencv/database/MNIST.hpp
index c1803fb..e3a9b6b 100644
--- a/include/aidge/backend/opencv/database/MNIST.hpp
+++ b/include/aidge/backend/opencv/database/MNIST.hpp
@@ -37,7 +37,6 @@ inline bool isBigEndian()
 
 class MNIST : public Database {
 public:
-    MNIST() = delete;
     MNIST(const std::string& dataPath, 
             // const GraphView transformations, 
             bool train,
@@ -65,6 +64,8 @@ public:
 
     std::size_t getLen() override;
 
+    std::size_t getNbModalities() override;
+
     union MagicNumber {
         unsigned int value;
         unsigned char byte[4];
diff --git a/src/database/MNIST.cpp b/src/database/MNIST.cpp
index ae745e9..d9145f6 100644
--- a/src/database/MNIST.cpp
+++ b/src/database/MNIST.cpp
@@ -127,7 +127,7 @@ void Aidge::MNIST::uncompress(const std::string& dataPath,
 std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::MNIST::getItem(std::size_t index) {
     std::vector<std::shared_ptr<Tensor>> item;
     // Load the digit tensor 
-    // TODO : Currently converts the tensor Opencv but this operation will be carried by a convert operator later
+    // TODO : Currently converts the tensor Opencv but this operation will be carried by a convert operator in the preprocessing graph
     item.push_back(Aidge::convertCpu((std::get<0>(mStimulis.at(index))).load()));
     // item.push_back((std::get<0>(mStimulis.at(index))).load());
     // Load the label tensor 
@@ -138,4 +138,9 @@ std::vector<std::shared_ptr<Aidge::Tensor>> Aidge::MNIST::getItem(std::size_t in
 
 std::size_t  Aidge::MNIST::getLen(){
     return mStimulis.size();
+}
+
+std::size_t  Aidge::MNIST::getNbModalities(){
+    size_t tupleSize = std::tuple_size<decltype(mStimulis)::value_type>::value;
+    return tupleSize;
 }
\ No newline at end of file
-- 
GitLab