diff --git a/.gitignore b/.gitignore index f37378e300efeb5362882eb8d6eb59f028563a0e..ff07895b0e4e52c6c6a21ae0984fd3341ab972f7 100644 --- a/.gitignore +++ b/.gitignore @@ -19,4 +19,7 @@ dist*/ xml*/ # ONNX -*.onnx \ No newline at end of file +*.onnx + +# log +*.log diff --git a/.gitlab/ci/build.gitlab-ci.yml b/.gitlab/ci/build.gitlab-ci.yml index d8472ca464e09814273ce13cdf340a92d2d7004c..baea3ca70266f1d3fcd489824da681106c85b73c 100644 --- a/.gitlab/ci/build.gitlab-ci.yml +++ b/.gitlab/ci/build.gitlab-ci.yml @@ -23,9 +23,15 @@ build:ubuntu_python: - DEPENDENCY_NAME="aidge_backend_cpu" - !reference [.download_dependency, script] + # aidge_learning + - DEPENDENCY_NAME="aidge_learning" + - !reference [.download_dependency, script] + + - python3 -m pip install virtualenv - virtualenv venv - source venv/bin/activate + - python3 -m pip install numpy coverage onnxruntime # used for tests - python3 -m pip install -r requirements.txt - python3 -m pip install . -v artifacts: diff --git a/.gitlab/ci/coverage.gitlab-ci.yml b/.gitlab/ci/coverage.gitlab-ci.yml index e3f06d0f592543d2b0661ee829674c908e90d030..88dcd3bea4cbd10e72fd0e3c3fa47fd2b67b4236 100644 --- a/.gitlab/ci/coverage.gitlab-ci.yml +++ b/.gitlab/ci/coverage.gitlab-ci.yml @@ -5,10 +5,10 @@ coverage:ubuntu_python: - docker script: - source venv/bin/activate - - python3 -m pip install numpy coverage - cd ${CI_PROJECT_NAME} # Retrieve the installation path of the module, since it is installed with pip. - export MODULE_LOCATION=`python -c "import ${CI_PROJECT_NAME} as _; print(_.__path__[0])"` + - echo $MODULE_LOCATION - python3 -m coverage run --source=$MODULE_LOCATION -m unittest discover -s unit_tests/ -v -b - python3 -m coverage report - python3 -m coverage xml diff --git a/aidge_interop_torch/unit_tests/test_pytorch.py b/aidge_interop_torch/unit_tests/test_pytorch.py index 30ff4d4fe7f846062abb1e5212f70b5eb21aa4b8..463b9ef5ec0207011caad8adaeeb050e1e868296 100755 --- a/aidge_interop_torch/unit_tests/test_pytorch.py +++ b/aidge_interop_torch/unit_tests/test_pytorch.py @@ -1,19 +1,21 @@ import torch import aidge_interop_torch import aidge_core +import aidge_learning import unittest import numpy as np torch.backends.cudnn.enabled = True torch.backends.cudnn.benchmark = True +torch.set_printoptions(precision=7) weight_value = 0.05 -batch_size = 10 +batch_size = 2 learning_rate = 0.01 comparison_precision = 0.001 absolute_presision = 0.0001 epochs = 10 - +# aidge_core.Log.set_console_level(aidge_core.Level.Debug) # TODO : add tensor test later ... # class test_tensor_conversion(unittest.TestCase): @@ -44,7 +46,7 @@ class Test_Networks(): def __init__(self, model1, model2, name="", test_backward=True, eval_mode=False, epochs=10, relative_precision=0.001, absolute_presision=0.0001, learning_rate=0.01, cuda=False): self.relative_precision = relative_precision - self.absolute_presision = absolute_presision + self.absolute_precision = absolute_presision self.epochs = epochs self.test_backward = test_backward self.model1 = model1 @@ -60,7 +62,6 @@ class Test_Networks(): self.model1.train() self.model2.train() self.name = name - if self.test_backward: self.optimizer1 = torch.optim.SGD( self.model1.parameters(), lr=learning_rate) @@ -74,14 +75,20 @@ class Test_Networks(): i = i.item() j = j.item() if j != 0: - if abs(i-j) > self.relative_precision * abs(j) + self.absolute_presision: + if abs(i-j) > self.relative_precision * abs(j) + self.absolute_precision: + return -1 + elif i != 0: + if abs(i-j) > self.relative_precision * abs(i) + self.absolute_precision: return -1 return 0 def unit_test(self, input_tensor, label): torch_tensor1 = input_tensor torch_tensor2 = input_tensor.detach().clone() - print(torch_tensor1) + + torch_tensor1.requires_grad = True + torch_tensor2.requires_grad = True + if self.test_backward: label1 = label label2 = label.detach().clone() @@ -104,11 +111,23 @@ class Test_Networks(): loss2 = self.criterion2(output2, label2) self.optimizer2.zero_grad() + loss2.backward() self.optimizer2.step() + NaN_flag = False + if loss1.isnan(): + print("Loss1 is NaN") + NaN_flag = True + if loss2.isnan(): + print("Loss2 is NaN") + NaN_flag = True + if NaN_flag: return -1 if self.compare_tensor(loss1, loss2): print("Different loss : ", loss1.item(), "|", loss2.item()) return -1 + if self.compare_tensor(torch_tensor1.grad, torch_tensor2.grad): + print(f"Different input gradient:\nGrad1\n{torch_tensor1.grad}\nGrad2\n{torch_tensor2.grad}\n") + return -1 return 0 def test_multiple_step(self, input_size, label_size): @@ -135,8 +154,8 @@ class TorchLeNet(torch.nn.Module): c1 = torch.nn.Conv2d(1, 6, 5, bias=False) c2 = torch.nn.Conv2d(6, 16, 5, bias=False) c3 = torch.nn.Conv2d(16, 120, 5, bias=False) - l1 = torch.nn.Linear(120, 84) - l2 = torch.nn.Linear(84, 10) + l1 = torch.nn.Linear(120, 84, bias=True) + l2 = torch.nn.Linear(84, 10, bias=True) torch.nn.init.constant_(c1.weight, weight_value) torch.nn.init.constant_(c2.weight, weight_value) @@ -171,6 +190,46 @@ class TorchLeNet(torch.nn.Module): return x +class Easy_graph(torch.nn.Module): + + def __init__(self): + super(Easy_graph, self).__init__() + + self.layer = torch.nn.Sequential( + torch.nn.Flatten(), + torch.nn.Linear(3, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 4), + torch.nn.ReLU(), + torch.nn.Linear(4, 4) + ) + + def forward(self, x): + x = self.layer(x) + return x +# class Easy_graph(torch.nn.Module): + +# def __init__(self): +# super(Easy_graph, self).__init__() + +# self.layer = torch.nn.Sequential( +# torch.nn.Flatten(), +# torch.nn.Linear(32*32*3, 512), +# torch.nn.ReLU(), +# torch.nn.Linear(512, 256), +# torch.nn.ReLU(), +# torch.nn.Linear(256, 128), +# torch.nn.ReLU(), +# torch.nn.Linear(128, 10) +# ) + +# def forward(self, x): +# x = self.layer(x) +# return x + + + + class test_interop(unittest.TestCase): def tearDown(self): @@ -193,6 +252,24 @@ class test_interop(unittest.TestCase): res = tester.test_multiple_step(input_size, (batch_size, 10)) self.assertNotEqual(res, -1, msg="LeNet CPU eval failed") + def test_aidge_backward_CPU(self): + print('=== Testing aidge backward CPU ===') + input_size = (batch_size, 3, 1, 1) + torch_model = Easy_graph() + + aidge_model = aidge_interop_torch.wrap(torch_model, input_size) + + opt = aidge_learning.SGD() + lrs = aidge_learning.constant_lr(0.01) + opt.set_learning_rate_scheduler(lrs) + opt.set_parameters(list(aidge_core.producers(aidge_model._graph_view))) + aidge_model.set_optimizer(opt) + + tester = Test_Networks(torch_model, aidge_model, eval_mode=False, epochs=epochs) + res = tester.test_multiple_step(input_size, (batch_size, 4)) + + self.assertNotEqual(res, -1, msg="CPU train failed") + if __name__ == '__main__': unittest.main() diff --git a/aidge_interop_torch/utils.py b/aidge_interop_torch/utils.py index b1c31749afdef9845ea9a6e44a592285844ee3b7..a685f57a594b1cca6e67f15c33dd2e73aa2fe3e5 100644 --- a/aidge_interop_torch/utils.py +++ b/aidge_interop_torch/utils.py @@ -6,12 +6,20 @@ import aidge_core import aidge_backend_cpu import aidge_onnx +import aidge_learning from onnxsim import simplify from typing import Union def convert_tensor(tensor): + """Convert a torch tensor to :py:class:`aidge_core.Tensor` and vice versa. + + :param tensor: Tensor to convert. + :type tensor: torch.Tensor or :py:class:`aidge_core.Tensor` + :return: Converted tensor + :rtype: torch.Tensor or :py:class:`aidge_core.Tensor` + """ if isinstance(tensor, torch.Tensor): return torch_tensor_to_aidge(tensor) elif isinstance(tensor, aidge_core.Tensor): @@ -21,11 +29,15 @@ def convert_tensor(tensor): f"Object of type {type(tensor)} is not convertible.") -def aidge_tensor_to_torch(aidge_tensor): - """ - Convert Aidge.Tensor -> torch.Tensor - The conversion creates a GPU memory copy if the tensor is CUDA. - This method also convert the shape of the tensor to follow torch convention. +def aidge_tensor_to_torch(aidge_tensor: aidge_core.Tensor) -> torch.Tensor: + """Convert:py:class:`aidge_core.Tensor` -> torch.Tensor + + CUDA tensor is not handled yet, but when it will, the conversion will creates a GPU memory copy if the tensor is CUDA. + + :param aidge_tensor: Tensor to convert + :type aidge_tensor: aidge_core.Tensor + :return: Converted tensor + :rtype: torch.Tensor """ torch_tensor = torch.from_numpy(np.array(aidge_tensor)) # TODO : handle cuda case ! @@ -34,7 +46,16 @@ def aidge_tensor_to_torch(aidge_tensor): return torch_tensor -def torch_tensor_to_aidge(torch_tensor): +def torch_tensor_to_aidge(torch_tensor: torch.Tensor) -> aidge_core.Tensor: + """Convert torch.Tensor -> :py:class:`aidge_core.Tensor` + + CUDA tensor is not handled yet, but when it will, the conversion will not creates a GPU memory copy if the tensor is CUDA. + + :param torch_tensor: Tensor to convert + :type torch_tensor: torch.Tensor + :return: Converted tensor + :rtype: aidge_core.Tensor + """ aidge_tensor = None numpy_tensor = torch_tensor.cpu().detach().numpy() # This operation creates a CPU memory copy. @@ -53,10 +74,9 @@ class AidgeModule(torch.nn.Module): def __init__(self, graph_view, batch_size=None): """ - :param block: Aidge block object to interface with PyTorch - :type block: :py:class:`aidge_core.GraphView` + :param graph_view: Aidge block object to interface with PyTorch + :type graph_view: :py:class:`aidge_core.GraphView` """ - print("TEST") super().__init__() if not isinstance(graph_view, aidge_core.GraphView): raise TypeError( @@ -65,7 +85,6 @@ class AidgeModule(torch.nn.Module): # TODO : better handling of backend ? # maybe a function set_backend similar to PyTorch ? - print("Set cpu !") graph_view.set_backend("cpu") # We need to add a random parameter to the module else pytorch refuse to compute gradient @@ -78,6 +97,11 @@ class AidgeModule(torch.nn.Module): # TODO support of multi input graph later ... self.input_nodes = [None] self.scheduler = None + self.optimizer = None + self.grad_compiled = False + + def set_optimizer(self, opt): + self.optimizer = opt def forward(self, inputs: torch.Tensor): """ @@ -108,14 +132,13 @@ class AidgeModule(torch.nn.Module): else: self.input_nodes[0].get_operator( ).set_output(0, aidge_tensor) - # TODO: create one scheduler ! maybe in init ? if not self.scheduler: self.scheduler = aidge_core.SequentialScheduler( self._graph_view) # Run inference ! - self.scheduler.forward(verbose=True) + self.scheduler.forward(forward_dims=True) # TODO: support for multi output later ? if len(self._graph_view.get_output_nodes()) > 1: @@ -128,7 +151,34 @@ class AidgeModule(torch.nn.Module): @staticmethod def backward(ctx, grad_output): - raise RuntimeError("Backward is not yet supported with Aidge.") + if not self.grad_compiled: aidge_core.compile_gradient(self._graph_view) + + if self.multi_outputs_flag: + raise RuntimeError( + "Backward is not possible if the model has multi-outputs") + if len(self.input_nodes) != 1: + raise RuntimeError( + "Multi-input is not handled for now in pytorch backpropagation") + # convert the output gradient to an AIDGE Tensor + aidge_grad_output = torch_tensor_to_aidge(grad_output) + + if len(self._graph_view.get_output_nodes()) != 1: + RuntimeError( + f"We only support one output got {len(self._graph_view.get_output_nodes())}") + output_node = list(self._graph_view.get_output_nodes())[0] + output_tensor = output_node.get_operator().get_output(0) + output_tensor.set_grad(aidge_grad_output) + + # run the backpropagation + # TODO: remove update from the backprop + self.optimizer.reset_grad() + self.scheduler.backward() + self.optimizer.update() + # get grad of first layer no handling of multi input + aidge_out_grad = self.input_nodes[0].get_operator().get_output(0).grad() + # convert grad to torch + torch_out_grad = aidge_tensor_to_torch(aidge_out_grad) + return torch_out_grad # If the layer is at the beginning of the network requires grad is False. inputs.requires_grad = True @@ -174,7 +224,6 @@ class ContextNoBatchNormFuse: if current_bn.affine: # Bias and Weights only created if affine=True saved_bias = current_bn.bias.detach().clone() saved_weight = current_bn.weight.detach().clone() - print(current_bn.running_mean.shape) # Real batchnorm forward output_tensor = torch.nn.functional.batch_norm( inputs, @@ -244,7 +293,8 @@ def wrap(torch_model: torch.nn.Module, try: torch_device = next(torch_model.parameters()).device except StopIteration: - torch_device = torch.device('cpu') # Model has no parameter, defaulting to cpu + # Model has no parameter, defaulting to cpu + torch_device = torch.device('cpu') dummy_in = torch.zeros(input_size).to(torch_device)