From fe58aba03030daa7175bbe98d10f866869c74e0a Mon Sep 17 00:00:00 2001
From: Wissam Boussella <wissam.boussella@cea.fr>
Date: Thu, 23 Jan 2025 17:14:34 +0100
Subject: [PATCH] Conv fwd_dims both for nhwc and nchw in input and output

---
 include/aidge/operator/Conv.hpp |  4 ++
 src/operator/Conv.cpp           | 77 ++++++++++++++++++++-------------
 2 files changed, 50 insertions(+), 31 deletions(-)

diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp
index 135ff8860..e2faeb6ac 100644
--- a/include/aidge/operator/Conv.hpp
+++ b/include/aidge/operator/Conv.hpp
@@ -172,6 +172,8 @@ public:
         if (!getInput(1)) {
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed.");
         }
+        if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC) 
+            return getInput(1)->template dims<DIM+2>()[DIM+1];
         return getInput(1)->template dims<DIM+2>()[1];
     }
 
@@ -184,6 +186,8 @@ public:
         if (!getInput(1)) {
             AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed.");
         }
+        if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC) 
+            return getInput(1)->template dims<DIM+2>()[DIM+1];
         return getInput(1)->template dims<DIM+2>()[0];
     }
 
diff --git a/src/operator/Conv.cpp b/src/operator/Conv.cpp
index 836c47645..746c32dd4 100644
--- a/src/operator/Conv.cpp
+++ b/src/operator/Conv.cpp
@@ -40,42 +40,57 @@ Aidge::Conv_Op<DIM>::Conv_Op(const Aidge::Conv_Op<DIM>& op)
 
 template <Aidge::DimIdx_t DIM>
 bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) {
-    if (inputsAssociated()) {
-        // first check weight since it defines inChannels and outChannels
-        AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)),
-                    "Wrong weight Tensor dimension: {} for Conv{}D operator. Expected number of dimensions is {}.", getInput(1)->nbDims(), DIM, DIM+2);
-        // check data
+    if (!inputsAssociated()) 
+        return false;
+    // first check weight since it defines inChannels and outChannels
+    if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC){
         AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) &&
-                    (getInput(0)->template dims<DIM+2>()[1] == inChannels()),
-                    "Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", "));
-        // check optional bias
-        if(getInput(2))
-            AIDGE_ASSERT((getInput(2)->nbDims() == (1)) &&
-                    (getInput(2)->template dims<1>()[0] == outChannels()),
-                    "Wrong bias size ({}) for Conv operator. Expected dims are [{}].", getInput(2)->dims(), outChannels());
-
-        std::array<DimSize_t, DIM + 2> outputDims{};
-        const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
-
-        for (std::size_t dim = 0; dim < mAttributes->template getAttr<ConvAttr::KernelDims>().size() ; ++dim) {
-            const DimSize_t kernelExtent = mAttributes->template getAttr<ConvAttr::DilationDims>()[dim] *
-                                                    (mAttributes->template getAttr<ConvAttr::KernelDims>()[dim] - 1) +
-                                            1;
-
-            outputDims[dim+2] = 1 + static_cast<DimSize_t>(
-                    floor(static_cast<float>(inputDims[dim+2] - kernelExtent) /
-                            static_cast<float>(mAttributes->template getAttr<ConvAttr::StrideDims>()[dim])));
-        }
+                (getInput(0)->template dims<DIM+2>()[DIM+1] == inChannels()),
+                "Wrong input size ({}) for Conv operator. Expected dims are [{}, {}, x].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", "));
+    }
+    else{ //For dataFormat in NCHW or Default Format
+        AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) &&
+                (getInput(0)->template dims<DIM+2>()[1] == inChannels()),
+                "Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", "));
+    }
 
-        outputDims[1] = outChannels();
-        outputDims[0] = inputDims[0];
-        mOutputs[0]->resize(outputDims);
-        return true;
+    AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)),
+                "Wrong weight Tensor dimension: {} for Conv{}D operator. Expected number of dimensions is {}.", getInput(1)->nbDims(), DIM, DIM+2);
+
+    if(getInput(2))
+        AIDGE_ASSERT((getInput(2)->nbDims() == (1)) &&
+                (getInput(2)->template dims<1>()[0] == outChannels()),
+                "Wrong bias size ({}) for Conv operator. Expected dims are [{}].", getInput(2)->dims(), outChannels());
+
+    const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>());
+    std::array<DimSize_t, DIM + 2> outputDims;
+
+    
+    unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2;
+    unsigned int out_dims_index = (getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2;
+
+    for (std::size_t dim = 0; dim < mAttributes->template getAttr<ConvAttr::KernelDims>().size(); ++dim) {
+        const DimSize_t kernelExtent = mAttributes->template getAttr<ConvAttr::DilationDims>()[dim] *
+                                    (mAttributes->template getAttr<ConvAttr::KernelDims>()[dim] - 1) +
+                                    1;
+        
+        outputDims[dim + out_dims_index] = 1 + static_cast<DimSize_t>(
+            floor(static_cast<float>(inputDims[dim + in_dims_index] - kernelExtent) /
+                static_cast<float>(mAttributes->template getAttr<ConvAttr::StrideDims>()[dim]))
+        );
     }
 
-    return false;
-}
+    if(getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) 
+        outputDims[DIM+1] = outChannels();
+    else 
+        outputDims[1] = outChannels();
 
+    outputDims[0] = inputDims[0];
+    mOutputs[0]->resize(outputDims);
+    return true;
+    
+    
+}
 
 template <Aidge::DimIdx_t DIM>
 std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>>
-- 
GitLab