Skip to content
Snippets Groups Projects
Commit 2f7d46f1 authored by Inna Kucher's avatar Inna Kucher
Browse files

adding LeNet generation

parent 9f1ae844
No related branches found
No related tags found
1 merge request!1Adding LeNet generation scripts
#!/bin/bash
script_directory="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )"
echo "Script directory: $script_directory"
ONNX="$script_directory/LeNet_MNIST.onnx"
DIGIT="$script_directory/digit_lenet.npy"
OUTPUT="$script_directory/output_digit_lenet.npy"
if [ ! -e "$ONNX" ] || [ ! -e "$DIGIT" ] || [ ! -e "$OUTPUT" ]
then
virtualenv -p python3.8 "$script_directory/py3_8"
source "$script_directory/py3_8/bin/activate"
pip3 install --quiet -U torch torchvision --index-url https://download.pytorch.org/whl/cpu
pip install --quiet -U onnx
pip install --quiet -U torchsummary
python ./torch_LeNet.py --epoch 1
deactivate
rm -r "$script_directory/py3_8"
rm -r "$script_directory/data"
else
echo "$ONNX $DIGIT $OUTPUT exist."
fi
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import datasets, transforms
from torchsummary import summary
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import onnx
import argparse
import os
class LeNet(torch.nn.Module):
def __init__(self):
super(LeNet, self).__init__()
# feature extractor CNN
self._feature_extractor = torch.nn.Sequential(
torch.nn.Conv2d(1,6,5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2,2),
torch.nn.Conv2d(6,16,5),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2,2) )
# classifier MLP
self._classifier = torch.nn.Sequential(
torch.nn.Linear(256,120),
torch.nn.ReLU(),
torch.nn.Linear(120,84),
torch.nn.ReLU(),
torch.nn.Linear(84,10) )
def forward(self, x):
# extract features
features = self._feature_extractor(x)
# flatten the 3d tensor (2d space x channels = features)
features = torch.flatten(features, start_dim=1)
#features = features.view(-1, np.prod(features.size()[1:]))
# classify and return
return self._classifier(features)
def train(model, train_loader, epoch, optimizer, criterion, log_interval=200):
# Set model to training mode
model.train()
# Loop over each batch from the training set
for batch_idx, (data, target) in enumerate(train_loader):
# Zero gradient buffers
optimizer.zero_grad()
# Pass data through the network
output = model(data)
# Calculate loss
loss = criterion(output, target)
# Backpropagate
loss.backward()
# Update weights
optimizer.step()
if batch_idx % log_interval == 0:
print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
epoch, batch_idx * len(data), len(train_loader.dataset),
100. * batch_idx / len(train_loader), loss.data.item()))
def validate(model, validation_loader, criterion):
model.eval()
#summary(model, (1, 28, 28))
val_loss, correct = 0, 0
for data, target in validation_loader:
output = model(data)
val_loss += criterion(output, target).data.item()
pred = output.data.max(1)[1] # get the index of the max log-probability
correct += pred.eq(target.data).cpu().sum()
val_loss /= len(validation_loader)
accuracy = 100. * correct.to(torch.float32) / len(validation_loader.dataset)
print('\nValidation set: Average loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(
val_loss, correct, len(validation_loader.dataset), accuracy))
return accuracy
def main():
parser = argparse.ArgumentParser(description='PyTorch MNIST Example')
parser.add_argument('--batch-size', type=int, default=20, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--epochs', type=int, default=15, metavar='N',
help='number of epochs to train (default: 15)')
parser.add_argument('--test', action='store_true', default=False,
help='test the model Best_mnist_MLP.pt if it exists')
args = parser.parse_args()
folder_path = os.path.dirname(os.path.abspath(__file__))
data_path = os.path.join(folder_path, "data")
model_path = os.path.join(folder_path, "Best_mnist_LeNet.pt")
onnx_path = os.path.join(folder_path, "LeNet_MNIST.onnx")
digit_path = os.path.join(folder_path, "digit")
output_path = os.path.join(folder_path, "output_digit")
trf=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
])
train_dataset = datasets.MNIST(data_path,
train=True,
download=True,
transform=trf)
validation_dataset = datasets.MNIST(data_path,
train=False,
transform=trf)
train_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=args.batch_size,
shuffle=True)
validation_loader = torch.utils.data.DataLoader(dataset=validation_dataset,
batch_size=args.batch_size,
shuffle=False)
####### TRAIN ########
model = LeNet().cpu()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.5)
criterion = nn.CrossEntropyLoss()
if args.test:
os.path.isfile(model_path)
model.load_state_dict(torch.load(model_path))
validate(model, validation_loader, criterion)
else:
best_acc = 0
for epoch in range(1, args.epochs + 1):
train(model, train_loader, epoch, optimizer, criterion)
acc = validate(model, validation_loader, criterion)
if acc > best_acc:
best_acc = acc
print('New best accuracy : ', best_acc)
torch.save(model.state_dict(), model_path)
print('-------------- Best model saved --------------\n')
# Find one digit correctly predicted
not_found = True
i=0
while not_found:
if i > train_dataset.__len__():
raise RuntimeError('No correctly predicted digits')
x, t = train_dataset.__getitem__(i)
# print('Input tensor size before unsqueeze : ', x.size())
x= torch.unsqueeze(x,0)
# print('Input tensor size after unsqueeze : ', x.size())
out = model(x)
pred = out.data.max(1)[1] # get the index of the max
not_found = pred.eq(t)
i+=1
# Save digit & the model output
x,_ = train_dataset.__getitem__(i)
model.load_state_dict(torch.load(model_path))
x= torch.unsqueeze(x,0)
output = model(x)
np.save(digit_path, x)
# np.save(digit_path, np.expand_dims(np.expand_dims(x.numpy(),0),0))
np.save(output_path, output.detach().numpy())
####### EXPORT ONNX ########
torch.onnx.export(model, x, onnx_path, verbose=True, input_names=[ "actual_input" ], output_names=[ "output" ])
if __name__ == '__main__':
main()
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment