diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index a9190b459536bf0b1d78abd89b7f2a3003ca72d0..7a32864001f926653d062fe7672bf2c53271e805 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -11,24 +11,24 @@ def get_chan(tensor: aidge_core.Tensor) -> int: """ dformat = tensor.dformat() dims = tensor.dims() - if dformat == aidge_core.dformat.Default: + if dformat == aidge_core.dformat.default: if len(dims) == 4: # Suppose NCHW return dims[1] elif len(dims) == 2: # Suppose NC return dims[1] else: return None - elif dformat == aidge_core.dformat.NCHW: + elif dformat == aidge_core.dformat.nchw: return dims[1] - elif dformat == aidge_core.dformat.NHWC: + elif dformat == aidge_core.dformat.nhwc: return dims[3] - elif dformat == aidge_core.dformat.CHWN: + elif dformat == aidge_core.dformat.chwn: return dims[0] - elif dformat == aidge_core.dformat.NCDHW: + elif dformat == aidge_core.dformat.ncdhw: return dims[1] - elif dformat == aidge_core.dformat.NDHWC: + elif dformat == aidge_core.dformat.ndhwc: return dims[4] - elif dformat == aidge_core.dformat.CDHWN: + elif dformat == aidge_core.dformat.cdhwn: return dims[0] else: raise RuntimeError(f"Unknown dataformat: {dformat}") @@ -37,24 +37,24 @@ def get_chan(tensor: aidge_core.Tensor) -> int: def get_height(tensor: aidge_core.Tensor) -> int: dformat = tensor.dformat() dims = tensor.dims() - if dformat == aidge_core.dformat.Default: + if dformat == aidge_core.dformat.default: if len(dims) == 4: # Suppose NCHW return dims[2] elif len(dims) == 2: # Suppose NC return 1 else: return None - elif dformat == aidge_core.dformat.NCHW: + elif dformat == aidge_core.dformat.nchw: return dims[2] - elif dformat == aidge_core.dformat.NHWC: + elif dformat == aidge_core.dformat.nhwc: return dims[1] - elif dformat == aidge_core.dformat.CHWN: + elif dformat == aidge_core.dformat.chwn: return dims[1] - elif dformat == aidge_core.dformat.NCDHW: + elif dformat == aidge_core.dformat.ncdhw: return dims[3] - elif dformat == aidge_core.dformat.NDHWC: + elif dformat == aidge_core.dformat.ndhwc: return dims[2] - elif dformat == aidge_core.dformat.CDHWN: + elif dformat == aidge_core.dformat.cdhwn: return dims[2] else: raise RuntimeError(f"Unknown dataformat: {dformat}") @@ -63,24 +63,24 @@ def get_height(tensor: aidge_core.Tensor) -> int: def get_width(tensor: aidge_core.Tensor) -> int: dformat = tensor.dformat() dims = tensor.dims() - if dformat == aidge_core.dformat.Default: + if dformat == aidge_core.dformat.default: if len(dims) == 4: # Suppose NCHW return dims[3] elif len(dims) == 2: # Suppose NC return 1 else: return None - elif dformat == aidge_core.dformat.NCHW: + elif dformat == aidge_core.dformat.nchw: return dims[3] - elif dformat == aidge_core.dformat.NHWC: + elif dformat == aidge_core.dformat.nhwc: return dims[2] - elif dformat == aidge_core.dformat.CHWN: + elif dformat == aidge_core.dformat.chwn: return dims[2] - elif dformat == aidge_core.dformat.NCDHW: + elif dformat == aidge_core.dformat.ncdhw: return dims[4] - elif dformat == aidge_core.dformat.NDHWC: + elif dformat == aidge_core.dformat.ndhwc: return dims[3] - elif dformat == aidge_core.dformat.CDHWN: + elif dformat == aidge_core.dformat.cdhwn: return dims[3] else: raise RuntimeError(f"Unknown dataformat: {dformat}")