From c61595f4994f45e2ed6dd87f2c5d0017daf0eaa1 Mon Sep 17 00:00:00 2001 From: thibault allenet <thibault.allenet@cea.fr> Date: Tue, 23 Jan 2024 14:43:25 +0000 Subject: [PATCH] Modify DataProvider to change backend of the tensors to cpu with the generic set_Backend() --- include/aidge/data/DataProvider.hpp | 1 - src/data/DataProvider.cpp | 14 +++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/include/aidge/data/DataProvider.hpp b/include/aidge/data/DataProvider.hpp index 055d6e926..b96118c16 100644 --- a/include/aidge/data/DataProvider.hpp +++ b/include/aidge/data/DataProvider.hpp @@ -36,7 +36,6 @@ protected: size_t mNumberModality; std::vector<std::vector<std::size_t>> mDataSizes; - std::vector<std::string> mDataBackends; std::vector<DataType> mDataTypes; // Desired size of the produced batches diff --git a/src/data/DataProvider.cpp b/src/data/DataProvider.cpp index 07804be8e..10c016b8a 100644 --- a/src/data/DataProvider.cpp +++ b/src/data/DataProvider.cpp @@ -9,15 +9,13 @@ DataProvider::DataProvider(Database& database, std::size_t batchSize) mDatabase(database), mBatchSize(batchSize) { - // Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same + // Get the tensor dimensions, datatype and backend of each modality to ensure each data have the samereadBatch auto item = mDatabase.getItem(0); mNumberModality = item.size(); // Iterating on each data modality in the database for (std::size_t i = 0; i < mNumberModality; ++i) { mDataSizes.push_back(item[i]->dims()); - // assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors"); - // mDataBackends.push_back(item[i]->getImpl()->backend()); mDataTypes.push_back(item[i]->dataType()); } } @@ -42,7 +40,6 @@ std::vector<std::shared_ptr<Tensor>> DataProvider::readBatch(std::size_t startIn dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size); auto batchData = std::make_shared<Tensor>(); batchData->resize(dataBatchSize[i]); - // batchData->setBackend(mDataBackends[i]); batchData->setBackend("cpu"); batchData->setDataType(mDataTypes[i]); batchTensors.push_back(batchData); @@ -61,9 +58,12 @@ std::vector<std::shared_ptr<Tensor>> DataProvider::readBatch(std::size_t startIn // Assert tensor sizes assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size"); - - // Assert implementation backend - // assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend"); + + // Check the backend of the tensor from the database + // If not cpu then change the backend to cpu + if (strcmp(dataSample->getImpl()->backend(), "cpu") != 0) { + dataSample->setBackend("cpu"); + } // Assert DataType assert(dataSample->dataType() == mDataTypes[j] && "DataProvider readBatch : corrupted data DataType"); -- GitLab