Python database
Add possibility to define a database using Python
import torch
import torchvision
import torchvision.transforms as transforms
import aidge_core
device = torch.device("cpu")
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=1,
shuffle=True, num_workers=2)
aidge_inputs, aidge_labels = [], []
for data in trainloader:
inputs, labels = data[0].to(device), data[1].to(device)
aidge_inputs.append(aidge_core.Tensor(inputs.numpy()))
aidge_labels.append(aidge_core.Tensor(labels.numpy()))
class aidge_cifar10(aidge_core.Database):
def __init__(self, inputs, labels):
aidge_core.Database.__init__(self)
assert(len(inputs) == len(labels))
self.inputs = inputs
self.labels = labels
def get_item(self, idx):
return [self.inputs[idx], self.labels[idx]]
def len(self):
return len(self.inputs)
def get_nb_modalities(self):
return 2
aidge_database = aidge_cifar10(aidge_inputs, aidge_labels)
aidge_dataprovider = aidge_core.DataProvider(aidge_database,
batch_size=4,
shuffle=True,
drop_last=False)
for i, (data_batch, lbl_batch) in enumerate(aidge_dataprovider):
print(lbl_batch)