From 4453a122db1870b15f3bdfe66db9a56c6d785131 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 23 Apr 2024 14:59:54 +0200 Subject: [PATCH] support input and/or attribute for reshape --- .../node_import/onnx_converters/reshape.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/aidge_onnx/node_import/onnx_converters/reshape.py b/aidge_onnx/node_import/onnx_converters/reshape.py index e0eb7b1..636d875 100644 --- a/aidge_onnx/node_import/onnx_converters/reshape.py +++ b/aidge_onnx/node_import/onnx_converters/reshape.py @@ -26,10 +26,14 @@ def import_reshape(onnx_node:onnx.NodeProto, input_nodes:List[Tuple[aidge_core.N attrs = {attr.name : attr for attr in onnx_node.attribute} shape = [] - if opset < 5: - if 'shape' in attrs: - shape = attrs['shape'].ints - del attrs['shape'] + if 'shape' in attrs: + shape = attrs['shape'].ints + del attrs['shape'] + + shape_tensor = aidge_core.Tensor(shape) + shape_node = aidge_core.Producer(shape_tensor, "shape") + intput_node = (shape_node, 0) + input_nodes.append(intput_node) if len(attrs) > 0: print(f"Warning: unsupported attribute(s): {attrs.keys()} for operator transpose.") @@ -47,5 +51,6 @@ def import_reshape(onnx_node:onnx.NodeProto, input_nodes:List[Tuple[aidge_core.N print(f"- {node_name} ({onnx_node.op_type})") return my_node else: - print(f"warning, bad shape initialization") - return None \ No newline at end of file + my_node = aidge_core.Reshape(name=node_name) + print(f"- {node_name} ({onnx_node.op_type})") + return my_node \ No newline at end of file -- GitLab