From a9acb2f86f949d74f7a660d1eddcbc6c4045a2a7 Mon Sep 17 00:00:00 2001 From: Wissam Boussella <wissam.boussella@cea.fr> Date: Wed, 19 Feb 2025 16:06:52 +0100 Subject: [PATCH] new comments for Conv.hpp and Conv.cpp --- include/aidge/operator/Conv.hpp | 3 +++ src/operator/Conv.cpp | 2 +- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index c6bbd0e40..f9c910928 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -173,8 +173,10 @@ public: AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed."); } + // check format if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC) return getInput(1)->template dims<DIM+2>()[DIM+1]; + // default format is NCHW return getInput(1)->template dims<DIM+2>()[1]; } @@ -187,6 +189,7 @@ public: if (!getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed."); } + // first weight dimension for both NCHW (Cout,Cin,H,W) and NHWC (Cout,H,W,Cin) data format return getInput(1)->template dims<DIM+2>()[0]; } diff --git a/src/operator/Conv.cpp b/src/operator/Conv.cpp index 91aaad8ee..2077cab52 100644 --- a/src/operator/Conv.cpp +++ b/src/operator/Conv.cpp @@ -46,7 +46,7 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC){ AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && (getInput(0)->template dims<DIM+2>()[DIM+1] == inChannels()), - "Wrong input size ({}) for Conv operator. Expected dims are [{}, {}, x].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", ")); + "Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), fmt::join(std::vector<std::string>(DIM, "x"), ", "), inChannels()); } else{ //For dataFormat in NCHW or Default Format AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && -- GitLab