Skip to content
Snippets Groups Projects
Commit 7ce2a2ef authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Adapt export node to data_format in lowercase.

parent 38642389
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
...@@ -11,24 +11,24 @@ def get_chan(tensor: aidge_core.Tensor) -> int: ...@@ -11,24 +11,24 @@ def get_chan(tensor: aidge_core.Tensor) -> int:
""" """
dformat = tensor.dformat() dformat = tensor.dformat()
dims = tensor.dims() dims = tensor.dims()
if dformat == aidge_core.dformat.Default: if dformat == aidge_core.dformat.default:
if len(dims) == 4: # Suppose NCHW if len(dims) == 4: # Suppose NCHW
return dims[1] return dims[1]
elif len(dims) == 2: # Suppose NC elif len(dims) == 2: # Suppose NC
return dims[1] return dims[1]
else: else:
return None return None
elif dformat == aidge_core.dformat.NCHW: elif dformat == aidge_core.dformat.nchw:
return dims[1] return dims[1]
elif dformat == aidge_core.dformat.NHWC: elif dformat == aidge_core.dformat.nhwc:
return dims[3] return dims[3]
elif dformat == aidge_core.dformat.CHWN: elif dformat == aidge_core.dformat.chwn:
return dims[0] return dims[0]
elif dformat == aidge_core.dformat.NCDHW: elif dformat == aidge_core.dformat.ncdhw:
return dims[1] return dims[1]
elif dformat == aidge_core.dformat.NDHWC: elif dformat == aidge_core.dformat.ndhwc:
return dims[4] return dims[4]
elif dformat == aidge_core.dformat.CDHWN: elif dformat == aidge_core.dformat.cdhwn:
return dims[0] return dims[0]
else: else:
raise RuntimeError(f"Unknown dataformat: {dformat}") raise RuntimeError(f"Unknown dataformat: {dformat}")
...@@ -37,24 +37,24 @@ def get_chan(tensor: aidge_core.Tensor) -> int: ...@@ -37,24 +37,24 @@ def get_chan(tensor: aidge_core.Tensor) -> int:
def get_height(tensor: aidge_core.Tensor) -> int: def get_height(tensor: aidge_core.Tensor) -> int:
dformat = tensor.dformat() dformat = tensor.dformat()
dims = tensor.dims() dims = tensor.dims()
if dformat == aidge_core.dformat.Default: if dformat == aidge_core.dformat.default:
if len(dims) == 4: # Suppose NCHW if len(dims) == 4: # Suppose NCHW
return dims[2] return dims[2]
elif len(dims) == 2: # Suppose NC elif len(dims) == 2: # Suppose NC
return 1 return 1
else: else:
return None return None
elif dformat == aidge_core.dformat.NCHW: elif dformat == aidge_core.dformat.nchw:
return dims[2] return dims[2]
elif dformat == aidge_core.dformat.NHWC: elif dformat == aidge_core.dformat.nhwc:
return dims[1] return dims[1]
elif dformat == aidge_core.dformat.CHWN: elif dformat == aidge_core.dformat.chwn:
return dims[1] return dims[1]
elif dformat == aidge_core.dformat.NCDHW: elif dformat == aidge_core.dformat.ncdhw:
return dims[3] return dims[3]
elif dformat == aidge_core.dformat.NDHWC: elif dformat == aidge_core.dformat.ndhwc:
return dims[2] return dims[2]
elif dformat == aidge_core.dformat.CDHWN: elif dformat == aidge_core.dformat.cdhwn:
return dims[2] return dims[2]
else: else:
raise RuntimeError(f"Unknown dataformat: {dformat}") raise RuntimeError(f"Unknown dataformat: {dformat}")
...@@ -63,24 +63,24 @@ def get_height(tensor: aidge_core.Tensor) -> int: ...@@ -63,24 +63,24 @@ def get_height(tensor: aidge_core.Tensor) -> int:
def get_width(tensor: aidge_core.Tensor) -> int: def get_width(tensor: aidge_core.Tensor) -> int:
dformat = tensor.dformat() dformat = tensor.dformat()
dims = tensor.dims() dims = tensor.dims()
if dformat == aidge_core.dformat.Default: if dformat == aidge_core.dformat.default:
if len(dims) == 4: # Suppose NCHW if len(dims) == 4: # Suppose NCHW
return dims[3] return dims[3]
elif len(dims) == 2: # Suppose NC elif len(dims) == 2: # Suppose NC
return 1 return 1
else: else:
return None return None
elif dformat == aidge_core.dformat.NCHW: elif dformat == aidge_core.dformat.nchw:
return dims[3] return dims[3]
elif dformat == aidge_core.dformat.NHWC: elif dformat == aidge_core.dformat.nhwc:
return dims[2] return dims[2]
elif dformat == aidge_core.dformat.CHWN: elif dformat == aidge_core.dformat.chwn:
return dims[2] return dims[2]
elif dformat == aidge_core.dformat.NCDHW: elif dformat == aidge_core.dformat.ncdhw:
return dims[4] return dims[4]
elif dformat == aidge_core.dformat.NDHWC: elif dformat == aidge_core.dformat.ndhwc:
return dims[3] return dims[3]
elif dformat == aidge_core.dformat.CDHWN: elif dformat == aidge_core.dformat.cdhwn:
return dims[3] return dims[3]
else: else:
raise RuntimeError(f"Unknown dataformat: {dformat}") raise RuntimeError(f"Unknown dataformat: {dformat}")
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment