diff --git a/examples/tutorials/Aidge_tutorial/MNIST_model/generate_LeNet.sh b/examples/tutorials/Aidge_tutorial/MNIST_model/generate_LeNet.sh new file mode 100755 index 0000000000000000000000000000000000000000..7aeaf52ff2bdc2dbef34634b51b9167f79e85dfa --- /dev/null +++ b/examples/tutorials/Aidge_tutorial/MNIST_model/generate_LeNet.sh @@ -0,0 +1,24 @@ +#!/bin/bash +script_directory="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" +echo "Script directory: $script_directory" + +ONNX="$script_directory/LeNet_MNIST.onnx" +DIGIT="$script_directory/digit_lenet.npy" +OUTPUT="$script_directory/output_digit_lenet.npy" +if [ ! -e "$ONNX" ] || [ ! -e "$DIGIT" ] || [ ! -e "$OUTPUT" ] +then + virtualenv -p python3.8 "$script_directory/py3_8" + source "$script_directory/py3_8/bin/activate" + pip3 install --quiet -U torch torchvision --index-url https://download.pytorch.org/whl/cpu + pip install --quiet -U onnx + pip install --quiet -U torchsummary + + python ./torch_LeNet.py --epoch 1 + + deactivate + rm -r "$script_directory/py3_8" + rm -r "$script_directory/data" +else + echo "$ONNX $DIGIT $OUTPUT exist." +fi + diff --git a/examples/tutorials/Aidge_tutorial/MNIST_model/torch_LeNet.py b/examples/tutorials/Aidge_tutorial/MNIST_model/torch_LeNet.py new file mode 100644 index 0000000000000000000000000000000000000000..983fe9ee1aed76f8aa39c1d9bfeca610b3520a5f --- /dev/null +++ b/examples/tutorials/Aidge_tutorial/MNIST_model/torch_LeNet.py @@ -0,0 +1,188 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from torchvision import datasets, transforms +from torchsummary import summary +import torch.nn.functional as F + +import torch.nn as nn + + +import numpy as np +import onnx +import argparse +import os + +class LeNet(torch.nn.Module): + def __init__(self): + + super(LeNet, self).__init__() + # feature extractor CNN + self._feature_extractor = torch.nn.Sequential( + torch.nn.Conv2d(1,6,5), + torch.nn.ReLU(), + torch.nn.MaxPool2d(2,2), + torch.nn.Conv2d(6,16,5), + torch.nn.ReLU(), + torch.nn.MaxPool2d(2,2) ) + # classifier MLP + self._classifier = torch.nn.Sequential( + torch.nn.Linear(256,120), + torch.nn.ReLU(), + torch.nn.Linear(120,84), + torch.nn.ReLU(), + torch.nn.Linear(84,10) ) + + def forward(self, x): + # extract features + features = self._feature_extractor(x) + # flatten the 3d tensor (2d space x channels = features) + features = torch.flatten(features, start_dim=1) + #features = features.view(-1, np.prod(features.size()[1:])) + # classify and return + return self._classifier(features) + +def train(model, train_loader, epoch, optimizer, criterion, log_interval=200): + # Set model to training mode + model.train() + + # Loop over each batch from the training set + for batch_idx, (data, target) in enumerate(train_loader): + # Zero gradient buffers + optimizer.zero_grad() + + # Pass data through the network + output = model(data) + + # Calculate loss + loss = criterion(output, target) + + # Backpropagate + loss.backward() + + # Update weights + optimizer.step() + + if batch_idx % log_interval == 0: + print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( + epoch, batch_idx * len(data), len(train_loader.dataset), + 100. * batch_idx / len(train_loader), loss.data.item())) + +def validate(model, validation_loader, criterion): + model.eval() + #summary(model, (1, 28, 28)) + val_loss, correct = 0, 0 + for data, target in validation_loader: + output = model(data) + val_loss += criterion(output, target).data.item() + pred = output.data.max(1)[1] # get the index of the max log-probability + correct += pred.eq(target.data).cpu().sum() + + val_loss /= len(validation_loader) + + accuracy = 100. * correct.to(torch.float32) / len(validation_loader.dataset) + + print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format( + val_loss, correct, len(validation_loader.dataset), accuracy)) + + return accuracy + + +def main(): + parser = argparse.ArgumentParser(description='PyTorch MNIST Example') + parser.add_argument('--batch-size', type=int, default=20, metavar='N', + help='input batch size for training (default: 32)') + parser.add_argument('--epochs', type=int, default=15, metavar='N', + help='number of epochs to train (default: 15)') + parser.add_argument('--test', action='store_true', default=False, + help='test the model Best_mnist_MLP.pt if it exists') + args = parser.parse_args() + + folder_path = os.path.dirname(os.path.abspath(__file__)) + data_path = os.path.join(folder_path, "data") + model_path = os.path.join(folder_path, "Best_mnist_LeNet.pt") + onnx_path = os.path.join(folder_path, "LeNet_MNIST.onnx") + digit_path = os.path.join(folder_path, "digit") + output_path = os.path.join(folder_path, "output_digit") + + trf=transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + + train_dataset = datasets.MNIST(data_path, + train=True, + download=True, + transform=trf) + + validation_dataset = datasets.MNIST(data_path, + train=False, + transform=trf) + + train_loader = torch.utils.data.DataLoader(dataset=train_dataset, + batch_size=args.batch_size, + shuffle=True) + + validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset, + batch_size=args.batch_size, + shuffle=False) + + + + ####### TRAIN ######## + model = LeNet().cpu() + optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5) + criterion = nn.CrossEntropyLoss() + + if args.test: + os.path.isfile(model_path) + model.load_state_dict(torch.load(model_path)) + validate(model, validation_loader, criterion) + else: + best_acc = 0 + for epoch in range(1, args.epochs + 1): + train(model, train_loader, epoch, optimizer, criterion) + acc = validate(model, validation_loader, criterion) + if acc > best_acc: + best_acc = acc + print('New best accuracy : ', best_acc) + torch.save(model.state_dict(), model_path) + print('-------------- Best model saved --------------\n') + + + # Find one digit correctly predicted + not_found = True + i=0 + while not_found: + if i > train_dataset.__len__(): + raise RuntimeError('No correctly predicted digits') + x, t = train_dataset.__getitem__(i) + # print('Input tensor size before unsqueeze : ', x.size()) + x= torch.unsqueeze(x,0) + # print('Input tensor size after unsqueeze : ', x.size()) + out = model(x) + pred = out.data.max(1)[1] # get the index of the max + not_found = pred.eq(t) + i+=1 + + # Save digit & the model output + x,_ = train_dataset.__getitem__(i) + + model.load_state_dict(torch.load(model_path)) + + x= torch.unsqueeze(x,0) + output = model(x) + + np.save(digit_path, x) + # np.save(digit_path, np.expand_dims(np.expand_dims(x.numpy(),0),0)) + np.save(output_path, output.detach().numpy()) + + ####### EXPORT ONNX ######## + torch.onnx.export(model, x, onnx_path, verbose=True, input_names=[ "actual_input" ], output_names=[ "output" ]) + + +if __name__ == '__main__': + main() + +