diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index c6bbd0e40b8cf1a3a6902ca18d60203da117eef3..f9c9109282cb90dadfa9b26d6f830faf9fdecd7c 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -173,8 +173,10 @@ public: AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed."); } + // check format if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC) return getInput(1)->template dims<DIM+2>()[DIM+1]; + // default format is NCHW return getInput(1)->template dims<DIM+2>()[1]; } @@ -187,6 +189,7 @@ public: if (!getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed."); } + // first weight dimension for both NCHW (Cout,Cin,H,W) and NHWC (Cout,H,W,Cin) data format return getInput(1)->template dims<DIM+2>()[0]; } diff --git a/src/operator/Conv.cpp b/src/operator/Conv.cpp index 91aaad8ee8cf8723b24b8692c59e39f34d158f18..2077cab52f613780e77bba80efacb41d06a7f3cf 100644 --- a/src/operator/Conv.cpp +++ b/src/operator/Conv.cpp @@ -46,7 +46,7 @@ bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC){ AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && (getInput(0)->template dims<DIM+2>()[DIM+1] == inChannels()), - "Wrong input size ({}) for Conv operator. Expected dims are [{}, {}, x].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", ")); + "Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), fmt::join(std::vector<std::string>(DIM, "x"), ", "), inChannels()); } else{ //For dataFormat in NCHW or Default Format AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) &&