Skip to content
Snippets Groups Projects

feat/release_pip

Closed Grégoire Kubler requested to merge feat/release_pip into dev
Files
8
@@ -98,7 +98,6 @@ class AidgeModule(torch.nn.Module):
@@ -98,7 +98,6 @@ class AidgeModule(torch.nn.Module):
self.input_nodes = [None]
self.input_nodes = [None]
self.scheduler = None
self.scheduler = None
self.optimizer = None
self.optimizer = None
self.grad_compiled = False
def set_optimizer(self, opt):
def set_optimizer(self, opt):
self.optimizer = opt
self.optimizer = opt
@@ -118,13 +117,13 @@ class AidgeModule(torch.nn.Module):
@@ -118,13 +117,13 @@ class AidgeModule(torch.nn.Module):
# TODO: add a system to avoid creating a new node everytime
# TODO: add a system to avoid creating a new node everytime
if self.input_nodes[0] == None:
if self.input_nodes[0] is None:
self.input_nodes[0] = aidge_core.Producer(
self.input_nodes[0] = aidge_core.Producer(
aidge_tensor, "Input_0")
aidge_tensor, "Input_0")
# TODO: get datatype & backend from graph view
# TODO: get datatype & backend from graph view
self.input_nodes[0].get_operator().set_datatype(
self.input_nodes[0].get_operator().set_datatype(
aidge_core.DataType.Float32)
aidge_core.dtype.float32)
self.input_nodes[0].get_operator().set_backend("cpu")
self.input_nodes[0].get_operator().set_backend("cpu")
self.input_nodes[0].add_child(self._graph_view)
self.input_nodes[0].add_child(self._graph_view)
@@ -151,8 +150,6 @@ class AidgeModule(torch.nn.Module):
@@ -151,8 +150,6 @@ class AidgeModule(torch.nn.Module):
@staticmethod
@staticmethod
def backward(ctx, grad_output):
def backward(ctx, grad_output):
if not self.grad_compiled: aidge_core.compile_gradient(self._graph_view)
if self.multi_outputs_flag:
if self.multi_outputs_flag:
raise RuntimeError(
raise RuntimeError(
"Backward is not possible if the model has multi-outputs")
"Backward is not possible if the model has multi-outputs")
@@ -264,7 +261,7 @@ class ContextNoBatchNormFuse:
@@ -264,7 +261,7 @@ class ContextNoBatchNormFuse:
def wrap(torch_model: torch.nn.Module,
def wrap(torch_model: torch.nn.Module,
input_size: Union[list, tuple],
input_size: Union[list, tuple],
opset_version: int = 11,
opset_version: int = 15,
in_names: list = None,
in_names: list = None,
out_names: list = None,
out_names: list = None,
verbose: bool = False) -> AidgeModule:
verbose: bool = False) -> AidgeModule:
Loading