From 9885ceaf8a2cb17fca42fb95d7ab7ea1c716f11b Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 12 Jul 2024 07:08:06 +0000
Subject: [PATCH] Add back data format when chosing how to retrieve nb chan ...

---
 aidge_core/export_utils/node_export.py | 116 ++++++++++++-------------
 1 file changed, 56 insertions(+), 60 deletions(-)

diff --git a/aidge_core/export_utils/node_export.py b/aidge_core/export_utils/node_export.py
index d09a74742..310d0ff8a 100644
--- a/aidge_core/export_utils/node_export.py
+++ b/aidge_core/export_utils/node_export.py
@@ -9,83 +9,79 @@ def get_chan(tensor: aidge_core.Tensor) -> int:
     """
     dformat = tensor.dformat()
     dims = tensor.dims()
-    if len(dims) == 4: # Suppose NCHW
+    if dformat == aidge_core.dformat.Default:
+        if len(dims) == 4: # Suppose NCHW
+            return dims[1]
+        elif len(dims) == 2: # Suppose NC
+            return dims[1]
+        else:
+            return None
+    elif dformat == aidge_core.dformat.NCHW:
         return dims[1]
-    elif len(dims) == 2: # Suppose NC
+    elif dformat == aidge_core.dformat.NHWC:
+        return dims[3]
+    elif dformat == aidge_core.dformat.CHWN:
+        return dims[0]
+    elif dformat == aidge_core.dformat.NCDHW:
         return dims[1]
+    elif dformat == aidge_core.dformat.NDHWC:
+        return dims[4]
+    elif dformat == aidge_core.dformat.CDHWN:
+        return dims[0]
     else:
-        return None
-    # if dformat == aidge_core.dformat.Default:
-    #     return None
-    # elif dformat == aidge_core.dformat.NCHW:
-    #     return dims[1]
-    # elif dformat == aidge_core.dformat.NHWC:
-    #     return dims[3]
-    # elif dformat == aidge_core.dformat.CHWN:
-    #     return dims[0]
-    # elif dformat == aidge_core.dformat.NCDHW:
-    #     return dims[1]
-    # elif dformat == aidge_core.dformat.NDHWC:
-    #     return dims[4]
-    # elif dformat == aidge_core.dformat.CDHWN:
-    #     return dims[0]
-    # else:
-    #     raise RuntimeError(f"Unknown dataformat: {dformat}")
+        raise RuntimeError(f"Unknown dataformat: {dformat}")
 
 
 def get_height(tensor: aidge_core.Tensor) -> int:
     dformat = tensor.dformat()
     dims = tensor.dims()
-    if len(dims) == 4: # Suppose NCHW
+    if dformat == aidge_core.dformat.Default:
+        if len(dims) == 4: # Suppose NCHW
+            return dims[2]
+        elif len(dims) == 2: # Suppose NC
+            return 1
+        else:
+            return None
+    elif dformat == aidge_core.dformat.NCHW:
+        return dims[2]
+    elif dformat == aidge_core.dformat.NHWC:
+        return dims[1]
+    elif dformat == aidge_core.dformat.CHWN:
+        return dims[1]
+    elif dformat == aidge_core.dformat.NCDHW:
+        return dims[3]
+    elif dformat == aidge_core.dformat.NDHWC:
+        return dims[2]
+    elif dformat == aidge_core.dformat.CDHWN:
         return dims[2]
-    elif len(dims) == 2: # Suppose NC
-        return 1
     else:
-        return None
-    # TODO: use when dformat is fully supported
-    # if dformat == aidge_core.dformat.Default:
-    #     return None
-    # elif dformat == aidge_core.dformat.NCHW:
-    #     return dims[2]
-    # elif dformat == aidge_core.dformat.NHWC:
-    #     return dims[1]
-    # elif dformat == aidge_core.dformat.CHWN:
-    #     return dims[1]
-    # elif dformat == aidge_core.dformat.NCDHW:
-    #     return dims[3]
-    # elif dformat == aidge_core.dformat.NDHWC:
-    #     return dims[2]
-    # elif dformat == aidge_core.dformat.CDHWN:
-    #     return dims[2]
-    # else:
-    #     raise RuntimeError(f"Unknown dataformat: {dformat}")
+        raise RuntimeError(f"Unknown dataformat: {dformat}")
 
 
 def get_width(tensor: aidge_core.Tensor) -> int:
     dformat = tensor.dformat()
     dims = tensor.dims()
-    if len(dims) == 4: # Suppose NCHW
+    if dformat == aidge_core.dformat.Default:
+        if len(dims) == 4: # Suppose NCHW
+            return dims[3]
+        elif len(dims) == 2: # Suppose NC
+            return 1
+        else:
+            return None
+    elif dformat == aidge_core.dformat.NCHW:
+        return dims[3]
+    elif dformat == aidge_core.dformat.NHWC:
+        return dims[2]
+    elif dformat == aidge_core.dformat.CHWN:
+        return dims[2]
+    elif dformat == aidge_core.dformat.NCDHW:
+        return dims[4]
+    elif dformat == aidge_core.dformat.NDHWC:
+        return dims[3]
+    elif dformat == aidge_core.dformat.CDHWN:
         return dims[3]
-    elif len(dims) == 2: # Suppose NC
-        return 1
     else:
-        return None
-    # if dformat == aidge_core.dformat.Default:
-    #     return None
-    # elif dformat == aidge_core.dformat.NCHW:
-    #     return dims[3]
-    # elif dformat == aidge_core.dformat.NHWC:
-    #     return dims[2]
-    # elif dformat == aidge_core.dformat.CHWN:
-    #     return dims[2]
-    # elif dformat == aidge_core.dformat.NCDHW:
-    #     return dims[4]
-    # elif dformat == aidge_core.dformat.NDHWC:
-    #     return dims[3]
-    # elif dformat == aidge_core.dformat.CDHWN:
-    #     return dims[3]
-    # else:
-    #     raise RuntimeError(f"Unknown dataformat: {dformat}")
+        raise RuntimeError(f"Unknown dataformat: {dformat}")
 
 
 class ExportNode(ABC):
-- 
GitLab