From fe58aba03030daa7175bbe98d10f866869c74e0a Mon Sep 17 00:00:00 2001 From: Wissam Boussella <wissam.boussella@cea.fr> Date: Thu, 23 Jan 2025 17:14:34 +0100 Subject: [PATCH] Conv fwd_dims both for nhwc and nchw in input and output --- include/aidge/operator/Conv.hpp | 4 ++ src/operator/Conv.cpp | 77 ++++++++++++++++++++------------- 2 files changed, 50 insertions(+), 31 deletions(-) diff --git a/include/aidge/operator/Conv.hpp b/include/aidge/operator/Conv.hpp index 135ff8860..e2faeb6ac 100644 --- a/include/aidge/operator/Conv.hpp +++ b/include/aidge/operator/Conv.hpp @@ -172,6 +172,8 @@ public: if (!getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of input channel imposed."); } + if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC) + return getInput(1)->template dims<DIM+2>()[DIM+1]; return getInput(1)->template dims<DIM+2>()[1]; } @@ -184,6 +186,8 @@ public: if (!getInput(1)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Convolution operator has no weight Tensor associated so no specific number of output channel imposed."); } + if(getInput(1)->dataFormat()==Aidge::DataFormat::NHWC) + return getInput(1)->template dims<DIM+2>()[DIM+1]; return getInput(1)->template dims<DIM+2>()[0]; } diff --git a/src/operator/Conv.cpp b/src/operator/Conv.cpp index 836c47645..746c32dd4 100644 --- a/src/operator/Conv.cpp +++ b/src/operator/Conv.cpp @@ -40,42 +40,57 @@ Aidge::Conv_Op<DIM>::Conv_Op(const Aidge::Conv_Op<DIM>& op) template <Aidge::DimIdx_t DIM> bool Aidge::Conv_Op<DIM>::forwardDims(bool /*allowDataDependency*/) { - if (inputsAssociated()) { - // first check weight since it defines inChannels and outChannels - AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)), - "Wrong weight Tensor dimension: {} for Conv{}D operator. Expected number of dimensions is {}.", getInput(1)->nbDims(), DIM, DIM+2); - // check data + if (!inputsAssociated()) + return false; + // first check weight since it defines inChannels and outChannels + if(getInput(0)->dataFormat() == Aidge::DataFormat::NHWC){ AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && - (getInput(0)->template dims<DIM+2>()[1] == inChannels()), - "Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", ")); - // check optional bias - if(getInput(2)) - AIDGE_ASSERT((getInput(2)->nbDims() == (1)) && - (getInput(2)->template dims<1>()[0] == outChannels()), - "Wrong bias size ({}) for Conv operator. Expected dims are [{}].", getInput(2)->dims(), outChannels()); - - std::array<DimSize_t, DIM + 2> outputDims{}; - const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); - - for (std::size_t dim = 0; dim < mAttributes->template getAttr<ConvAttr::KernelDims>().size() ; ++dim) { - const DimSize_t kernelExtent = mAttributes->template getAttr<ConvAttr::DilationDims>()[dim] * - (mAttributes->template getAttr<ConvAttr::KernelDims>()[dim] - 1) + - 1; - - outputDims[dim+2] = 1 + static_cast<DimSize_t>( - floor(static_cast<float>(inputDims[dim+2] - kernelExtent) / - static_cast<float>(mAttributes->template getAttr<ConvAttr::StrideDims>()[dim]))); - } + (getInput(0)->template dims<DIM+2>()[DIM+1] == inChannels()), + "Wrong input size ({}) for Conv operator. Expected dims are [{}, {}, x].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", ")); + } + else{ //For dataFormat in NCHW or Default Format + AIDGE_ASSERT((getInput(0)->nbDims() == (DIM+2)) && + (getInput(0)->template dims<DIM+2>()[1] == inChannels()), + "Wrong input size ({}) for Conv operator. Expected dims are [x, {}, {}].", getInput(0)->dims(), inChannels(), fmt::join(std::vector<std::string>(DIM, "x"), ", ")); + } - outputDims[1] = outChannels(); - outputDims[0] = inputDims[0]; - mOutputs[0]->resize(outputDims); - return true; + AIDGE_ASSERT((getInput(1)->nbDims() == (DIM+2)), + "Wrong weight Tensor dimension: {} for Conv{}D operator. Expected number of dimensions is {}.", getInput(1)->nbDims(), DIM, DIM+2); + + if(getInput(2)) + AIDGE_ASSERT((getInput(2)->nbDims() == (1)) && + (getInput(2)->template dims<1>()[0] == outChannels()), + "Wrong bias size ({}) for Conv operator. Expected dims are [{}].", getInput(2)->dims(), outChannels()); + + const std::array<DimSize_t, DIM + 2> inputDims(getInput(0)->template dims<DIM+2>()); + std::array<DimSize_t, DIM + 2> outputDims; + + + unsigned int in_dims_index = (getInput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; + unsigned int out_dims_index = (getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) ? 1 : 2; + + for (std::size_t dim = 0; dim < mAttributes->template getAttr<ConvAttr::KernelDims>().size(); ++dim) { + const DimSize_t kernelExtent = mAttributes->template getAttr<ConvAttr::DilationDims>()[dim] * + (mAttributes->template getAttr<ConvAttr::KernelDims>()[dim] - 1) + + 1; + + outputDims[dim + out_dims_index] = 1 + static_cast<DimSize_t>( + floor(static_cast<float>(inputDims[dim + in_dims_index] - kernelExtent) / + static_cast<float>(mAttributes->template getAttr<ConvAttr::StrideDims>()[dim])) + ); } - return false; -} + if(getOutput(0)->dataFormat() == Aidge::DataFormat::NHWC) + outputDims[DIM+1] = outChannels(); + else + outputDims[1] = outChannels(); + outputDims[0] = inputDims[0]; + mOutputs[0]->resize(outputDims); + return true; + + +} template <Aidge::DimIdx_t DIM> std::vector<std::pair<std::vector<Aidge::DimSize_t>, std::vector<Aidge::DimSize_t>>> -- GitLab