diff --git a/include/aidge/data/DataProvider.hpp b/include/aidge/data/DataProvider.hpp index 055d6e9261b8ba8ee66d8f01e9b409ac18a869dc..b96118c164d5fa08e53d9ee9eeda155e14e9ea44 100644 --- a/include/aidge/data/DataProvider.hpp +++ b/include/aidge/data/DataProvider.hpp @@ -36,7 +36,6 @@ protected: size_t mNumberModality; std::vector<std::vector<std::size_t>> mDataSizes; - std::vector<std::string> mDataBackends; std::vector<DataType> mDataTypes; // Desired size of the produced batches diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp index 07804be8ec21f13d2c8ec77efafbae49ebed4da8..10c016b8a77db357792aa34339fca046af299385 100644 --- a/src/data/DataProvider.cpp +++ b/src/data/DataProvider.cpp @@ -9,15 +9,13 @@ DataProvider::DataProvider(Database& database, std::size_t batchSize) mDatabase(database), mBatchSize(batchSize) { - // Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same + // Get the tensor dimensions, datatype and backend of each modality to ensure each data have the samereadBatch auto item = mDatabase.getItem(0); mNumberModality = item.size(); // Iterating on each data modality in the database for (std::size_t i = 0; i < mNumberModality; ++i) { mDataSizes.push_back(item[i]->dims()); - // assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors"); - // mDataBackends.push_back(item[i]->getImpl()->backend()); mDataTypes.push_back(item[i]->dataType()); } } @@ -42,7 +40,6 @@ std::vector<std::shared_ptr<Tensor>> DataProvider::readBatch(std::size_t startIn dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size); auto batchData = std::make_shared<Tensor>(); batchData->resize(dataBatchSize[i]); - // batchData->setBackend(mDataBackends[i]); batchData->setBackend("cpu"); batchData->setDataType(mDataTypes[i]); batchTensors.push_back(batchData); @@ -61,9 +58,12 @@ std::vector<std::shared_ptr<Tensor>> DataProvider::readBatch(std::size_t startIn // Assert tensor sizes assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size"); - - // Assert implementation backend - // assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend"); + + // Check the backend of the tensor from the database + // If not cpu then change the backend to cpu + if (strcmp(dataSample->getImpl()->backend(), "cpu") != 0) { + dataSample->setBackend("cpu"); + } // Assert DataType assert(dataSample->dataType() == mDataTypes[j] && "DataProvider readBatch : corrupted data DataType");