Skip to content
Snippets Groups Projects
Commit 9885ceaf authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add back data format when chosing how to retrieve nb chan ...

parent bd723d05
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
Pipeline #50796 failed
......@@ -9,83 +9,79 @@ def get_chan(tensor: aidge_core.Tensor) -> int:
"""
dformat = tensor.dformat()
dims = tensor.dims()
if len(dims) == 4: # Suppose NCHW
if dformat == aidge_core.dformat.Default:
if len(dims) == 4: # Suppose NCHW
return dims[1]
elif len(dims) == 2: # Suppose NC
return dims[1]
else:
return None
elif dformat == aidge_core.dformat.NCHW:
return dims[1]
elif len(dims) == 2: # Suppose NC
elif dformat == aidge_core.dformat.NHWC:
return dims[3]
elif dformat == aidge_core.dformat.CHWN:
return dims[0]
elif dformat == aidge_core.dformat.NCDHW:
return dims[1]
elif dformat == aidge_core.dformat.NDHWC:
return dims[4]
elif dformat == aidge_core.dformat.CDHWN:
return dims[0]
else:
return None
# if dformat == aidge_core.dformat.Default:
# return None
# elif dformat == aidge_core.dformat.NCHW:
# return dims[1]
# elif dformat == aidge_core.dformat.NHWC:
# return dims[3]
# elif dformat == aidge_core.dformat.CHWN:
# return dims[0]
# elif dformat == aidge_core.dformat.NCDHW:
# return dims[1]
# elif dformat == aidge_core.dformat.NDHWC:
# return dims[4]
# elif dformat == aidge_core.dformat.CDHWN:
# return dims[0]
# else:
# raise RuntimeError(f"Unknown dataformat: {dformat}")
raise RuntimeError(f"Unknown dataformat: {dformat}")
def get_height(tensor: aidge_core.Tensor) -> int:
dformat = tensor.dformat()
dims = tensor.dims()
if len(dims) == 4: # Suppose NCHW
if dformat == aidge_core.dformat.Default:
if len(dims) == 4: # Suppose NCHW
return dims[2]
elif len(dims) == 2: # Suppose NC
return 1
else:
return None
elif dformat == aidge_core.dformat.NCHW:
return dims[2]
elif dformat == aidge_core.dformat.NHWC:
return dims[1]
elif dformat == aidge_core.dformat.CHWN:
return dims[1]
elif dformat == aidge_core.dformat.NCDHW:
return dims[3]
elif dformat == aidge_core.dformat.NDHWC:
return dims[2]
elif dformat == aidge_core.dformat.CDHWN:
return dims[2]
elif len(dims) == 2: # Suppose NC
return 1
else:
return None
# TODO: use when dformat is fully supported
# if dformat == aidge_core.dformat.Default:
# return None
# elif dformat == aidge_core.dformat.NCHW:
# return dims[2]
# elif dformat == aidge_core.dformat.NHWC:
# return dims[1]
# elif dformat == aidge_core.dformat.CHWN:
# return dims[1]
# elif dformat == aidge_core.dformat.NCDHW:
# return dims[3]
# elif dformat == aidge_core.dformat.NDHWC:
# return dims[2]
# elif dformat == aidge_core.dformat.CDHWN:
# return dims[2]
# else:
# raise RuntimeError(f"Unknown dataformat: {dformat}")
raise RuntimeError(f"Unknown dataformat: {dformat}")
def get_width(tensor: aidge_core.Tensor) -> int:
dformat = tensor.dformat()
dims = tensor.dims()
if len(dims) == 4: # Suppose NCHW
if dformat == aidge_core.dformat.Default:
if len(dims) == 4: # Suppose NCHW
return dims[3]
elif len(dims) == 2: # Suppose NC
return 1
else:
return None
elif dformat == aidge_core.dformat.NCHW:
return dims[3]
elif dformat == aidge_core.dformat.NHWC:
return dims[2]
elif dformat == aidge_core.dformat.CHWN:
return dims[2]
elif dformat == aidge_core.dformat.NCDHW:
return dims[4]
elif dformat == aidge_core.dformat.NDHWC:
return dims[3]
elif dformat == aidge_core.dformat.CDHWN:
return dims[3]
elif len(dims) == 2: # Suppose NC
return 1
else:
return None
# if dformat == aidge_core.dformat.Default:
# return None
# elif dformat == aidge_core.dformat.NCHW:
# return dims[3]
# elif dformat == aidge_core.dformat.NHWC:
# return dims[2]
# elif dformat == aidge_core.dformat.CHWN:
# return dims[2]
# elif dformat == aidge_core.dformat.NCDHW:
# return dims[4]
# elif dformat == aidge_core.dformat.NDHWC:
# return dims[3]
# elif dformat == aidge_core.dformat.CDHWN:
# return dims[3]
# else:
# raise RuntimeError(f"Unknown dataformat: {dformat}")
raise RuntimeError(f"Unknown dataformat: {dformat}")
class ExportNode(ABC):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment