diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py index d09a7474258940b6a648ff87b0157425890e335b..310d0ff8a62b83d19497574dad8a0cadef08e942 100644 --- a/aidge_core/export_utils/node_export.py +++ b/aidge_core/export_utils/node_export.py @@ -9,83 +9,79 @@ def get_chan(tensor: aidge_core.Tensor) -> int: """ dformat = tensor.dformat() dims = tensor.dims() - if len(dims) == 4: # Suppose NCHW + if dformat == aidge_core.dformat.Default: + if len(dims) == 4: # Suppose NCHW + return dims[1] + elif len(dims) == 2: # Suppose NC + return dims[1] + else: + return None + elif dformat == aidge_core.dformat.NCHW: return dims[1] - elif len(dims) == 2: # Suppose NC + elif dformat == aidge_core.dformat.NHWC: + return dims[3] + elif dformat == aidge_core.dformat.CHWN: + return dims[0] + elif dformat == aidge_core.dformat.NCDHW: return dims[1] + elif dformat == aidge_core.dformat.NDHWC: + return dims[4] + elif dformat == aidge_core.dformat.CDHWN: + return dims[0] else: - return None - # if dformat == aidge_core.dformat.Default: - # return None - # elif dformat == aidge_core.dformat.NCHW: - # return dims[1] - # elif dformat == aidge_core.dformat.NHWC: - # return dims[3] - # elif dformat == aidge_core.dformat.CHWN: - # return dims[0] - # elif dformat == aidge_core.dformat.NCDHW: - # return dims[1] - # elif dformat == aidge_core.dformat.NDHWC: - # return dims[4] - # elif dformat == aidge_core.dformat.CDHWN: - # return dims[0] - # else: - # raise RuntimeError(f"Unknown dataformat: {dformat}") + raise RuntimeError(f"Unknown dataformat: {dformat}") def get_height(tensor: aidge_core.Tensor) -> int: dformat = tensor.dformat() dims = tensor.dims() - if len(dims) == 4: # Suppose NCHW + if dformat == aidge_core.dformat.Default: + if len(dims) == 4: # Suppose NCHW + return dims[2] + elif len(dims) == 2: # Suppose NC + return 1 + else: + return None + elif dformat == aidge_core.dformat.NCHW: + return dims[2] + elif dformat == aidge_core.dformat.NHWC: + return dims[1] + elif dformat == aidge_core.dformat.CHWN: + return dims[1] + elif dformat == aidge_core.dformat.NCDHW: + return dims[3] + elif dformat == aidge_core.dformat.NDHWC: + return dims[2] + elif dformat == aidge_core.dformat.CDHWN: return dims[2] - elif len(dims) == 2: # Suppose NC - return 1 else: - return None - # TODO: use when dformat is fully supported - # if dformat == aidge_core.dformat.Default: - # return None - # elif dformat == aidge_core.dformat.NCHW: - # return dims[2] - # elif dformat == aidge_core.dformat.NHWC: - # return dims[1] - # elif dformat == aidge_core.dformat.CHWN: - # return dims[1] - # elif dformat == aidge_core.dformat.NCDHW: - # return dims[3] - # elif dformat == aidge_core.dformat.NDHWC: - # return dims[2] - # elif dformat == aidge_core.dformat.CDHWN: - # return dims[2] - # else: - # raise RuntimeError(f"Unknown dataformat: {dformat}") + raise RuntimeError(f"Unknown dataformat: {dformat}") def get_width(tensor: aidge_core.Tensor) -> int: dformat = tensor.dformat() dims = tensor.dims() - if len(dims) == 4: # Suppose NCHW + if dformat == aidge_core.dformat.Default: + if len(dims) == 4: # Suppose NCHW + return dims[3] + elif len(dims) == 2: # Suppose NC + return 1 + else: + return None + elif dformat == aidge_core.dformat.NCHW: + return dims[3] + elif dformat == aidge_core.dformat.NHWC: + return dims[2] + elif dformat == aidge_core.dformat.CHWN: + return dims[2] + elif dformat == aidge_core.dformat.NCDHW: + return dims[4] + elif dformat == aidge_core.dformat.NDHWC: + return dims[3] + elif dformat == aidge_core.dformat.CDHWN: return dims[3] - elif len(dims) == 2: # Suppose NC - return 1 else: - return None - # if dformat == aidge_core.dformat.Default: - # return None - # elif dformat == aidge_core.dformat.NCHW: - # return dims[3] - # elif dformat == aidge_core.dformat.NHWC: - # return dims[2] - # elif dformat == aidge_core.dformat.CHWN: - # return dims[2] - # elif dformat == aidge_core.dformat.NCDHW: - # return dims[4] - # elif dformat == aidge_core.dformat.NDHWC: - # return dims[3] - # elif dformat == aidge_core.dformat.CDHWN: - # return dims[3] - # else: - # raise RuntimeError(f"Unknown dataformat: {dformat}") + raise RuntimeError(f"Unknown dataformat: {dformat}") class ExportNode(ABC):