Skip to content

Python database

Cyril Moineau requested to merge PythonDB into dev

Add possibility to define a database using Python

import torch
import torchvision
import torchvision.transforms as transforms
import aidge_core

device = torch.device("cpu")

transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
                                        shuffle=True, num_workers=2)

aidge_inputs, aidge_labels = [], []
for data in trainloader:
    inputs, labels = data[0].to(device), data[1].to(device)
    aidge_inputs.append(aidge_core.Tensor(inputs.numpy()))
    aidge_labels.append(aidge_core.Tensor(labels.numpy()))

class aidge_cifar10(aidge_core.Database):
    def __init__(self, inputs, labels):
        aidge_core.Database.__init__(self)
        assert(len(inputs) == len(labels))
        self.inputs = inputs
        self.labels = labels
    def get_item(self, idx):
        return [self.inputs[idx], self.labels[idx]]
    def len(self):
        return len(self.inputs)
    def get_nb_modalities(self):
        return 2

aidge_database  = aidge_cifar10(aidge_inputs, aidge_labels)

aidge_dataprovider = aidge_core.DataProvider(aidge_database,
                                           batch_size=4,
                                           shuffle=True,
                                           drop_last=False)

for i, (data_batch, lbl_batch) in enumerate(aidge_dataprovider):
     print(lbl_batch)

Merge request reports