Skip to content
Snippets Groups Projects
Commit c61595f4 authored by Thibault Allenet's avatar Thibault Allenet
Browse files

Modify DataProvider to change backend of the tensors to cpu with the generic set_Backend()

parent efb68e6a
No related tags found
1 merge request!72Draft: Data provider use setBackend
Pipeline #45644 canceled
......@@ -36,7 +36,6 @@ protected:
size_t mNumberModality;
std::vector<std::vector<std::size_t>> mDataSizes;
std::vector<std::string> mDataBackends;
std::vector<DataType> mDataTypes;
// Desired size of the produced batches
......
......@@ -9,15 +9,13 @@ DataProvider::DataProvider(Database& database, std::size_t batchSize)
mDatabase(database),
mBatchSize(batchSize)
{
// Get the tensor dimensions, datatype and backend of each modality to ensure each data have the same
// Get the tensor dimensions, datatype and backend of each modality to ensure each data have the samereadBatch
auto item = mDatabase.getItem(0);
mNumberModality = item.size();
// Iterating on each data modality in the database
for (std::size_t i = 0; i < mNumberModality; ++i) {
mDataSizes.push_back(item[i]->dims());
// assert(std::strcmp(item[i]->getImpl()->backend(), "cpu") == 0 && "DataProvider currently only supports cpu backend tensors");
// mDataBackends.push_back(item[i]->getImpl()->backend());
mDataTypes.push_back(item[i]->dataType());
}
}
......@@ -42,7 +40,6 @@ std::vector<std::shared_ptr<Tensor>> DataProvider::readBatch(std::size_t startIn
dataBatchSize[i].insert(dataBatchSize[i].begin(), current_batch_size);
auto batchData = std::make_shared<Tensor>();
batchData->resize(dataBatchSize[i]);
// batchData->setBackend(mDataBackends[i]);
batchData->setBackend("cpu");
batchData->setDataType(mDataTypes[i]);
batchTensors.push_back(batchData);
......@@ -61,9 +58,12 @@ std::vector<std::shared_ptr<Tensor>> DataProvider::readBatch(std::size_t startIn
// Assert tensor sizes
assert(dataSample->dims() == mDataSizes[j] && "DataProvider readBatch : corrupted Data size");
// Assert implementation backend
// assert(dataSample->getImpl()->backend() == mDataBackends[j] && "DataProvider readBatch : corrupted data backend");
// Check the backend of the tensor from the database
// If not cpu then change the backend to cpu
if (strcmp(dataSample->getImpl()->backend(), "cpu") != 0) {
dataSample->setBackend("cpu");
}
// Assert DataType
assert(dataSample->dataType() == mDataTypes[j] && "DataProvider readBatch : corrupted data DataType");
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment