diff --git a/aidge_onnx/node_import/onnx_converters/reshape.py b/aidge_onnx/node_import/onnx_converters/reshape.py index e0eb7b195b22222b767a629b82cf9b5896e9b4b4..636d875bdb520dfaec62696f2c9b8b813a50d831 100644 --- a/aidge_onnx/node_import/onnx_converters/reshape.py +++ b/aidge_onnx/node_import/onnx_converters/reshape.py @@ -26,10 +26,14 @@ def import_reshape(onnx_node:onnx.NodeProto, input_nodes:List[Tuple[aidge_core.N attrs = {attr.name : attr for attr in onnx_node.attribute} shape = [] - if opset < 5: - if 'shape' in attrs: - shape = attrs['shape'].ints - del attrs['shape'] + if 'shape' in attrs: + shape = attrs['shape'].ints + del attrs['shape'] + + shape_tensor = aidge_core.Tensor(shape) + shape_node = aidge_core.Producer(shape_tensor, "shape") + intput_node = (shape_node, 0) + input_nodes.append(intput_node) if len(attrs) > 0: print(f"Warning: unsupported attribute(s): {attrs.keys()} for operator transpose.") @@ -47,5 +51,6 @@ def import_reshape(onnx_node:onnx.NodeProto, input_nodes:List[Tuple[aidge_core.N print(f"- {node_name} ({onnx_node.op_type})") return my_node else: - print(f"warning, bad shape initialization") - return None \ No newline at end of file + my_node = aidge_core.Reshape(name=node_name) + print(f"- {node_name} ({onnx_node.op_type})") + return my_node \ No newline at end of file