From 7ce2a2ef1695c05ffed2acb09741dd98915fbbc0 Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 20 Sep 2024 11:19:06 +0000 Subject: [PATCH] Adapt export node to data_format in lowercase. --- aidge_core/export_utils/node_export.py | 42 +++++++++++++------------- 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index a9190b459..7a3286400 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -11,24 +11,24 @@ def get_chan(tensor: aidge_core.Tensor) -> int: """ dformat = tensor.dformat() dims = tensor.dims() - if dformat == aidge_core.dformat.Default: + if dformat == aidge_core.dformat.default: if len(dims) == 4: # Suppose NCHW return dims[1] elif len(dims) == 2: # Suppose NC return dims[1] else: return None - elif dformat == aidge_core.dformat.NCHW: + elif dformat == aidge_core.dformat.nchw: return dims[1] - elif dformat == aidge_core.dformat.NHWC: + elif dformat == aidge_core.dformat.nhwc: return dims[3] - elif dformat == aidge_core.dformat.CHWN: + elif dformat == aidge_core.dformat.chwn: return dims[0] - elif dformat == aidge_core.dformat.NCDHW: + elif dformat == aidge_core.dformat.ncdhw: return dims[1] - elif dformat == aidge_core.dformat.NDHWC: + elif dformat == aidge_core.dformat.ndhwc: return dims[4] - elif dformat == aidge_core.dformat.CDHWN: + elif dformat == aidge_core.dformat.cdhwn: return dims[0] else: raise RuntimeError(f"Unknown dataformat: {dformat}") @@ -37,24 +37,24 @@ def get_chan(tensor: aidge_core.Tensor) -> int: def get_height(tensor: aidge_core.Tensor) -> int: dformat = tensor.dformat() dims = tensor.dims() - if dformat == aidge_core.dformat.Default: + if dformat == aidge_core.dformat.default: if len(dims) == 4: # Suppose NCHW return dims[2] elif len(dims) == 2: # Suppose NC return 1 else: return None - elif dformat == aidge_core.dformat.NCHW: + elif dformat == aidge_core.dformat.nchw: return dims[2] - elif dformat == aidge_core.dformat.NHWC: + elif dformat == aidge_core.dformat.nhwc: return dims[1] - elif dformat == aidge_core.dformat.CHWN: + elif dformat == aidge_core.dformat.chwn: return dims[1] - elif dformat == aidge_core.dformat.NCDHW: + elif dformat == aidge_core.dformat.ncdhw: return dims[3] - elif dformat == aidge_core.dformat.NDHWC: + elif dformat == aidge_core.dformat.ndhwc: return dims[2] - elif dformat == aidge_core.dformat.CDHWN: + elif dformat == aidge_core.dformat.cdhwn: return dims[2] else: raise RuntimeError(f"Unknown dataformat: {dformat}") @@ -63,24 +63,24 @@ def get_height(tensor: aidge_core.Tensor) -> int: def get_width(tensor: aidge_core.Tensor) -> int: dformat = tensor.dformat() dims = tensor.dims() - if dformat == aidge_core.dformat.Default: + if dformat == aidge_core.dformat.default: if len(dims) == 4: # Suppose NCHW return dims[3] elif len(dims) == 2: # Suppose NC return 1 else: return None - elif dformat == aidge_core.dformat.NCHW: + elif dformat == aidge_core.dformat.nchw: return dims[3] - elif dformat == aidge_core.dformat.NHWC: + elif dformat == aidge_core.dformat.nhwc: return dims[2] - elif dformat == aidge_core.dformat.CHWN: + elif dformat == aidge_core.dformat.chwn: return dims[2] - elif dformat == aidge_core.dformat.NCDHW: + elif dformat == aidge_core.dformat.ncdhw: return dims[4] - elif dformat == aidge_core.dformat.NDHWC: + elif dformat == aidge_core.dformat.ndhwc: return dims[3] - elif dformat == aidge_core.dformat.CDHWN: + elif dformat == aidge_core.dformat.cdhwn: return dims[3] else: raise RuntimeError(f"Unknown dataformat: {dformat}") -- GitLab