diff --git a/include/aidge/backend/cpu.hpp b/include/aidge/backend/cpu.hpp index 80574b4a46fef0c843c9511836f162e02de5aab3..5c1f9b111f41a435aa477d0647fa66fb29a058fb 100644 --- a/include/aidge/backend/cpu.hpp +++ b/include/aidge/backend/cpu.hpp @@ -27,6 +27,7 @@ #include "aidge/backend/cpu/operator/ClipImpl.hpp" #include "aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp" #include "aidge/backend/cpu/operator/ConvImpl.hpp" +#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp" #include "aidge/backend/cpu/operator/ConstantOfShapeImpl.hpp" #include "aidge/backend/cpu/operator/CryptoHashImpl.hpp" #include "aidge/backend/cpu/operator/DivImpl.hpp" diff --git a/include/aidge/backend/cpu/operator/ConvImpl.hpp b/include/aidge/backend/cpu/operator/ConvImpl.hpp index c06d0912f419909013f930867ce3c3238c1a5555..e480697b6452440f043901140a07cb643f3cbdb6 100644 --- a/include/aidge/backend/cpu/operator/ConvImpl.hpp +++ b/include/aidge/backend/cpu/operator/ConvImpl.hpp @@ -13,45 +13,64 @@ #define AIDGE_CPU_OPERATOR_CONVIMPL_H_ #include <array> -#include <memory> -#include <tuple> -#include <vector> #include "aidge/backend/cpu/operator/OperatorImpl.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" namespace Aidge { + // Operator implementation entry point for the backend using Conv1D_Op = Conv_Op<1>; using ConvImpl1D_cpu = OperatorImpl_cpu<Conv_Op<1>, - void(const std::array<DimSize_t, 1>&, - const std::array<DimSize_t, 1>&, - const std::array<DimSize_t, 1>&, - const std::array<DimSize_t, 3> &, - DimSize_t, - const void *, - const void *, - const void *, - void *)>; + void(const std::array<DimSize_t, 1> &, + const std::array<DimSize_t, 1> &, + const std::array<DimSize_t, 1> &, + const std::array<DimSize_t, 3> &, + DimSize_t, + const void *, + const void *, + const void *, + void *), + void(const std::array<DimSize_t, 1> &, + const std::array<DimSize_t, 1> &, + const std::array<DimSize_t, 1> &, + const std::array<DimSize_t, 3> &, + const std::array<DimSize_t, 3> &, + const void *, + const void *, + const void *, + void *, + void *, + void *)>; using Conv2D_Op = Conv_Op<2>; -using ConvImpl2D_cpu = OperatorImpl_cpu<Conv_Op<2>, - void(const std::array<DimSize_t, 2>&, - const std::array<DimSize_t, 2>&, - const std::array<DimSize_t, 2>&, - const std::array<DimSize_t, 4> &, - DimSize_t, - const void *, - const void *, - const void *, - void *)>; +using ConvImpl2D_cpu = OperatorImpl_cpu<Conv2D_Op, + void(const std::array<DimSize_t, 2> &, + const std::array<DimSize_t, 2> &, + const std::array<DimSize_t, 2> &, + const std::array<DimSize_t, 4> &, + DimSize_t, + const void *, + const void *, + const void *, + void *), + void(const std::array<DimSize_t, 2> &, + const std::array<DimSize_t, 2> &, + const std::array<DimSize_t, 2> &, + const std::array<DimSize_t, 4> &, + const std::array<DimSize_t, 4> &, + const void *, + const void *, + const void *, + void *, + void *, + void *)>; // Implementation entry point registration to Operator REGISTRAR(Conv1D_Op, "cpu", Aidge::ConvImpl1D_cpu::create); REGISTRAR(Conv2D_Op, "cpu", Aidge::ConvImpl2D_cpu::create); -} // namespace Aidge +} // namespace Aidge #endif /* AIDGE_CPU_OPERATOR_CONVIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp index 1229d5714e6b0cbae4e42ece9130c2c2305f133e..7ae9e45fe4f5d7436e3f08447c69bef3c16b6218 100644 --- a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp @@ -13,18 +13,16 @@ #define AIDGE_CPU_OPERATOR_CONVIMPL_KERNELS_H_ #include <array> -#include <memory> -#include <tuple> -#include <vector> +#include <cstdint> +#include <fmt/base.h> -#include "aidge/backend/cpu/operator/OperatorImpl.hpp" #include "aidge/backend/cpu/operator/ConvImpl.hpp" -#include "aidge/operator/Conv.hpp" #include "aidge/utils/Registrar.hpp" #include "aidge/utils/Types.h" -#include "aidge/backend/cpu/data/GetCPUPtr.h" namespace Aidge { +using std::array; + /** * @brief Forward kernel for 1D Convolution on CPU backend. * @tparam I Input data type. @@ -39,16 +37,15 @@ namespace Aidge { * @param output_ Output Tensor. */ template <class I, class W, class B, class O> -void ConvImpl1D_cpu_forward_kernel(const std::array<DimSize_t, 1>& strideDims, - const std::array<DimSize_t, 1>& dilationDims, - const std::array<DimSize_t, 1>& kernelDims, - const std::array<DimSize_t, 3>& inputDims, - DimSize_t outChannels, - const void *input_, - const void *weights_, - const void *biases_, - void *output_) -{ +void ConvImpl1D_cpu_forward_kernel(const array<DimSize_t, 1> &strideDim, + const array<DimSize_t, 1> &dilationDim, + const array<DimSize_t, 1> &kernelDim, + const std::array<DimSize_t, 3> &inputDims, + DimSize_t outChannels, + const void *input_, + const void *weights_, + const void *biases_, + void *output_) { // FIXME: missing convolution attributes as arguments const I *input = static_cast<const I *>(input_); const W *weights = static_cast<const W *>(weights_); @@ -56,38 +53,38 @@ void ConvImpl1D_cpu_forward_kernel(const std::array<DimSize_t, 1>& strideDims, O *output = static_cast<O *>(output_); // output H size - const std::size_t oxSize = - static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[2] - dilationDims[0]*(kernelDims[0] - 1) - 1 + strideDims[0]) / - static_cast<float>(strideDims[0]))); - const DimSize_t dilated_kernel_x = dilationDims[0]*(kernelDims[0] - 1) + 1; + const std::size_t oxSize = static_cast<std::size_t>(std::floor( + static_cast<float>(inputDims[2] - dilationDim[0] * (kernelDim[0] - 1) - + 1 + strideDim[0]) / + static_cast<float>(strideDim[0]))); + const DimSize_t dilated_kernel_x = dilationDim[0] * (kernelDim[0] - 1) + 1; - // TODO: kernel computation - // output (batch, outCh, Xout, Yout) - // input (batch, inCh, Xin, Yin) - // weight (outCh, inCh, kernelX, kernelY) - // does not take Dilation attribute into account using signedsize = std::make_signed<std::size_t>::type; for (std::size_t batch = 0; batch < inputDims[0]; ++batch) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { - const std::size_t oIndex = (outCh + batch*outChannels) * oxSize; + const std::size_t oIndex = (outCh + batch * outChannels) * oxSize; // If bias = nullptr, set B(0) B biasVal = (biases != nullptr) ? biases[outCh] : B(0); - std::fill(output + oIndex, output+(oIndex+oxSize), biasVal); + std::fill(output + oIndex, output + (oIndex + oxSize), biasVal); for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { - const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2]; - const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0]; + const std::size_t iIndex = + (inCh + batch * inputDims[1]) * inputDims[2]; + const std::size_t wIndex = + (inCh + outCh * inputDims[1]) * kernelDim[0]; for (std::size_t ox = 0; ox < oxSize; ++ox) { - // const signedsize difx = static_cast<signedsize>(- ox * strideDims[0]); - // const std::size_t sxMin = static_cast<std::size_t>(std::max(difx, signedsize(0))); - // const std::size_t sxMax = (static_cast<signedsize>(inputDims[2]) + difx) < 0 ? 0 : ((inputDims[2] + difx) > kernelDims[0] ? kernelDims[0] : inputDims[2] + difx); const std::size_t sxMin = 0; const std::size_t sxMax = dilated_kernel_x; const std::size_t oIndexFull = oIndex + ox; - const signedsize ix = static_cast<signedsize>(ox * strideDims[0]); + const signedsize ix = + static_cast<signedsize>(ox * strideDim[0]); - for (std::size_t sx = sxMin; sx*dilationDims[0] < sxMax; ++sx) { - output[oIndexFull] += weights[wIndex + sx] * - input[iIndex + static_cast<std::size_t>(ix+static_cast<signedsize>(sx*dilationDims[0]))]; + for (std::size_t sx = sxMin; sx * dilationDim[0] < sxMax; + ++sx) { + output[oIndexFull] += + weights[wIndex + sx] * + input[iIndex + static_cast<std::size_t>( + ix + static_cast<signedsize>( + sx * dilationDim[0]))]; } } } @@ -95,20 +92,342 @@ void ConvImpl1D_cpu_forward_kernel(const std::array<DimSize_t, 1>& strideDims, } } +/** + * @brief perform 1D backpropagation for the data input + * @note INPUT & OUTPUT convention is the same as in the + * forward function + * @note formula : + * for i in 0..input_size: + * for n in 0..weight_size: + * dL dYn dL + * ---- = ---- ---- + * dXi dXi Yn + * with : dYn / dXi = w_k + * for each input value + * for each weight + * for each output + * multiply the weight with the associated value + * @note kernel & stride are passed as single integers as they are just arrays + * of length 1 + * @note reminder that kernel dimensions are + * {outChannels, inChannels, {kernelDims}} + * <=> {oDims[1], iDims[1], kernelDim} + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam O Output data type. + * @param[in] stride stride parameter of the convolution operator + * @param[in] dilation dilation parameter of the convolution operator + * @param[in] kDims dimension of the kernel + * @param[in] kStrides nb of elements contained per dimension of the kernel + * @param[in] weights kernel weights + * @param[in] oDims dimensions of the output + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[in] oGrad output gradient + * @param[in] iDims input dimensions + * @param[in] iStrides nb of elements contained per dimension of the input + * @param[inout] iGrad gradients of the input to update + */ +template <class I, class W, class O> +void conv1DBackwardInput(const array<DimSize_t, 1> &stride, + const array<DimSize_t, 1> &dilation, + const array<DimSize_t, 1> &kDim, + const array<DimSize_t, 2> &kStrides, + const W *weights, + const array<DimSize_t, 3> &oDims, + const array<DimSize_t, 2> &oStrides, + const O *oGrad, + const array<DimSize_t, 3> &iDims, + const array<DimSize_t, 2> &iStrides, + I *iGrad) { + + array<DimSize_t, 2> iOffsets{0, 0}; + array<DimSize_t, 2> oOffsets{0, 0}; + array<DimSize_t, 2> kOffsets{0, 0}; + + for (std::size_t batch = 0; batch < iDims[0]; ++batch) { + iOffsets[0] = batch * iStrides[0]; + oOffsets[0] = batch * oStrides[0]; + + for (DimSize_t oChannel = 0; oChannel < oDims[1]; oChannel++) { + oOffsets[1] = (oChannel * oStrides[1]) + oOffsets[0]; + kOffsets[0] = oChannel * kStrides[0]; + + for (std::size_t iChannel = 0; iChannel < iDims[1]; ++iChannel) { + iOffsets[1] = (iChannel * iStrides[1]) + iOffsets[0]; + kOffsets[1] = iChannel * kStrides[1] + kOffsets[0]; + + for (DimSize_t oX = 0; oX < oDims[2]; ++oX) { + auto iX = oX * stride[0]; + auto inIdx = iX + iOffsets[1]; + + for (DimSize_t kX = 0; kX < kDim[0]; ++kX) { + auto dilatedKernelIdx = kX * dilation[0]; + + iGrad[inIdx + dilatedKernelIdx] += + weights[kOffsets[1] + kX] * + oGrad[oOffsets[1] + oX]; + } + } + } + } + } +} + +/** + * @brief computes weight backpropagation for conv1D + * @note INPUT & OUTPUT convention is the same as in the + * forward function + * weight grad + * for i in 0..weight_size: + * for n in 0..output_size: + * dL dYn dL + * ---- = ---- ---- + * dwi dwi Yn + * with : dYn / dwi = x_k + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam O Output data type. + * @param[in] stride stride parameter of the convolution operator + * @param[in] dilation dilation parameter of the convolution operator + * @param[in] iDims input dimensions + * @param[in] iStrides nb of elements contained per dimension of the input + * @param[inout] iGrad gradients of the input to update + * @param[in] oDims dimensions of the output + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[in] oGrad output gradient + * @param[in] kDims dimension of the kernel + * @param[in] kStrides nb of elements contained per dimension of the kernel + * @param[in] weights kernel weights + */ +template <class I, class W, class O> +static void conv1DBackwardWeights(const array<DimSize_t, 1> &stride, + const array<DimSize_t, 1> &dilation, + const array<DimSize_t, 3> &iDims, + const array<DimSize_t, 2> iStrides, + const I *input, + const array<DimSize_t, 3> &oDims, + const array<DimSize_t, 2> oStrides, + const O *oGrad, + const array<DimSize_t, 1> &kDim, + const array<DimSize_t, 2> kStrides, + W *weightsGrad) { + + array<DimSize_t, 2> iOffsets{0, 0}; + array<DimSize_t, 2> oOffsets{0, 0}; + array<DimSize_t, 2> kOffsets{0, 0}; + + for (DimSize_t batch = 0; batch < oDims[0]; ++batch) { + iOffsets[0] = batch * iStrides[0]; + oOffsets[0] = batch * oStrides[0]; + + for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) { + oOffsets[1] = oChannel * oStrides[1] + oOffsets[0]; + kOffsets[0] = oChannel * kStrides[0]; + + for (DimSize_t iChannel = 0; iChannel < iDims[1]; ++iChannel) { + kOffsets[1] = iChannel * kStrides[1] + kOffsets[0]; + iOffsets[1] = iChannel * iStrides[1] + iOffsets[0]; + oOffsets[1] = oChannel * oStrides[1] + oOffsets[0]; + + for (DimSize_t kX = 0; kX < kDim[0]; ++kX) { + + for (DimSize_t oX = 0; oX < oDims[2]; ++oX) { + const DimSize_t iX = oX * stride[0] + kX * dilation[0] ; + + weightsGrad[kOffsets[1] + kX] += + input[iOffsets[1] + iX] * oGrad[oOffsets[1] + oX]; + } + } + } + } + } +} + +/** + * @brief computes bias backpropagation for conv1D operation + * @note INPUT & OUTPUT convention is the same as in the + * forward function + * @note formula : + * Bias grad: + * for i in 0..bias_size: + * for n in 0..output_size: + * dL dYn dL + * ---- = ---- ---- + * dbi dbi Yn + * with : dYn / dbi = 1 + * + * Hence the partial derivative of the loss wrt bias is the + * output loss. Hence the bias grad is just the sum of the + * loss values over the batch + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] oDims output tensor dimensions + * @param[in] oStrides nb of elements contained per dimension of the output + * tensor + * @param[in] oGrad output tensor gradients + * @param[inout] biasesGrad biases gradients + */ +template <class B, class O> +static void conv1DBackwardBias(const array<DimSize_t, 3> &oDims, + const array<DimSize_t, 2> &oStrides, + const O *oGrad, + B *biasesGrad) { + array<DimSize_t, 2> oOffsets{0, 0}; + + for (DimSize_t batchIdx = 0; batchIdx < oDims[0]; ++batchIdx) { + oOffsets[0] = batchIdx * oStrides[0]; + + for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) { + oOffsets[1] = oChannel * oStrides[1] + oOffsets[0]; + + for (DimSize_t oIdx = 0; oIdx < oDims[2]; oIdx++) { + biasesGrad[oChannel] += oGrad[oOffsets[1] + oIdx]; + } + } + } +} + +/** + * @brief Backward kernel for 1D Convolution on CPU backend. + * @note INPUT & OUTPUT convention is the same as in the + * forward function + * + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] const stride + * @param[in] const kernelDims + * @param[in] const iDims input data dimensions + * @param[in] const oDims output data dimmensions + * @param[in] const oChannels output channel number + * @param[in] const input_ const input Tensor. + * @param[in] const weights_ const weight Tensor. + * @param[in] const biases_ const Biais Tensor. + * @param[in] const output_ Output Tensor. + * @param[in] const oGrad_ gradients of output data + * @param[inout] iGrad_ gradients of input data + * @param[inout] weightsGrad_ gradients of the kernel weights + * @param[inout] biasesGrad_ gradients of the kernel biases + */ +template <class I, class W, class B, class O> +void ConvImpl1D_cpu_backward_kernel(const array<DimSize_t,1> &stride, + const array<DimSize_t,1> &dilation, + const array<DimSize_t,1> &kernelDim, + const array<DimSize_t, 3> &inputDims, + const array<DimSize_t, 3> &outputDims, + const void *input_, + const void *weights_, + const void *oGrad_, + void *iGrad_, + void *weightsGrad_, + void *biasesGrad_) { + + const I *input = static_cast<const I *>(input_); + I *iGrad = static_cast<I *>(iGrad_); + const I *oGrad = static_cast<const I *>(oGrad_); + const W *weights = static_cast<const W *>(weights_); + W *weightsGrad = static_cast<W *>(weightsGrad_); + + ////////////////////////////// + // COMPUTING STRIDES + ////////////////////////////// + // NOTE: The ...Stride var represent the number of values contained in + // each dimension they will be used to compute the index offset of + // values while iterating on each tensor + // NOTE: They are 1 item shorter than their corresponding tensor as the + // number of total elements is not used except for gradient initialization + + // {batch_stride, channel_stride, dim0_stride, dim1_stride} + const array<DimSize_t, 2> inputStrides{inputDims[1] * inputDims[2], + inputDims[2]}; + const DimSize_t nbEltsInput = inputDims[0] * inputStrides[0]; + + // {batch_stride, channel_stride, dim0_stride, dim1_stride} + const array<DimSize_t, 2> outputStrides{outputDims[1] * outputDims[2], + outputDims[2]}; + + // NOTE: kernel dims = {iChannel, oChannel, kernelDim0, kernelDim1} + // kernel_strides = {iChannel, oChannel, kernelDim0} + const array<DimSize_t, 2> kernelStrides{ + inputDims[1] * kernelDim[0], + kernelDim[0], + }; + const DimSize_t nbEltsKernel = outputDims[1] * kernelStrides[0]; + + std::fill(iGrad, iGrad + nbEltsInput, I(0)); + std::fill(weightsGrad, weightsGrad + nbEltsKernel, W(0)); + + conv1DBackwardInput(stride, + dilation, + kernelDim, + kernelStrides, + weights, + outputDims, + outputStrides, + oGrad, + inputDims, + inputStrides, + iGrad); + + conv1DBackwardWeights(stride, + dilation, + inputDims, + inputStrides, + input, + outputDims, + outputStrides, + oGrad, + kernelDim, + kernelStrides, + weightsGrad); + + if (biasesGrad_ != nullptr) { + B *biasesGrad = static_cast<B *>(biasesGrad_); + std::fill(biasesGrad, biasesGrad + outputDims[1], B(0)); + conv1DBackwardBias(outputDims, outputStrides, oGrad, biasesGrad); + } +} + // Kernels registration to implementation entry point REGISTRAR(ConvImpl1D_cpu, - {{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}}, - {ProdConso::inPlaceModel, Aidge::ConvImpl1D_cpu_forward_kernel<float, float, float, float>, nullptr}); + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvImpl1D_cpu_forward_kernel<float, float, float, float>, + ConvImpl1D_cpu_backward_kernel<float, float, float, float>}); REGISTRAR(ConvImpl1D_cpu, - {{DataType::Any, DataFormat::NCHW}, {DataType::Float16, DataFormat::NCHW}}, - {ProdConso::inPlaceModel, Aidge::ConvImpl1D_cpu_forward_kernel<half_float::half, half_float::half, half_float::half, half_float::half>, nullptr}); + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float16, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvImpl1D_cpu_forward_kernel<half_float::half, + half_float::half, + half_float::half, + half_float::half>, + ConvImpl1D_cpu_backward_kernel<half_float::half, + half_float::half, + half_float::half, + half_float::half>}); REGISTRAR(ConvImpl1D_cpu, - {{DataType::Any, DataFormat::NCHW}, {DataType::Int32, DataFormat::NCHW}}, - {ProdConso::inPlaceModel, Aidge::ConvImpl1D_cpu_forward_kernel<int32_t, int32_t, int32_t, int32_t>, nullptr}); + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float64, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvImpl1D_cpu_forward_kernel<double, double, double, double>, + ConvImpl1D_cpu_backward_kernel<double, double, double, double>}); REGISTRAR(ConvImpl1D_cpu, - {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}}, - {ProdConso::inPlaceModel, Aidge::ConvImpl1D_cpu_forward_kernel<double, double, double, double>, nullptr}); - + {{DataType::Any, DataFormat::NCHW}, + {DataType::Int32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvImpl1D_cpu_forward_kernel<std::int32_t, + std::int32_t, + std::int32_t, + std::int32_t>, + ConvImpl1D_cpu_backward_kernel<std::int32_t, + std::int32_t, + std::int32_t, + std::int32_t>}); /** * @brief Forward kernel for 2D Convolution on CPU backend. @@ -124,16 +443,15 @@ REGISTRAR(ConvImpl1D_cpu, * @param output_ Output Tensor. */ template <class I, class W, class B, class O> -void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, - const std::array<DimSize_t, 2>& dilationDims, - const std::array<DimSize_t, 2>& kernelDims, - const std::array<DimSize_t, 4> &inputDims, - DimSize_t outChannels, - const void *input_, - const void *weights_, - const void *biases_, - void *output_) -{ +void ConvImpl2D_cpu_forward_kernel(const array<DimSize_t, 2> &strideDims, + const array<DimSize_t, 2> &dilationDims, + const array<DimSize_t, 2> &kernelDims, + const array<DimSize_t, 4> &inputDims, + DimSize_t outChannels, + const void *input_, + const void *weights_, + const void *biases_, + void *output_) { // FIXME: missing convolution attributes as arguments const I *input = static_cast<const I *>(input_); const W *weights = static_cast<const W *>(weights_); @@ -141,59 +459,102 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, O *output = static_cast<O *>(output_); // output H size - const DimSize_t dilated_kernel_x = dilationDims[0]*(kernelDims[0] - 1) + 1; - const std::size_t oxSize = - static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) / - static_cast<float>(strideDims[0]))); + const DimSize_t dilated_kernel_x = + dilationDims[0] * (kernelDims[0] - 1) + 1; + const std::size_t oxSize = static_cast<std::size_t>(std::floor( + static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) / + static_cast<float>(strideDims[0]))); // output W size - const DimSize_t dilated_kernel_y = dilationDims[1]*(kernelDims[1] - 1) + 1; - const std::size_t oySize = - static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) / - static_cast<float>(strideDims[1]))); - + const DimSize_t dilated_kernel_y = + dilationDims[1] * (kernelDims[1] - 1) + 1; + const std::size_t oySize = static_cast<std::size_t>(std::floor( + static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) / + static_cast<float>(strideDims[1]))); // TODO: kernel computation // output (batch, outCh, Xout, Yout) // input (batch, inCh, Xin, Yin) // weight (outCh, inCh, kernelX, kernelY) // does not take Dilation attribute into account - const std::size_t outChannels_s = oxSize * oySize; + const std::size_t outChannels_s = oxSize * oySize; if (dilated_kernel_x == 3 && dilated_kernel_y == 3) { for (std::size_t batch = 0; batch < inputDims[0]; ++batch) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { // If bias = nullptr, set B(0) B biasVal = (biases != nullptr) ? biases[outCh] : B(0); - std::fill(output, output+outChannels_s, biasVal); + std::fill(output, output + outChannels_s, biasVal); for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { - std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; - const std::size_t wIndex = (inCh + outCh*inputDims[1]) * 9; - if (strideDims[0] == 1 && strideDims[1]==1) { - for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex-=inputDims[3]) { + std::size_t iIndex = (inCh + batch * inputDims[1]) * + inputDims[2] * inputDims[3]; + const std::size_t wIndex = + (inCh + outCh * inputDims[1]) * 9; + if (strideDims[0] == 1 && strideDims[1] == 1) { + for (std::size_t ox = 0, oIndex = 0; ox < oxSize; + ++ox, oIndex += oySize, iIndex -= inputDims[3]) { for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy]+weights[wIndex+1]*input[iIndex+oy+1]+weights[wIndex+2]*input[iIndex+oy+2]; + output[oIndex + oy] += + weights[wIndex + 0] * input[iIndex + oy] + + weights[wIndex + 1] * + input[iIndex + oy + 1] + + weights[wIndex + 2] * + input[iIndex + oy + 2]; } - iIndex+=inputDims[3]; + iIndex += inputDims[3]; for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy]+weights[wIndex+4]*input[iIndex+oy+1]+weights[wIndex+5]*input[iIndex+oy+2]; + output[oIndex + oy] += + weights[wIndex + 3] * input[iIndex + oy] + + weights[wIndex + 4] * + input[iIndex + oy + 1] + + weights[wIndex + 5] * + input[iIndex + oy + 2]; } - iIndex+=inputDims[3]; + iIndex += inputDims[3]; for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy]+weights[wIndex+7]*input[iIndex+oy+1]+weights[wIndex+8]*input[iIndex+oy+2]; + output[oIndex + oy] += + weights[wIndex + 6] * input[iIndex + oy] + + weights[wIndex + 7] * + input[iIndex + oy + 1] + + weights[wIndex + 8] * + input[iIndex + oy + 2]; } } } else { - for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=(strideDims[0]-2)*inputDims[3]) { + for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, + oIndex += oySize, + iIndex += (strideDims[0] - + 2) * inputDims[3]) { for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy*strideDims[1]]+weights[wIndex+1]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+2]*input[iIndex+oy*strideDims[1]+2]; + output[oIndex + oy] += + weights[wIndex + 0] * + input[iIndex + oy * strideDims[1]] + + weights[wIndex + 1] * + input[iIndex + oy * strideDims[1] + + 1] + + weights[wIndex + 2] * + input[iIndex + oy * strideDims[1] + 2]; } - iIndex+=inputDims[3]; + iIndex += inputDims[3]; for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy*strideDims[1]]+weights[wIndex+4]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+5]*input[iIndex+oy*strideDims[1]+2]; + output[oIndex + oy] += + weights[wIndex + 3] * + input[iIndex + oy * strideDims[1]] + + weights[wIndex + 4] * + input[iIndex + oy * strideDims[1] + + 1] + + weights[wIndex + 5] * + input[iIndex + oy * strideDims[1] + 2]; } - iIndex+=inputDims[3]; + iIndex += inputDims[3]; for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy*strideDims[1]]+weights[wIndex+7]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+8]*input[iIndex+oy*strideDims[1]+2]; + output[oIndex + oy] += + weights[wIndex + 6] * + input[iIndex + oy * strideDims[1]] + + weights[wIndex + 7] * + input[iIndex + oy * strideDims[1] + + 1] + + weights[wIndex + 8] * + input[iIndex + oy * strideDims[1] + 2]; } } } @@ -206,18 +567,26 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { // If bias = nullptr, set B(0) B biasVal = (biases != nullptr) ? biases[outCh] : B(0); - std::fill(output, output+outChannels_s, biasVal); + std::fill(output, output + outChannels_s, biasVal); for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { - std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; - const std::size_t wIndex = (inCh + outCh*inputDims[1]); + std::size_t iIndex = (inCh + batch * inputDims[1]) * + inputDims[2] * inputDims[3]; + const std::size_t wIndex = (inCh + outCh * inputDims[1]); if (strideDims[0] == 1 && strideDims[1] == 1) { - for (std::size_t oIndex = 0; oIndex < oxSize*oySize; ++oIndex, ++iIndex) { + for (std::size_t oIndex = 0; oIndex < oxSize * oySize; + ++oIndex, ++iIndex) { output[oIndex] += weights[wIndex] * input[iIndex]; } - } else { - for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=inputDims[3]*strideDims[0]) { - for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) { - output[oIndex + oy] += weights[wIndex+0]*input[iIndex+iy]; + } else { + for (std::size_t ox = 0, oIndex = 0; ox < oxSize; + ++ox, + oIndex += oySize, + iIndex += + inputDims[3] * strideDims[0]) { + for (std::size_t oy = 0, iy = 0; oy < oySize; + ++oy, iy += strideDims[1]) { + output[oIndex + oy] += + weights[wIndex + 0] * input[iIndex + iy]; } } } @@ -230,21 +599,36 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { // If bias = nullptr, set B(0) B biasVal = (biases != nullptr) ? biases[outCh] : B(0); - std::fill(output, output+outChannels_s, biasVal); + std::fill(output, output + outChannels_s, biasVal); for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { - std::size_t iIndex_channel = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; - const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1]; + std::size_t iIndex_channel = + (inCh + batch * inputDims[1]) * inputDims[2] * + inputDims[3]; + const std::size_t wIndex = (inCh + outCh * inputDims[1]) * + kernelDims[0] * kernelDims[1]; // loop over each ouput line - for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex_channel+=inputDims[3]*strideDims[0]) { + for (std::size_t ox = 0, oIndex = 0; ox < oxSize; + ++ox, + oIndex += oySize, + iIndex_channel += + inputDims[3] * strideDims[0]) { // loop over associated input line - for (std::size_t ky = 0, ix = 0; ky < kernelDims[0]; ++ky, ix += inputDims[3]*dilationDims[0]) { + for (std::size_t ky = 0, ix = 0; ky < kernelDims[0]; + ++ky, ix += inputDims[3] * dilationDims[0]) { // loop over the entire line - for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) { - const std::size_t iIndex = iIndex_channel + ix + iy; - // loop over elements assosicated with one output - for (std::size_t kx = 0; kx < kernelDims[0]; ++kx) { - output[oIndex + oy] += weights[wIndex+kernelDims[0]*ky+kx]*input[iIndex+kx*dilationDims[1]]; + for (std::size_t oy = 0, iy = 0; oy < oySize; + ++oy, iy += strideDims[1]) { + const std::size_t iIndex = + iIndex_channel + ix + iy; + // loop over elements assosicated with one + // output + for (std::size_t kx = 0; kx < kernelDims[0]; + ++kx) { + output[oIndex + oy] += + weights[wIndex + kernelDims[0] * ky + + kx] * + input[iIndex + kx * dilationDims[1]]; } } } @@ -256,21 +640,380 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, } } +/** + * @brief perform backpropagation for the input + * @note INPUT & OUTPUT convention is the same as in the + * forward function + * @note formula : + * for i in 0..input_size: + * for n in 0..weight_size: + * dL dYn dL + * ---- = ---- ---- + * dXi dXi Yn + * with : dYn / dXi = w_k + * for each input value + * for each weight + * for each output + * multiply the weight with the associated value + * @note kernel & stride are passed as single integers as they are just arrays + * of length 1 + * @note reminder that kernel dimensions are + * {outChannels, inChannels, {kernelDims}} + * <=> {oDims[1], iDims[1], kernelDim} + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam O Output data type. + * @param[in] stride stride parameter of the convolution operator + * @param[in] dilation dilation parameter of the convolution operator + * @param[in] kDims dimension of the kernel + * @param[in] kStrides nb of elements contained per dimension of the kernel + * @param[in] weights weights values + * @param[in] oDims dimensions of the output + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[in] oGrad output gradient + * @param[in] iDims input dimensions + * @param[in] iStrides nb of elements contained per dimension of the input + * @param[inout] iGrad gradients of the input to update + */ +template <class I, class W, class O> +void conv2DBackwardInput(const array<DimSize_t, 2> &stride, + const array<DimSize_t, 2> &dilation, + const array<DimSize_t, 2> &kDims, + const array<DimSize_t, 3> &kStrides, + const W *weights, + const array<DimSize_t, 4> &oDims, + const array<DimSize_t, 3> &oStrides, + const O *oGrad, + const array<DimSize_t, 4> &iDims, + const array<DimSize_t, 3> &iStrides, + I *iGrad) { + // records index offsets for each dimension that have a stride (== all + // dimension except the last) for every parsed tensor + array<DimSize_t, 3> kOffset{}; + array<DimSize_t, 3> iOffset{}; + array<DimSize_t, 3> oOffset{}; + + for (std::size_t batch = 0; batch < iDims[0]; ++batch) { + iOffset[0] = batch * iStrides[0]; + oOffset[0] = batch * oStrides[0]; + + for (DimSize_t oChannel = 0; oChannel < oDims[1]; oChannel++) { + oOffset[1] = (oChannel * oStrides[1]) + oOffset[0]; + kOffset[0] = (oChannel * kStrides[0]); + + for (std::size_t iChannel = 0; iChannel < iDims[1]; ++iChannel) { + iOffset[1] = (iChannel * iStrides[1]) + iOffset[0]; + kOffset[1] = iChannel * kStrides[1] + kOffset[0]; + + for (DimSize_t oX = 0; oX < oDims[2]; ++oX) { + oOffset[2] = (oX * oStrides[2]) + oOffset[1]; + + auto iX = oX * stride[0]; + iOffset[2] = (iX * iStrides[2]) + iOffset[1]; + for (DimSize_t oY = 0; oY < oDims[3]; ++oY) { + auto oIdx = oOffset[2] + oY; + + auto iY = oY * stride[1]; + auto iIdx = iOffset[2] + iY; + + for (DimSize_t kX = 0; kX < kDims[0]; ++kX) { + auto kDilX = kX * dilation[0]; + auto iDilKXOffset = kDilX * iStrides[2]; + + kOffset[2] = (kX * kStrides[2]) + kOffset[1]; + + for (DimSize_t kY = 0; kY < kDims[1]; ++kY) { + auto kDilY = kY * dilation[1]; + + iGrad[iIdx + iDilKXOffset + kDilY] += + weights[kOffset[2] + kY] * oGrad[oIdx]; + } + } + } + } + } + } + } +} + +/** + * @brief computes weight backpropagation for conv2D operation + * @note INPUT & OUTPUT convention is the same as in the + * forward function + * weight grad + * for i in 0..weight_size: + * for n in 0..output_size: + * dL dYn dL + * ---- = ---- ---- + * dwi dwi Yn + * with : dYn / dwi = x_k + * @tparam I input dtype + * @tparam W weight dtype + * @tparam O output dtype + * @param[in] iDims input data dimensions + * @param[in] iBatchStride nb element in each input data batch + * @param[in] iChannelStride nb element in each input data channel + * @param[in] input input data + * @param[in] oDims output data dimmensions + * @param[in] oBatchStride nb element in each output data batch + * @param[in] oChannelStride nb element in each output data channel + * @param[in] oGrad gradients of output data + * @param[in] stride + * @param[in] kernelDims + * @param[inout] weightsGrad gradients of the kernel weights + */ +template <class I, class W, class O> +void conv2DBackwardWeights(const array<DimSize_t, 4> &iDims, + const array<DimSize_t, 3> &iStrides, + const I *input, + const array<DimSize_t, 4> &oDims, + const array<DimSize_t, 3> &oStrides, + const O *oGrad, + const array<DimSize_t, 2> &kDim, + const array<DimSize_t, 3> &kStrides, + const array<DimSize_t, 2> &stride, + const array<DimSize_t, 2> &dilation, + W *weightsGrad) { + // records index offsets for each dimension that have a stride (== all + // dimension except the last) for every parsed tensor + array<DimSize_t, 3> iOffsets{0, 0, 0}; + array<DimSize_t, 3> oOffsets{0, 0, 0}; + array<DimSize_t, 3> kOffsets{0, 0, 0}; + + for (DimSize_t batchIdx = 0; batchIdx < oDims[0]; ++batchIdx) { + iOffsets[0] = batchIdx * iStrides[0]; + oOffsets[0] = batchIdx * oStrides[0]; + + for (DimSize_t iChannel = 0; iChannel < iDims[1]; ++iChannel) { + iOffsets[1] = iChannel * iStrides[1] + iOffsets[0]; + kOffsets[0] = iChannel * kStrides[0]; + + for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) { + oOffsets[1] = oChannel * oStrides[1] + oOffsets[0]; + kOffsets[1] = oChannel * kStrides[1] + kOffsets[0]; + + for (DimSize_t kX = 0; kX < kDim[0]; ++kX) { + kOffsets[2] = kX * kStrides[2] + kOffsets[1]; + for (DimSize_t kY = 0; kY < kDim[1]; ++kY) { + + for (DimSize_t oX = 0; oX < oDims[2]; ++oX) { + const DimSize_t iX = + oX * stride[0] + kX * dilation[0]; + + oOffsets[2] = oX * oStrides[2] + oOffsets[1]; + iOffsets[2] = iX * iStrides[2] + iOffsets[1]; + + for (DimSize_t oY = 0; oY < oDims[3]; ++oY) { + const DimSize_t iY = + oY * stride[1] + kY * dilation[1]; + + weightsGrad[kOffsets[2] + kY] += + input[iOffsets[2] + iY] * + oGrad[oOffsets[2] + oY]; + } + } + } + } + } + } + } +} + +/** + * @brief computes bias backpropagation for conv2D operation + * @note INPUT & OUTPUT convention is the same as in the + * forward function + * @note formula : + * Bias grad: + * for i in 0..bias_size: + * for n in 0..output_size: + * dL dYn dL + * ---- = ---- ---- + * dbi dbi Yn + * with : dYn / dbi = 1 + * + * Hence the partial derivative of the loss wrt bias is the + * output loss Hence the bias grad is just the sum of the + * loss values over the batch + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] oDims output tensor dimensions + * @param[in] oStrides nb of elements contained per dimension of the + * output + * @param[in] oGrad output tensor gradients + * @param[inout] biasesGrad biases gradients + */ +template <class B, class O> +static void conv2DBackwardBias(const array<DimSize_t, 4> &oDims, + const array<DimSize_t, 3> &oStrides, + const O *oGrad, + B *biasesGrad) { + // records all index offsets for output tensor + array<DimSize_t, 3> oOffsets{}; + for (DimSize_t batchIdx = 0; batchIdx < oDims[0]; ++batchIdx) { + oOffsets[0] = batchIdx * oStrides[0]; + + for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) { + oOffsets[1] = oChannel * oStrides[1] + oOffsets[0]; + + for (DimSize_t oX = 0; oX < oDims[2]; ++oX) { + oOffsets[2] = oX * oStrides[2] + oOffsets[1]; + + for (DimSize_t oY = 0; oY < oDims[3]; ++oY) { + biasesGrad[oChannel] += oGrad[oOffsets[2] + oY]; + } + } + } + } +} + +/** + * @brief Backward kernel for 2D Convolution on CPU backend. + * @note INPUT & OUTPUT convention is the same as in the + * forward function + * + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] const stride attribute of conv operator + * @param[in] const dilation attribute of conv operator + * @param[in] const kernelDims + * @param[in] const iDims input data dimensions + * @param[in] const oDims output data dimmensions + * @param[in] const input_ input tensor. + * @param[in] const weights_ kernel tensor. + * @param[in] const oGrad_ output tensor gradient. + * @param[inout] iGrad_ input tensor gradient. + * @param[inout] weightsGrad_ kernel weights tensor gradients + * @param[inout] biasesGrad_ kernel biases tensor gradients + */ +template <class I, class W, class B, class O> +void ConvImpl2D_cpu_backward_kernel(const array<DimSize_t, 2> &stride, + const array<DimSize_t, 2> &dilation, + const array<DimSize_t, 2> &kernelDims, + const array<DimSize_t, 4> &inputDims, + const array<DimSize_t, 4> &outputDims, + const void *input_, + const void *weights_, + const void *oGrad_, + void *iGrad_, + void *weightsGrad_, + void *biasesGrad_) { + + const I *input = static_cast<const I *>(input_); + I *iGrad = static_cast<I *>(iGrad_); + const I *outputGrad = static_cast<const I *>(oGrad_); + const W *weights = static_cast<const W *>(weights_); + W *weightsGrad = static_cast<W *>(weightsGrad_); + + ////////////////////////////// + // COMPUTING STRIDES + ////////////////////////////// + // NOTE: The ...Stride var represent the number of values contained in + // each dimension they will be used to compute the index offset of + // values while iterating on each tensor + // NOTE: They are 1 item shorter than their corresponding tensor as the + // number of total elements is not used except for gradient initialization + + // {batch_stride, channel_stride, dim0_stride, dim1_stride} + const array<DimSize_t, 3> inputStrides{ + inputDims[1] * inputDims[2] * inputDims[3], + inputDims[2] * inputDims[3], + inputDims[3]}; + const DimSize_t nbEltsInput = inputDims[0] * inputStrides[0]; + + // {batch_stride, channel_stride, dim0_stride, dim1_stride} + const array<DimSize_t, 3> outputStrides{ + outputDims[1] * outputDims[2] * outputDims[3], + outputDims[2] * outputDims[3], + outputDims[3]}; + + // NOTE: kernel dims = {iChannel, oChannel, kernelDim0, kernelDim1} + // kernel_strides = {iChannel, oChannel, kernelDim0} + const array<DimSize_t, 3> kernelStrides{ + inputDims[1] * kernelDims[0] * kernelDims[1], + kernelDims[0] * kernelDims[1], + kernelDims[1]}; + + const DimSize_t nbEltsKernel = outputDims[1] * kernelStrides[0]; + + //////////////////////////// + // prepping gradient arrays + std::fill(iGrad, iGrad + nbEltsInput, I(0)); + std::fill(weightsGrad, weightsGrad + nbEltsKernel, W(0)); + + conv2DBackwardInput(stride, + dilation, + kernelDims, + kernelStrides, + weights, + outputDims, + outputStrides, + outputGrad, + inputDims, + inputStrides, + iGrad); + + conv2DBackwardWeights(inputDims, + inputStrides, + input, + outputDims, + outputStrides, + outputGrad, + kernelDims, + kernelStrides, + stride, + dilation, + weightsGrad); + + if (biasesGrad_ != nullptr) { + B *biasesGrad = static_cast<B *>(biasesGrad_); + std::fill(biasesGrad, biasesGrad + outputDims[1], B(0)); + conv2DBackwardBias(outputDims, outputStrides, outputGrad, biasesGrad); + } +} // Kernels registration to implementation entry point REGISTRAR(ConvImpl2D_cpu, - {{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}}, - {ProdConso::inPlaceModel, Aidge::ConvImpl2D_cpu_forward_kernel<float, float, float, float>, nullptr}); -REGISTRAR(ConvImpl2D_cpu, - {{DataType::Any, DataFormat::NCHW}, {DataType::Float16, DataFormat::NCHW}}, - {ProdConso::inPlaceModel, Aidge::ConvImpl2D_cpu_forward_kernel<half_float::half, half_float::half, half_float::half, half_float::half>, nullptr}); -REGISTRAR(ConvImpl2D_cpu, - {{DataType::Any, DataFormat::NCHW}, {DataType::Int32, DataFormat::NCHW}}, - {ProdConso::inPlaceModel, Aidge::ConvImpl2D_cpu_forward_kernel<int32_t, int32_t, int32_t, int32_t>, nullptr}); + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + Aidge::ConvImpl2D_cpu_forward_kernel<float, float, float, float>, + Aidge::ConvImpl2D_cpu_backward_kernel<float, float, float, float>}); REGISTRAR(ConvImpl2D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float16, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + Aidge::ConvImpl2D_cpu_forward_kernel<half_float::half, + half_float::half, + half_float::half, + half_float::half>, + Aidge::ConvImpl2D_cpu_backward_kernel<half_float::half, + half_float::half, + half_float::half, + half_float::half>}); +REGISTRAR( + ConvImpl2D_cpu, {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}}, - {ProdConso::inPlaceModel, Aidge::ConvImpl2D_cpu_forward_kernel<double, double, double, double>, nullptr}); -} // namespace Aidge + {ProdConso::inPlaceModel, + Aidge::ConvImpl2D_cpu_forward_kernel<double, double, double, double>, + Aidge::ConvImpl2D_cpu_backward_kernel<double, double, double, double>}); +REGISTRAR(ConvImpl2D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Int32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvImpl2D_cpu_forward_kernel<std::int32_t, + std::int32_t, + std::int32_t, + std::int32_t>, + ConvImpl2D_cpu_backward_kernel<std::int32_t, + std::int32_t, + std::int32_t, + std::int32_t>}); +} // namespace Aidge #endif /* AIDGE_CPU_OPERATOR_CONVIMPL_KERNELS_H_ */ diff --git a/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp b/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp new file mode 100644 index 0000000000000000000000000000000000000000..7604a96a18e7be44f4c2e8970a0b60b1c4ad918b --- /dev/null +++ b/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp @@ -0,0 +1,59 @@ + +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_H_ +#define AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_H_ + +#include <array> + +#include "aidge/backend/cpu/operator/OperatorImpl.hpp" +#include "aidge/operator/ConvTranspose.hpp" +#include "aidge/utils/Registrar.hpp" +#include "aidge/utils/Types.h" + +namespace Aidge { + +using std::array; + +// Operator implementation entry point for the backend +using ConvTranspose1D_Op = ConvTranspose_Op<1>; +using ConvTransposeImpl1D_cpu = + OperatorImpl_cpu<ConvTranspose1D_Op, + void(const array<DimSize_t,1> &, + const array<DimSize_t,1> &, + const array<DimSize_t,1> &, + const array<DimSize_t, 3> &, + const array<DimSize_t, 3> &, + const void *, + const void *, + const void *, + void *)>; + +using ConvTranspose2D_Op = ConvTranspose_Op<2>; +using ConvTransposeImpl2D_cpu = + OperatorImpl_cpu<ConvTranspose2D_Op, + void(const array<DimSize_t, 2> &, + const array<DimSize_t, 2> &, + const array<DimSize_t, 2> &, + const array<DimSize_t, 4> &, + const array<DimSize_t, 4> &, + const void *, + const void *, + const void *, + void *)>; + +// Implementation entry point registration to Operator +REGISTRAR(ConvTranspose1D_Op, "cpu", ConvTransposeImpl1D_cpu::create); +REGISTRAR(ConvTranspose2D_Op, "cpu", ConvTransposeImpl2D_cpu::create); +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp new file mode 100644 index 0000000000000000000000000000000000000000..e11dd2625ae1645a8e7c5482b1635b85fb475b06 --- /dev/null +++ b/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp @@ -0,0 +1,305 @@ +/******************************************************************************** + * Copyright (c) 2025 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#ifndef AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_KERNELS_H_ +#define AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_KERNELS_H_ + +#include <array> + +#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp" +#include "aidge/utils/Registrar.hpp" +#include <aidge/backend/cpu/operator/ConvImpl_kernels.hpp> +#include <aidge/data/Data.hpp> +#include <aidge/data/half.hpp> +#include <aidge/scheduler/ProdConso.hpp> +#include <aidge/utils/Types.h> + +namespace Aidge { + +using std::array; + +//////////////////////////////////////////////////////// +//////////////////////////////////////////////////////// +// 1D +//////////////////////////////////////////////////////// +//////////////////////////////////////////////////////// + +/** + * @brief performs forward bias operation for convtranspose operator + * + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] bias bias values + * @param[in] oDims dimensions of the output + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[out] output + */ +template <class B, class O> +static void convTranspose1DForwardBias(const B *biases, + const array<DimSize_t, 3> &oDims, + const array<DimSize_t, 2> &oStrides, + O *output) { + array<DimSize_t, 2> outOffsets{0, 0}; + for (DimSize_t batch = 0; batch < oDims[0]; ++batch) { + outOffsets[0] = batch * oStrides[0]; + for (DimSize_t outCh = 0; outCh < oDims[1]; ++outCh) { + outOffsets[1] = outCh * oStrides[1] + outOffsets[0]; + // If bias = nullptr, set B(0) + B biasVal = (biases != nullptr) ? biases[outCh] : B(0); + std::fill(output + outOffsets[1], + output + (outOffsets[1] + oDims[2]), + biasVal); + } + } +} + +/** + * @brief forward kernel for convtranspose + * @note ConvTranspose forward is simply convolution backward kernel. + * Check convolution functions for more in-depth details on how the + subfunctions are built. + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] stride stride parameter of the convTranspose operator + * @param[in] dilation dilation parameter of the convTranspose operator + * @param[in] inputDims input dimensions + * @param[in] outputDims output tensor dimensions + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[in] input_ values + * @param[in] weight_ values + * @param[in] biases_ values + * @param[out] output + */ +template <class I, class W, class B, class O> +void ConvTransposeImpl1D_cpu_forward_kernel( + const array<DimSize_t, 1> &stride, + const array<DimSize_t, 1> &dilation, + const array<DimSize_t, 1> &kernelDim, + const array<DimSize_t, 3> &inputDims, + const array<DimSize_t, 3> &outputDims, + const void *input_, + const void *weights_, + const void *biases_, + void *output_) { + + const I *input = static_cast<const I *>(input_); + const W *weights = static_cast<const W *>(weights_); + O *output = static_cast<O *>(output_); + + // {batch_stride, channel_stride, dim0_stride} + const array<DimSize_t, 2> inputStrides{inputDims[1] * inputDims[2], + inputDims[2]}; + + // {batch_stride, channel_stride, dim0_stride} + const array<DimSize_t, 2> outputStrides{outputDims[1] * outputDims[2], + outputDims[2]}; + + // NOTE: kernel dims = {inChannels, outChannels, kernelDims[0]} + const array<DimSize_t, 2> kernelStrides{ + outputDims[1] * kernelDim[0], + kernelDim[0], + }; + + if (biases_ != nullptr) { + const B *biases = static_cast<const B *>(biases_); + convTranspose1DForwardBias(biases, outputDims, outputStrides, output); + } + + conv1DBackwardInput(stride, + dilation, + kernelDim, + kernelStrides, + weights, + inputDims, + inputStrides, + input, + outputDims, + outputStrides, + output); +} + +REGISTRAR(ConvTransposeImpl1D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Int32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl1D_cpu_forward_kernel<std::int32_t, + std::int32_t, + std::int32_t, + std::int32_t>, + nullptr}); +REGISTRAR(ConvTransposeImpl1D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl1D_cpu_forward_kernel<float, float, float, float>, + nullptr}); +REGISTRAR(ConvTransposeImpl1D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float16, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl1D_cpu_forward_kernel<half_float::half, + half_float::half, + half_float::half, + half_float::half>, + nullptr}); +REGISTRAR( + ConvTransposeImpl1D_cpu, + {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl1D_cpu_forward_kernel<double, double, double, double>, + nullptr}); + +//////////////////////////////////////////////////////// +//////////////////////////////////////////////////////// +// 2D +//////////////////////////////////////////////////////// +//////////////////////////////////////////////////////// + +/** + * @brief performs forward bias operation for convtranspose operator + * + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] bias bias values + * @param[in] oDims dimensions of the output + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[out] output + */ +template <class B, class O> +static void convTranspose2DForwardBias(const B *biases, + const array<DimSize_t, 4> &oDims, + const array<DimSize_t, 3> &oStrides, + O *output) { + array<DimSize_t, 2> outOffsets{0, 0}; + + for (DimSize_t batch = 0; batch < oDims[0]; ++batch) { + outOffsets[0] = batch * oStrides[0]; + + for (DimSize_t outCh = 0; outCh < oDims[1]; ++outCh) { + outOffsets[1] = outCh * oStrides[1] + outOffsets[0]; + // If bias = nullptr, set B(0) + B biasVal = (biases != nullptr) ? biases[outCh] : B(0); + std::fill(output + outOffsets[1], + (output + outOffsets[1]) + oStrides[1], + biasVal); + } + } +} + +/** + * @brief forward kernel for convtranspose + * @note ConvTranspose forward is simply convolution backward kernel. + * Check convolution functions for more in-depth details on how the + subfunctions are built. + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] stride stride parameter of the convTranspose operator + * @param[in] dilation dilation parameter of the convTranspose operator + * @param[in] inputDims input dimensions + * @param[in] outputDims output tensor dimensions + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[in] input_ values + * @param[in] weight_ values + * @param[in] biases_ values + * @param[out] output + */ +template <class I, class W, class B, class O> +void ConvTransposeImpl2D_cpu_forward_kernel( + const array<DimSize_t, 2> &stride, + const array<DimSize_t, 2> &dilation, + const array<DimSize_t, 2> &kernelDims, + const array<DimSize_t, 4> &inputDims, + const array<DimSize_t, 4> &outputDims, + const void *input_, + const void *weights_, + const void *biases_, + void *output_) { + + auto input = static_cast<const I *>(input_); + auto weights = static_cast<const W *>(weights_); + auto output = static_cast<O *>(output_); + + // {channel_stride, dim0_stride, dim1_stride} + const array<DimSize_t, 3> inputStrides{ + inputDims[1] * inputDims[2] * inputDims[3], + inputDims[2] * inputDims[3], + inputDims[3]}; + + // {channel_stride, dim0_stride, dim1_stride} + const array<DimSize_t, 3> outputStrides{ + outputDims[1] * outputDims[2] * outputDims[3], + outputDims[2] * outputDims[3], + outputDims[3]}; + + // NOTE: kernel dims = {inChannels, outChannels, kernelDims[0], + // kernelDims[1]} + const array<DimSize_t, 3> kernelStrides{ + outputDims[1] * kernelDims[0] * kernelDims[1], + kernelDims[0] * kernelDims[1], + kernelDims[1], + }; + + if (biases_ != nullptr) { + auto biases = static_cast<const B *>(biases_); + convTranspose2DForwardBias(biases, outputDims, outputStrides, output); + } + + conv2DBackwardInput(stride, + dilation, + kernelDims, + kernelStrides, + weights, + inputDims, + inputStrides, + input, + outputDims, + outputStrides, + output); +} + +REGISTRAR(ConvTransposeImpl2D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Int32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl2D_cpu_forward_kernel<std::int32_t, + std::int32_t, + std::int32_t, + std::int32_t>, + nullptr}); +REGISTRAR(ConvTransposeImpl2D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float16, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl2D_cpu_forward_kernel<half_float::half, + half_float::half, + half_float::half, + half_float::half>, + nullptr}); +REGISTRAR(ConvTransposeImpl2D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl2D_cpu_forward_kernel<float, float, float, float>, + nullptr}); +REGISTRAR( + ConvTransposeImpl2D_cpu, + {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl2D_cpu_forward_kernel<double, double, double, double>, + nullptr}); + +} // namespace Aidge + +#endif /* AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_KERNELS_H_ */ diff --git a/src/operator/ConvImpl.cpp b/src/operator/ConvImpl.cpp index fdfe19fbf4bf3e71c86aa28b966cfb21a1b5ba40..d23a9968ffb424b4639e0fcd2629a3a1cc2e11c3 100644 --- a/src/operator/ConvImpl.cpp +++ b/src/operator/ConvImpl.cpp @@ -13,14 +13,11 @@ #include "aidge/backend/cpu/operator/ConvImpl_kernels.hpp" #include <cassert> -#include <chrono> // std::chrono::milliseconds -#include <numeric> // std::accumulate -#include <thread> // std::this_thread::sleep_for -#include <vector> #include "aidge/backend/cpu/data/GetCPUPtr.h" #include "aidge/operator/Conv.hpp" -#include "aidge/utils/Types.h" + +namespace Aidge { template <> void Aidge::ConvImpl1D_cpu::forward() { @@ -43,21 +40,60 @@ void Aidge::ConvImpl1D_cpu::forward() { const auto& input2 = (op_.getInput(2)) ? op_.getInput(2)->refCastFrom(input2Fallback, *op_.getOutput(0)) : Tensor(); // Call kernel - impl.forward(op_.strideDims(), - op_.dilationDims(), - op_.kernelDims(), - op_.getInput(0)->template dims<3>(), // input dimensions - dynamic_cast<const Conv_Op<1>&>(mOp).outChannels(), // outChannels - input0.getImpl()->rawPtr(), // input - input1.getImpl()->rawPtr(), // weight - op_.getInput(2) ? input2.getImpl()->rawPtr() : nullptr, // bias - getCPUPtr(mOp.getRawOutput(0)) // output - ); + impl.forward( + op_.strideDims(), + op_.dilationDims(), + op_.kernelDims(), + op_.getInput(0)->template dims<3>(), // input dimensions + dynamic_cast<const Conv_Op<1> &>(mOp).outChannels(), // outChannels + input0.getImpl()->rawPtr(), // input + input1.getImpl()->rawPtr(), // weight + op_.getInput(2) ? input2.getImpl()->rawPtr() : nullptr, // bias + getCPUPtr(mOp.getRawOutput(0)) // output + ); } -template <> -void Aidge::ConvImpl1D_cpu::backward() { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for Conv_Op<1> on backend cpu"); +template <> void ConvImpl1D_cpu::backward() { + const auto &op = dynamic_cast<const Conv1D_Op &>(mOp); + const auto &outputGrad = op.getOutput(0)->grad(); + AIDGE_ASSERT(outputGrad, "{}: missing ouput #0 gradient", op.type()); + AIDGE_ASSERT(op.getInput(0)->grad(), + "{}: missing data input(#0) gradient", + op.type()); + AIDGE_ASSERT(op.getInput(1)->grad(), + "{}: missing weight input(#1) gradient", + op.type()); + + std::shared_ptr<Tensor> inputDataGradFallback, inputWeightGradFallback, + inputBiasGradFallback; + const auto &inputDataGrad = + op.getInput(0)->grad()->refCastFrom(inputDataGradFallback, + *(op.getOutput(0))); + const auto &inputWeightGrad = + op.getInput(1)->grad()->refCastFrom(inputWeightGradFallback, + *(op.getOutput(0))); + const auto &inputBiasGrad = + (op.getInput(2) && op.getInput(2)->grad()) + ? op.getInput(2)->grad()->refCastFrom(inputBiasGradFallback, + *(op.getOutput(0))) + : Tensor(); + + // Call kernel + const auto impl = + Registrar<ConvImpl1D_cpu>::create(getBestMatch(getRequiredSpec())); + impl.backward( + op.strideDims(), + op.dilationDims(), + op.kernelDims(), + op.getInput(0)->template dims<3>(), + op.getOutput(0)->template dims<3>(), + + getCPUPtr(op.getInput(0)), + getCPUPtr(op.getInput(1)), + getCPUPtr(outputGrad), + inputDataGrad.getImpl()->rawPtr(), + inputWeightGrad.getImpl()->rawPtr(), + op.getInput(2) ? inputBiasGrad.getImpl()->rawPtr() : nullptr); } template <> @@ -93,7 +129,48 @@ void Aidge::ConvImpl2D_cpu::forward() { ); } -template <> -void Aidge::ConvImpl2D_cpu::backward() { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Backward not yet implemented for Conv_Op<2> on backend cpu"); + +template <> void ConvImpl2D_cpu::backward() { + const auto &op = dynamic_cast<const Conv2D_Op &>(mOp); + const auto &outputGrad = op.getOutput(0)->grad(); + AIDGE_ASSERT(outputGrad, "{}: missing ouput #0 gradient", op.type()); + AIDGE_ASSERT(op.getInput(0)->grad(), + "{}: missing data input(#0) gradient", + op.type()); + AIDGE_ASSERT(op.getInput(1)->grad(), + "{}: missing weight input(#1) gradient", + op.type()); + + std::shared_ptr<Tensor> inputDataGradFallback, inputWeightGradFallback, + inputBiasGradFallback; + const auto &inputDataGrad = + op.getInput(0)->grad()->refCastFrom(inputDataGradFallback, + *(op.getOutput(0))); + const auto &inputWeightGrad = + op.getInput(1)->grad()->refCastFrom(inputWeightGradFallback, + *(op.getOutput(0))); + const auto &inputBiasGrad = + (op.getInput(2) && op.getInput(2)->grad()) + ? op.getInput(2)->grad()->refCastFrom(inputBiasGradFallback, + *(op.getOutput(0))) + : Tensor(); + + // Call kernel + const auto impl = + Registrar<ConvImpl2D_cpu>::create(getBestMatch(getRequiredSpec())); + impl.backward( + op.strideDims(), + op.dilationDims(), + op.kernelDims(), + op.getInput(0)->template dims<4>(), + op.getOutput(0)->template dims<4>(), + + getCPUPtr(op.getInput(0)), + getCPUPtr(op.getInput(1)), + getCPUPtr(outputGrad), + inputDataGrad.getImpl()->rawPtr(), + inputWeightGrad.getImpl()->rawPtr(), + op.getInput(2) ? inputBiasGrad.getImpl()->rawPtr() : nullptr); } + +} // namespace Aidge diff --git a/src/operator/ConvTransposeImpl.cpp b/src/operator/ConvTransposeImpl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..d1135cc92dd3c68746b9dcf80739f4f65acdad2e --- /dev/null +++ b/src/operator/ConvTransposeImpl.cpp @@ -0,0 +1,91 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp" +#include "aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp" + +template <> void Aidge::ConvTransposeImpl1D_cpu::forward() { + const auto &op = static_cast<const ConvTranspose_Op<1> &>(mOp); + + AIDGE_ASSERT(op.getInput(0), "{}: missing data input (#0).", op.type()); + AIDGE_ASSERT(op.getInput(1), "{}: missing bias input (#1).", op.type()); + AIDGE_ASSERT(op.getInput(2), "{}: missing weight input (#1).", op.type()); + + std::shared_ptr<Tensor> inputDataFallback, inputWeightFallback, + inputBiasFallback; + const auto &inputData = + op.getInput(0)->refCastFrom(inputDataFallback, *op.getOutput(0)); + const auto &inputWeight = + op.getInput(1)->refCastFrom(inputWeightFallback, *op.getOutput(0)); + const auto &inputBias = + (op.getInput(2)) + ? op.getInput(2)->refCastFrom(inputBiasFallback, *op.getOutput(0)) + : Tensor(); + + // Call kernel + const auto impl = Registrar<ConvTransposeImpl1D_cpu>::create( + getBestMatch(getRequiredSpec())); + impl.forward(op.strideDims(), + op.dilationDims(), + op.kernelDims(), + op.getInput(0)->template dims<3>(), + op.getOutput(0)->template dims<3>(), + inputData.getImpl()->hostPtr(), + inputWeight.getImpl()->hostPtr(), + op.getInput(2) ? inputBias.getImpl()->hostPtr() : nullptr, + op.getOutput(0)->getImpl()->rawPtr()); +} + +template <> void Aidge::ConvTransposeImpl1D_cpu::backward() { + AIDGE_THROW_OR_ABORT( + std::runtime_error, + "Backward not yet implemented for Conv_Op<1> on backend cpu"); +} + +template <> void Aidge::ConvTransposeImpl2D_cpu::forward() { + const auto &op = static_cast<const ConvTranspose_Op<2> &>(mOp); + + AIDGE_ASSERT(op.getInput(0), "{}: missing data input (#0).", op.type()); + AIDGE_ASSERT(op.getInput(1), "{}: missing bias input (#1).", op.type()); + AIDGE_ASSERT(op.getInput(2), "{}: missing weight input (#1).", op.type()); + + std::shared_ptr<Tensor> inputDataFallback, inputWeightFallback, + inputBiasFallback; + const auto &inputData = + op.getInput(0)->refCastFrom(inputDataFallback, *op.getOutput(0)); + const auto &inputWeight = + op.getInput(1)->refCastFrom(inputWeightFallback, *op.getOutput(0)); + const auto &inputBias = + (op.getInput(2)) + ? op.getInput(2)->refCastFrom(inputBiasFallback, *op.getOutput(0)) + : Tensor(); + + // Call kernel + const auto impl = Registrar<ConvTransposeImpl2D_cpu>::create( + getBestMatch(getRequiredSpec())); + + impl.forward(op.strideDims(), + op.dilationDims(), + op.kernelDims(), + op.getInput(0)->template dims<4>(), + op.getOutput(0)->template dims<4>(), + inputData.getImpl()->hostPtr(), + inputWeight.getImpl()->hostPtr(), + op.getInput(2) ? inputBias.getImpl()->hostPtr() : nullptr, + op.getOutput(0)->getImpl()->rawPtr()); +} + +template <> void Aidge::ConvTransposeImpl2D_cpu::backward() { + AIDGE_THROW_OR_ABORT( + std::runtime_error, + "Backward not yet implemented for Conv_Op<2> on backend cpu"); +} + diff --git a/unit_tests/operator/Test_ClipImpl.cpp b/unit_tests/operator/Test_ClipImpl.cpp index 99147ac93bd659dd91897f6b7f1f3f33e5552ef6..3d75ad78807d0e4d23ec231f5df485e8574a03ee 100644 --- a/unit_tests/operator/Test_ClipImpl.cpp +++ b/unit_tests/operator/Test_ClipImpl.cpp @@ -315,5 +315,5 @@ TEST_CASE("[cpu/operator] Clip", "[Clip][CPU]") Log::info("total time: {}\n", duration.count()); } } -} // namespace Aidge -} \ No newline at end of file +} +} // namespace Aidge diff --git a/unit_tests/operator/Test_ConvImpl.cpp b/unit_tests/operator/Test_ConvImpl.cpp index f7be338c0b9c5bb1d5af6bfa09ed7855c17fb6c0..59ec16dd80ee98c09c79d5943c503e945abf5cdb 100644 --- a/unit_tests/operator/Test_ConvImpl.cpp +++ b/unit_tests/operator/Test_ConvImpl.cpp @@ -17,6 +17,7 @@ #include "aidge/backend/cpu/operator/ConvImpl.hpp" #include "aidge/data/Data.hpp" // DataType #include "aidge/data/Tensor.hpp" +#include "aidge/filler/Filler.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/utils/TensorUtils.hpp" @@ -1645,4 +1646,1000 @@ TEST_CASE("[cpu/operator] Conv(forward)", "[Conv][CPU]") { REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-6f)); } } -} \ No newline at end of file +} + +template <DimSize_t DIM> +std::shared_ptr<OperatorTensor> +setupTestConv(const DimSize_t batchSize, + const DimSize_t inChannels, + const DimSize_t outChannels, + const std::array<DimSize_t, DIM> kernelSize, + const std::array<DimSize_t, DIM> dataSize, + const std::array<DimSize_t, DIM> stride, + const std::array<DimSize_t, DIM> dilation, + const std::array<DimSize_t, 2 * DIM> padding, + const std::shared_ptr<Tensor> input, + const std::shared_ptr<Tensor> weights, + const std::shared_ptr<Tensor> biases) { + input->setBackend("cpu"); + weights->setBackend("cpu"); + biases->setBackend("cpu"); + std::shared_ptr<Node> convNode; + convNode = Conv(inChannels, + outChannels, + kernelSize, + "myconv", + std::array<DimSize_t, DIM>({stride}), + dilation); + auto op = + std::static_pointer_cast<OperatorTensor>(convNode->getOperator()); + + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + + op->associateInput(0, input); + op->associateInput(1, weights); + op->associateInput(2, biases); + + REQUIRE_NOTHROW(op->forwardDims(true)); + + return op; +} + +TEST_CASE("[cpu/operator] Conv(backward)", "[Conv][CPU]") { + SECTION("1D") { + const std::size_t DIM = 1; + SECTION("no stride & no dilation, outChannels > inChannels") { + + const DimSize_t batchSize = 1; + const DimSize_t inChannels = 2; + const DimSize_t outChannels = 3; + const DimSize_t kernelSize = 4; + const DimSize_t inDataSize = 12; + + const DimSize_t stride = 1; + const DimSize_t dilation = 1; + const std::array<DimSize_t, 2 * DIM> padding({0, 0}); + + auto inputSize = + std::vector<DimSize_t>({batchSize, inChannels, inDataSize}); + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000}, + {1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000, + 1.000000}}}})); + + auto weights = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{0.100000, 0.100000, 0.100000, 0.100000}, + {0.100000, 0.100000, 0.100000, 0.100000}}, + {{0.100000, 0.100000, 0.100000, 0.100000}, + {0.100000, 0.100000, 0.100000, 0.100000}}, + {{0.100000, 0.100000, 0.100000, 0.100000}, + {0.100000, 0.100000, 0.100000, 0.100000}}} + + })); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({0.010000, 0.010000, 0.010000})); + + auto op = setupTestConv<DIM>( + batchSize, + inChannels, + outChannels, + std::array<DimSize_t, DIM>({kernelSize}), + std::array<DimSize_t, DIM>({inDataSize}), + std::array<DimSize_t, DIM>({stride}), + std::array<DimSize_t, DIM>({dilation}), + padding, + input, + weights, + biases); + + //////////////////////////////////// + // setup gradients for backward + auto outputGrad = + std::make_shared<Tensor>(op->getOutput(0)->dims()); + outputGrad->setDataType(DataType::Float32); + outputGrad->setBackend("cpu"); + constantFiller(outputGrad, 1.f); + op->getOutput(0)->setGrad(outputGrad); + + //////////////////////////////////// + // setup gradients for backward + REQUIRE_NOTHROW(op->backward()); + + SECTION("Input Grad") { + auto expectedInputGrad = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{0.3000, + 0.6000, + 0.9000, + 1.2000, + 1.2000, + 1.2000, + 1.2000, + 1.2000, + 1.2000, + 0.9000, + 0.6000, + 0.3000}, + {0.3000, + 0.6000, + 0.9000, + 1.2000, + 1.2000, + 1.2000, + 1.2000, + 1.2000, + 1.2000, + 0.9000, + 0.6000, + 0.3000}}}})); + CHECK(approxEq<float, float>(*op->getInput(0)->grad(), + *expectedInputGrad)); + } + SECTION("Weight grad") { + std::vector<DimSize_t> weightsSize( + {outChannels, inChannels, kernelSize}); + auto expectedWeightsGrad = + std::make_shared<Tensor>(weightsSize); + expectedWeightsGrad->setBackend("cpu"); + expectedWeightsGrad->setDataType(DataType::Float32); + constantFiller<float>(expectedWeightsGrad, 9.); + + CHECK(approxEq<float, float>(*op->getInput(1)->grad(), + *expectedWeightsGrad)); + } + SECTION("Bias Grad") { + std::vector<DimSize_t> biasesSize({outChannels}); + auto expectedBiasGrad = std::make_shared<Tensor>(biasesSize); + expectedBiasGrad->setBackend("cpu"); + expectedBiasGrad->setDataType(DataType::Float32); + constantFiller<float>(expectedBiasGrad, 9.); + CHECK(approxEq<float, float>(*op->getInput(2)->grad(), + *expectedBiasGrad)); + } + } + + SECTION("stride and no dilation, inChannel > outChannels") { + const DimSize_t batchSize = 2; + const DimSize_t inChannels = 3; + const DimSize_t outChannels = 1; + const DimSize_t kernelSize = 2; + const DimSize_t inDataSize = 8; + const DimSize_t stride = 3; + const DimSize_t dilation = 1; + const std::array<DimSize_t, 2 * DIM> padding({0, 0}); + + auto inputSize = + std::vector<DimSize_t>({batchSize, inChannels, inDataSize}); + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1.}}, + + {{1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1.}}}})); + auto weights = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{0.1000, 0.1000}, + {0.1000, 0.1000}, + {0.1000, 0.1000}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({0.060000})); + + auto op = setupTestConv<DIM>( + batchSize, + inChannels, + outChannels, + std::array<DimSize_t, DIM>({kernelSize}), + std::array<DimSize_t, DIM>({inDataSize}), + std::array<DimSize_t, DIM>({stride}), + std::array<DimSize_t, DIM>({dilation}), + padding, + input, + weights, + biases); + + //////////////////////////////////// + // setup gradients for backward + auto outputGrad = + std::make_shared<Tensor>(op->getOutput(0)->dims()); + outputGrad->setDataType(DataType::Float32); + outputGrad->setBackend("cpu"); + constantFiller(outputGrad, 1.f); + op->getOutput(0)->setGrad(outputGrad); + + //////////////////////////////////// + // setup gradients for backward + REQUIRE_NOTHROW(op->backward()); + + SECTION("Input Grad") { + auto expectedInputGrad = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000}, + {0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000}, + {0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000}}, + + {{0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000}, + {0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000}, + {0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000, + 0.0000, + 0.1000, + 0.1000}}}})); + CHECK(approxEq<float, float>(*op->getInput(0)->grad(), + *expectedInputGrad)); + } + SECTION("Weight grad") { + auto expectedWeightsGrad = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{6., 6.}, {6., 6.}, {6., 6.}}}})); + CHECK(approxEq<float, float>(*op->getInput(1)->grad(), + *expectedWeightsGrad)); + } + SECTION("Bias Grad") { + auto expectedBiasesGrad = std::make_shared<Tensor>( + Array1D<float, outChannels>({6.})); + CHECK(approxEq<float, float>(*op->getInput(2)->grad(), + *expectedBiasesGrad)); + } + } + + SECTION("dilation, no stride") { + const DimSize_t batchSize = 2; + const DimSize_t inChannels = 3; + const DimSize_t outChannels = 1; + const DimSize_t kernelSize = 2; + const DimSize_t inDataSize = 8; + + const DimSize_t stride = 1; + const DimSize_t dilation = 2; + const std::array<DimSize_t, 2 * DIM> padding({0, 0}); + + auto inputSize = + std::vector<DimSize_t>({batchSize, inChannels, inDataSize}); + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1.}}, + + {{1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1.}}}})); + auto weights = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{0.1000, 0.1000}, + {0.1000, 0.1000}, + {0.1000, 0.1000}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({0.060000})); + + auto op = setupTestConv<DIM>( + batchSize, + inChannels, + outChannels, + std::array<DimSize_t, DIM>({kernelSize}), + std::array<DimSize_t, DIM>({inDataSize}), + std::array<DimSize_t, DIM>({stride}), + std::array<DimSize_t, DIM>({dilation}), + padding, + input, + weights, + biases); + + //////////////////////////////////// + // setup gradients for backward + auto outputGrad = + std::make_shared<Tensor>(op->getOutput(0)->dims()); + outputGrad->setDataType(DataType::Float32); + outputGrad->setBackend("cpu"); + constantFiller(outputGrad, 1.f); + op->getOutput(0)->setGrad(outputGrad); + + //////////////////////////////////// + // setup gradients for backward + REQUIRE_NOTHROW(op->backward()); + + SECTION("Input Grad") { + auto expectedInputGrad = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{0.1000, + 0.1000, + 0.2000, + 0.2000, + 0.2000, + 0.2000, + 0.1000, + 0.1000}, + {0.1000, + 0.1000, + 0.2000, + 0.2000, + 0.2000, + 0.2000, + 0.1000, + 0.1000}, + {0.1000, + 0.1000, + 0.2000, + 0.2000, + 0.2000, + 0.2000, + 0.1000, + 0.1000}}, + + {{0.1000, + 0.1000, + 0.2000, + 0.2000, + 0.2000, + 0.2000, + 0.1000, + 0.1000}, + {0.1000, + 0.1000, + 0.2000, + 0.2000, + 0.2000, + 0.2000, + 0.1000, + 0.1000}, + {0.1000, + 0.1000, + 0.2000, + 0.2000, + 0.2000, + 0.2000, + 0.1000, + 0.1000}}}})); + CHECK(approxEq<float, float>(*op->getInput(0)->grad(), + *expectedInputGrad)); + } + SECTION("Weight grad") { + auto expectedWeightsGrad = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{12., 12.}, {12., 12.}, {12., 12.}}}})); + CHECK(approxEq<float, float>(*op->getInput(1)->grad(), + *expectedWeightsGrad)); + } + SECTION("Bias Grad") { + auto expectedBiasesGrad = std::make_shared<Tensor>( + Array1D<float, outChannels>({12.})); + CHECK(approxEq<float, float>(*op->getInput(2)->grad(), + *expectedBiasesGrad)); + } + } + SECTION("stride & dilation") { + const DimSize_t batchSize = 1; + const DimSize_t inChannels = 4; + const DimSize_t outChannels = 4; + const DimSize_t kernelSize = 3; + const DimSize_t inDataSize = 13; + + const DimSize_t stride = 4; + const DimSize_t dilation = 3; + const std::array<DimSize_t, 2 * DIM> padding({0, 0}); + + auto inputSize = + std::vector<DimSize_t>({batchSize, inChannels, inDataSize}); + + auto input = std::make_shared< + Tensor>(Array3D<float, batchSize, inChannels, inDataSize>( + {{{{1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}, + {1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.}}}})); + auto weights = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}}, + + {{0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}}, + + {{0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}}, + + {{0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}, + {0.1000, 0.1000, 0.1000}}}})); + + auto biases = std::make_shared<Tensor>(Array1D<float, outChannels>( + {{0.0100, 0.0100, 0.0100, 0.0100}})); + + auto op = setupTestConv<DIM>( + batchSize, + inChannels, + outChannels, + std::array<DimSize_t, DIM>({kernelSize}), + std::array<DimSize_t, DIM>({inDataSize}), + std::array<DimSize_t, DIM>({stride}), + std::array<DimSize_t, DIM>({dilation}), + padding, + input, + weights, + biases); + + //////////////////////////////////// + // setup gradients for backward + auto outputGrad = + std::make_shared<Tensor>(op->getOutput(0)->dims()); + outputGrad->setDataType(DataType::Float32); + outputGrad->setBackend("cpu"); + constantFiller(outputGrad, 1.f); + op->getOutput(0)->setGrad(outputGrad); + + //////////////////////////////////// + // setup gradients for backward + REQUIRE_NOTHROW(op->backward()); + + SECTION("Input Grad") { + auto expectedInputGrad = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{0.4000, + 0.0000, + 0.0000, + 0.4000, + 0.4000, + 0.0000, + 0.4000, + 0.4000, + 0.0000, + 0.0000, + 0.4000, + 0.0000, + 0.0000}, + {0.4000, + 0.0000, + 0.0000, + 0.4000, + 0.4000, + 0.0000, + 0.4000, + 0.4000, + 0.0000, + 0.0000, + 0.4000, + 0.0000, + 0.0000}, + {0.4000, + 0.0000, + 0.0000, + 0.4000, + 0.4000, + 0.0000, + 0.4000, + 0.4000, + 0.0000, + 0.0000, + 0.4000, + 0.0000, + 0.0000}, + {0.4000, + 0.0000, + 0.0000, + 0.4000, + 0.4000, + 0.0000, + 0.4000, + 0.4000, + 0.0000, + 0.0000, + 0.4000, + 0.0000, + 0.0000}}}})); + CHECK(approxEq<float, float>(*op->getInput(0)->grad(), + *expectedInputGrad)); + } + SECTION("Weight grad") { + auto expectedWeightsGrad = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{2., 2., 2.}, + {2., 2., 2.}, + {2., 2., 2.}, + {2., 2., 2.}}, + + {{2., 2., 2.}, + {2., 2., 2.}, + {2., 2., 2.}, + {2., 2., 2.}}, + + {{2., 2., 2.}, + {2., 2., 2.}, + {2., 2., 2.}, + {2., 2., 2.}}, + + {{2., 2., 2.}, + {2., 2., 2.}, + {2., 2., 2.}, + {2., 2., 2.}}}})); + CHECK(approxEq<float, float>(*op->getInput(1)->grad(), + *expectedWeightsGrad)); + } + SECTION("Bias Grad") { + auto expectedBiasesGrad = std::make_shared<Tensor>( + Array1D<float, outChannels>({{2., 2., 2., 2.}})); + CHECK(approxEq<float, float>(*op->getInput(2)->grad(), + *expectedBiasesGrad)); + } + } + + // Harder to read, look at previous tests in case of issue + SECTION("Sequential values") { + const DimSize_t batchSize = 1; + const DimSize_t inChannels = 2; + const DimSize_t outChannels = 2; + const DimSize_t kernelSize = 3; + const DimSize_t inDataSize = 8; + + const DimSize_t stride = 2; + const DimSize_t dilation = 2; + const std::array<DimSize_t, 2 * DIM> padding({0, 0}); + + const DimSize_t outDataSize = 2; + + auto inputSize = + std::vector<DimSize_t>({batchSize, inChannels, inDataSize}); + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{1., 2., 3., 4., 5., 6., 7., 8.}, + {9., 10., 11., 12., 13., 14., 15., 16.}}}})); + auto weights = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{0.1000, 0.2000, 0.3000}, {0.4000, 0.5000, 0.6000}}, + + {{0.7000, 0.8000, 0.9000}, {1.0000, 1.1000, 1.2000}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.0100, 0.0200}})); + + auto outputGrad = std::make_shared<Tensor>( + Array3D<float, batchSize, outChannels, outDataSize>( + {{{{1., 2.}, {3., 4.}}}})); + + auto op = setupTestConv<DIM>( + batchSize, + inChannels, + outChannels, + std::array<DimSize_t, DIM>({kernelSize}), + std::array<DimSize_t, DIM>({inDataSize}), + std::array<DimSize_t, DIM>({stride}), + std::array<DimSize_t, DIM>({dilation}), + padding, + input, + weights, + biases); + + //////////////////////////////////// + // setup gradients for backward + op->getOutput(0)->setGrad(outputGrad); + + REQUIRE_NOTHROW(op->backward()); + + SECTION("Input Grad") { + auto expectedInputGrad = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{2.2000, + 0.0000, + 5.6000, + 0.0000, + 6.6000, + 0.0000, + 4.2000, + 0.0000}, + {3.4000, + 0.0000, + 8.6000, + 0.0000, + 9.6000, + 0.0000, + 6.0000, + 0.0000}}}})); + CHECK(approxEq<float, float>(*op->getInput(0)->grad(), + *expectedInputGrad)); + } + SECTION("Weight grad") { + auto expectedWeightsGrad = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{7., 13., 19.}, {31., 37., 43.}}, + + {{15., 29., 43.}, {71., 85., 99.}}}})); + CHECK(approxEq<float, float>(*op->getInput(1)->grad(), + *expectedWeightsGrad)); + } + SECTION("Bias Grad") { + auto expectedBiasesGrad = std::make_shared<Tensor>( + Array1D<float, outChannels>({{3., 7.}})); + CHECK(approxEq<float, float>(*op->getInput(2)->grad(), + *expectedBiasesGrad)); + } + } + SECTION("random values testing") { + const DimSize_t batchSize = 1; + const DimSize_t inChannels = 4; + const DimSize_t outChannels = 4; + const DimSize_t kernelSize = 3; + const DimSize_t inDataSize = 13; + const DimSize_t outDataSize = 2; + + const DimSize_t stride = 4; + const DimSize_t dilation = 3; + const std::array<DimSize_t, 2 * DIM> padding({0, 0}); + + auto inputSize = + std::vector<DimSize_t>({batchSize, inChannels, inDataSize}); + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{0.180772, + -0.069988, + -0.359623, + -0.915204, + 0.625765, + 0.025510, + 0.954514, + 0.064349, + 0.361151, + 1.167878, + -1.349893, + -0.510177, + 0.235958}, + {-0.239778, + -0.921115, + 1.543297, + 1.348826, + -0.139642, + 0.285797, + 0.965120, + -2.037150, + 0.493136, + 1.486999, + 0.591033, + 0.126030, + -1.562687}, + {-1.160103, + -0.334841, + 0.447772, + -0.801645, + 1.523611, + 2.508587, + -0.663096, + -0.251275, + 1.010145, + 0.121547, + -1.510835, + 2.104773, + 2.762959}, + {-1.746529, + 0.410919, + -0.242185, + 0.420812, + 0.277596, + 0.778898, + 1.533269, + 1.609736, + -0.403228, + -0.274928, + 1.473840, + 0.068826, + 1.332708}}}})); + auto weights = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{0.587285, 0.286069, 0.008287}, + {-0.252325, -1.324722, 0.189178}, + {0.021100, 0.940420, -0.557690}, + {-0.693927, -0.325247, 1.243933}}, + + {{-1.167186, -0.409124, 1.260062}, + {-1.563006, 1.134614, -0.082384}, + {0.289316, 0.835773, -0.244991}, + {0.271223, 0.093636, -0.883432}}, + + {{-0.327417, 0.078394, -0.380766}, + {0.377508, 0.111912, 2.314279}, + {-0.798906, -0.564303, -1.134660}, + {0.170527, 0.994665, 1.262572}}, + + {{1.621816, 1.077471, 0.594781}, + {-1.529087, 2.043707, -0.165627}, + {0.087070, -0.527656, -0.100288}, + {1.053922, -0.623074, -1.590572}}}})); + + auto biases = std::make_shared<Tensor>(Array1D<float, outChannels>( + {{1.285940, -0.051787, -0.968103, -0.586324}})); + + auto op = setupTestConv<DIM>( + batchSize, + inChannels, + outChannels, + std::array<DimSize_t, DIM>({kernelSize}), + std::array<DimSize_t, DIM>({inDataSize}), + std::array<DimSize_t, DIM>({stride}), + std::array<DimSize_t, DIM>({dilation}), + padding, + input, + weights, + biases); + + //////////////////////////////////// + // setup gradients for backward + auto outputGrad = std::make_shared<Tensor>( + Array3D<float, batchSize, outChannels, outDataSize>( + {{{{0.053156, 1.189073}, + {0.100228, 1.042344}, + {-1.468991, 0.581337}, + {1.330418, 0.487802}}}})); + op->getOutput(0)->setGrad(outputGrad); + + //////////////////////////////////// + // setup gradients for backward + REQUIRE_NOTHROW(op->backward()); + + SECTION("Input Grad") { + auto expectedInputGrad = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize>( + {{{{2.552898, + 0.000000, + 0.000000, + 1.292528, + 0.082501, + 0.000000, + 1.477383, + 0.484875, + 0.000000, + 0.000000, + 1.392054, + 0.000000, + 0.000000}, + {-2.758950, + 0.000000, + 0.000000, + 2.597889, + -2.455656, + 0.000000, + -3.618210, + 0.669449, + 0.000000, + 0.000000, + 1.403657, + 0.000000, + 0.000000}, + {1.319545, + 0.000000, + 0.000000, + 0.260710, + -0.095303, + 0.000000, + 1.479181, + 1.403949, + 0.000000, + 0.000000, + -1.627040, + 0.000000, + 0.000000}, + {1.141951, + 0.000000, + 0.000000, + -2.298007, + 0.070817, + 0.000000, + -3.993255, + -0.014843, + 0.000000, + 0.000000, + 0.516383, + 0.000000, + 0.000000}}}})); + CHECK(approxEq<float, float>(*op->getInput(0)->grad(), + *expectedInputGrad, + 1e-5, + 1e-6)); + } + SECTION("Weight grad") { + auto expectedWeightsGrad = std::make_shared<Tensor>( + Array3D<float, outChannels, inChannels, kernelSize>( + {{{{0.753690, 0.027866, -1.554383}, + {-0.178790, -2.350622, 0.754084}, + {1.750019, -0.341397, -1.831741}, + {0.237243, 1.936463, 1.834007}}, + + {{0.670381, -0.024656, -1.311384}, + {-0.169587, -1.988220, 0.712792}, + {1.471852, -0.342263, -1.641270}, + {0.114300, 1.720076, 1.689925}}, + + {{0.098228, 1.381835, -2.186914}, + {0.271054, -3.165683, -1.074165}, + {2.589912, 1.031534, 0.095779}, + {2.727013, 0.317630, -1.395561}}, + + {{0.545751, -1.186215, 0.611421}, + {-0.387123, 0.800776, 1.572321}, + {-0.800201, -1.189095, -1.619183}, + {-2.188202, 1.345088, 2.758830}}} + + })); + CHECK(approxEq<float, float>(*op->getInput(1)->grad(), + *expectedWeightsGrad, + 1e-5, + 1e-6)); + } + SECTION("Bias Grad") { + auto expectedBiasesGrad = + std::make_shared<Tensor>(Array1D<float, outChannels>( + {{1.242230, 1.142572, -0.887655, 1.818220}})); + CHECK(approxEq<float, float>(*op->getInput(2)->grad(), + *expectedBiasesGrad)); + } + } + } + SECTION("2D") { + const DimSize_t DIM = 2; + SECTION("Sequential values") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 1; + constexpr DimSize_t outChannels = 2; + constexpr std::array<DimSize_t, DIM> kernelSize = {1, 2}; + constexpr std::array<DimSize_t, DIM> inDataSize = {3, 4}; + + constexpr std::array<DimSize_t, DIM> stride = {1, 2}; + constexpr std::array<DimSize_t, DIM> dilation = {1, 2}; + constexpr std::array<DimSize_t, 2 * DIM> padding({0, 0}); + + constexpr std::array<DimSize_t, DIM> outDataSize = {3, 1}; + + auto inputSize = std::vector<DimSize_t>( + {batchSize, inChannels, inDataSize[0], inDataSize[1]}); + + auto input = std::make_shared<Tensor>( + Array4D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1]>({{{{{1., 2., 3., 4.}, + {5., 6., 7., 8.}, + {9., 10., 11., 12.}}}}})); + auto weights = std::make_shared<Tensor>( + Array4D<float, + outChannels, + inChannels, + kernelSize[0], + kernelSize[1]>({{{{{1., 2.}}}, {{{3., 4.}}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{1., 2.}})); + + auto outputGrad = std::make_shared<Tensor>(Array4D<float, + batchSize, + outChannels, + outDataSize[0], + outDataSize[1]>( + {{{{{1.}, {2.}, {3.}}, {{4.}, {5.}, {6.}}}}})); + + auto op = setupTestConv<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + padding, + input, + weights, + biases); + + //////////////////////////////////// + // setup gradients for backward + op->getOutput(0)->setGrad(outputGrad); + + REQUIRE_NOTHROW(op->backward()); + + SECTION("Input Grad") { + auto expectedInputGrad = std::make_shared<Tensor>( + Array4D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1]>({{{{{13., 0., 18., 0.}, + {17., 0., 24., 0.}, + {21., 0., 30., 0.}}}}})); + CHECK(approxEq<float, float>(*op->getInput(0)->grad(), + *expectedInputGrad)); + } + SECTION("Weight grad") { + auto expectedWeightsGrad = + std::make_shared<Tensor>(Array4D<float, + outChannels, + inChannels, + kernelSize[0], + kernelSize[1]>( + {{{{{38., 50.}}}, {{{83., 113.}}}}})); + CHECK(approxEq<float, float>(*op->getInput(1)->grad(), + *expectedWeightsGrad)); + } + SECTION("Bias Grad") { + auto expectedBiasesGrad = std::make_shared<Tensor>( + Array1D<float, outChannels>({{6., 15.}})); + CHECK(approxEq<float, float>(*op->getInput(2)->grad(), + *expectedBiasesGrad)); + } + } + } +} diff --git a/unit_tests/operator/Test_ConvTranspose.cpp b/unit_tests/operator/Test_ConvTranspose.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6e889e809e0a05d551829bd15fda9cc651068465 --- /dev/null +++ b/unit_tests/operator/Test_ConvTranspose.cpp @@ -0,0 +1,2298 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <aidge/utils/Types.h> +#include <memory> + +#include <catch2/catch_test_macros.hpp> +#include <fmt/core.h> + +#include "aidge/backend/cpu/operator/ConvTransposeImpl.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/ConvTranspose.hpp" +#include "aidge/utils/TensorUtils.hpp" + +namespace Aidge { + +template <DimSize_t DIM> +static std::shared_ptr<OperatorTensor> +setupTestConvTranspose(const DimSize_t batchSize, + const DimSize_t inChannels, + const DimSize_t outChannels, + const std::array<DimSize_t, DIM> kernelSize, + const std::array<DimSize_t, DIM> dataSize, + const std::array<DimSize_t, DIM> stride, + const std::array<DimSize_t, DIM> dilation, + const std::shared_ptr<Tensor> input, + const std::shared_ptr<Tensor> weights, + const std::shared_ptr<Tensor> biases) { + std::shared_ptr<Node> convTransposeNode; + convTransposeNode = ConvTranspose(inChannels, + outChannels, + kernelSize, + stride, + dilation, + false, + "myconv"); + auto op = std::static_pointer_cast<OperatorTensor>( + convTransposeNode->getOperator()); + + op->associateInput(0, input); + op->setDataType(DataType::Float32); + + input->setBackend("cpu"); + op->setBackend("cpu"); + + weights->setBackend("cpu"); + op->associateInput(1, weights); + + biases->setBackend("cpu"); + op->associateInput(2, biases); + + REQUIRE_NOTHROW(op->forwardDims(true)); + + return op; +} + +TEST_CASE("[cpu/operator] ConvTranspose(forward)", "[ConvTranspose][CPU]") { + constexpr DimSize_t DIM = 1; + SECTION("1D") { + SECTION("kernel = 2 , in/outChannels = 1") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 1; + constexpr DimSize_t outChannels = 1; + + constexpr std::array<DimSize_t, DIM> kernelSize{2}; + + constexpr std::array<DimSize_t, DIM> inDataSize{4}; + constexpr std::array<DimSize_t, DIM> outDataSize{5}; + + constexpr std::array<DimSize_t, DIM> stride{1}; + constexpr std::array<DimSize_t, DIM> dilation{1}; + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize[0]>( + {{{{1.000000, 2.000000, 3.000000, 4.000000}}}})); + + auto weights = std::make_shared<Tensor>( + Array3D<float, inChannels, outChannels, kernelSize[0]>( + {{{{0.100000, 0.200000}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.010000}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array3D<float, batchSize, outChannels, outDataSize[0]>( + {{{{0.110000, 0.410000, 0.710000, 1.010000, 0.810000}}}})); + CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); + } + SECTION("kernel = 2, inChannel = 2, outChannels = 1") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 2; + constexpr DimSize_t outChannels = 1; + + constexpr std::array<DimSize_t, DIM> kernelSize{2}; + + constexpr std::array<DimSize_t, DIM> inDataSize{4}; + constexpr std::array<DimSize_t, DIM> outDataSize{5}; + + constexpr std::array<DimSize_t, DIM> stride{1}; + constexpr std::array<DimSize_t, DIM> dilation{1}; + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize[0]>( + {{{{1.000000, 2.000000, 3.000000, 4.000000}, + {5.000000, 6.000000, 7.000000, 8.000000}}}})); + + auto weights = std::make_shared<Tensor>( + Array3D<float, inChannels, outChannels, kernelSize[0]>( + {{{{0.100000, 0.200000}}, {{0.300000, 0.400000}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.010000}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array3D<float, batchSize, outChannels, outDataSize[0]>( + {{{{1.610000, 4.210000, 5.210000, 6.210001, 4.010000}}}})); + CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); + } + SECTION("kernel = 2, inChannel = 1, outChannels = 2") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 1; + constexpr DimSize_t outChannels = 2; + + constexpr std::array<DimSize_t, DIM> kernelSize{2}; + + constexpr std::array<DimSize_t, DIM> inDataSize{4}; + constexpr std::array<DimSize_t, DIM> outDataSize{5}; + + constexpr std::array<DimSize_t, DIM> stride{1}; + constexpr std::array<DimSize_t, DIM> dilation{1}; + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize[0]>( + {{{{1., 2., 3., 4.}}}})); + + auto weights = std::make_shared<Tensor>( + Array3D<float, inChannels, outChannels, kernelSize[0]>( + {{{{0.1, 0.2}, {0.3, 0.4}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.01, 0.02}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array3D<float, batchSize, outChannels, outDataSize[0]>( + {{{{0.11, 0.41, 0.71, 1.01, 0.81}, + {0.32, 1.02, 1.72, 2.42, 1.62}}}})); + + CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); + } + SECTION("kernel = 1, inChannel = 2, outChannels = 2") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 2; + constexpr DimSize_t outChannels = 2; + + constexpr std::array<DimSize_t, DIM> kernelSize{1}; + + constexpr std::array<DimSize_t, DIM> inDataSize{4}; + constexpr std::array<DimSize_t, DIM> outDataSize{4}; + + constexpr std::array<DimSize_t, DIM> stride{1}; + constexpr std::array<DimSize_t, DIM> dilation{1}; + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize[0]>( + {{{{1.000000, 2.000000, 3.000000, 4.000000}, + {5.000000, 6.000000, 7.000000, 8.000000}}}})); + + auto weights = std::make_shared<Tensor>( + Array3D<float, inChannels, outChannels, kernelSize[0]>( + {{{{0.100000}, {0.200000}}, + + {{0.300000}, {0.400000}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.010000, 0.020000}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array3D<float, batchSize, outChannels, outDataSize[0]>( + {{{{1.610000, 2.010000, 2.410000, 2.810000}, + {2.220000, 2.820000, 3.420000, 4.020000}}}})); + + CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); + } + SECTION("kernel = 2, inChannels = 2, outChannels = 3") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 2; + constexpr DimSize_t outChannels = 3; + + constexpr std::array<DimSize_t, DIM> kernelSize{2}; + + constexpr std::array<DimSize_t, DIM> inDataSize{4}; + constexpr std::array<DimSize_t, DIM> outDataSize{5}; + + constexpr std::array<DimSize_t, DIM> stride{1}; + constexpr std::array<DimSize_t, DIM> dilation{1}; + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize[0]>( + {{{{1., 2., 3., 4.}, {5., 6., 7., 8.}}}})); + + auto weights = std::make_shared<Tensor>( + Array3D<float, inChannels, outChannels, kernelSize[0]>( + {{{{0.10, 0.20}, {0.30, 0.40}, {0.50, 0.60}}, + + {{0.70, 0.80}, {0.90, 1.}, {1.10, 1.20}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.010000, 0.020000, 0.030000}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared< + Tensor>(Array3D<float, batchSize, outChannels, outDataSize[0]>( + {{{{3.610000, 8.610001, 10.410000, 12.210001, 7.210001}, + {4.820000, 11.420000, 14.020000, 16.620001, 9.620001}, + {6.030000, 14.230000, 17.630001, 21.030001, 12.030000}}}})); + + CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); + } + + SECTION("Big test to ensure kernel capabilities") { + constexpr DimSize_t batchSize = 2; + constexpr DimSize_t inChannels = 3; + constexpr DimSize_t outChannels = 4; + + constexpr std::array<DimSize_t, DIM> kernelSize{6}; + + constexpr std::array<DimSize_t, DIM> inDataSize{6}; + constexpr std::array<DimSize_t, DIM> outDataSize{11}; + + constexpr std::array<DimSize_t, DIM> stride{1}; + constexpr std::array<DimSize_t, DIM> dilation{1}; + + auto input = std::make_shared<Tensor>( + Array3D<float, batchSize, inChannels, inDataSize[0]>( + {{{{1., 2., 3., 4., 5., 6.}, + {7., 8., 9., 10., 11., 12.}, + {13., 14., 15., 16., 17., 18.}}, + + {{19., 20., 21., 22., 23., 24.}, + {25., 26., 27., 28., 29., 30.}, + {31., 32., 33., 34., 35., 36.}}}})); + + auto weights = std::make_shared<Tensor>( + Array3D<float, inChannels, outChannels, kernelSize[0]>( + {{{{0.1, 0.2, 0.3, 0.4, 0.5, 0.6}, + {0.7, 0.8, 0.9, 1., 1.1, 1.2}, + {1.3, 1.4, 1.5, 1.6, 1.7, 1.8}, + {1.9, 2., 2.1, 2.2, 2.3, 2.4}}, + + {{2.5, 2.6, 2.7, 2.8, 2.9, 3.}, + {3.1, 3.2, 3.3, 3.4, 3.5, 3.6}, + {3.7, 3.8, 3.9, 4., 4.1, 4.2}, + {4.3, 4.4, 4.5, 4.6, 4.7, 4.8}}, + + {{4.9, 5., 5.1, 5.2, 5.3, 5.4}, + {5.5, 5.6, 5.7, 5.8, 5.9, 6.}, + {6.1, 6.2, 6.3, 6.4, 6.5, 6.6}, + {6.7, 6.8, 6.9, 7., 7.1, 7.2}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.01, 0.02, 0.03, 0.04}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array3D<float, batchSize, outChannels, outDataSize[0]>( + {{{{81.310005, + 172.210007, + 273.010010, + 384.010040, + 505.509979, + 637.810059, + 561.010010, + 472.809998, + 372.910004, + 261.010010, + 136.809998}, + {93.919998, + 199.220001, + 316.219971, + 445.220001, + 586.520081, + 740.420044, + 651.020020, + 548.420044, + 432.319977, + 302.420013, + 158.419998}, + {106.529999, + 226.230011, + 359.429993, + 506.430054, + 667.530090, + 843.030029, + 741.030029, + 624.030029, + 491.730042, + 343.829987, + 180.029999}, + {119.140007, + 253.240005, + 402.640045, + 567.640076, + 748.539978, + 945.639954, + 831.039978, + 699.640015, + 551.140015, + 385.239990, + 201.639999}}, + + {{216.309998, + 447.610016, + 694.210022, + 956.410034, + 1234.510132, + 1528.810059, + 1317.010010, + 1088.410034, + 842.710022, + 579.610046, + 298.810028}, + {261.319977, + 539.420044, + 834.619995, + 1147.220093, + 1477.520142, + 1825.820068, + 1569.019897, + 1293.619995, + 999.320068, + 685.820007, + 352.819977}, + {306.329987, + 631.230042, + 975.030029, + 1338.030151, + 1720.530029, + 2122.829834, + 1821.029785, + 1498.830200, + 1155.930054, + 792.030029, + 406.830017}, + {351.340027, + 723.039978, + 1115.440063, + 1528.840210, + 1963.539917, + 2419.839844, + 2073.040283, + 1704.040039, + 1312.540039, + 898.239990, + 460.840027}}}})); + CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); + } + } + + SECTION("2D") { + constexpr DimSize_t DIM = 2; + SECTION("inChannels = 1, outChannels = 2, kernelSize = {1,2}, " + "inDataSize = {2,3}") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 1; + constexpr DimSize_t outChannels = 2; + + constexpr std::array<DimSize_t, DIM> kernelSize{1, 2}; + + constexpr std::array<DimSize_t, DIM> inDataSize{2, 3}; + constexpr std::array<DimSize_t, DIM> outDataSize{2, 4}; + + constexpr std::array<DimSize_t, DIM> stride{1, 1}; + constexpr std::array<DimSize_t, DIM> dilation{1, 1}; + + auto input = std::make_shared<Tensor>(Array4D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1]>( + {{{{{1.000000, 2.000000, 3.000000}, + {4.000000, 5.000000, 6.000000}}}}})); + + auto weights = std::make_shared<Tensor>( + Array4D<float, + inChannels, + outChannels, + kernelSize[0], + kernelSize[1]>({{{{{0.100000, 0.200000}}, + + {{0.300000, 0.400000}}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.010000}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = + std::make_shared<Tensor>(Array4D<float, + batchSize, + outChannels, + outDataSize[0], + outDataSize[1]>( + {{{{{0.110000, 0.410000, 0.710000, 0.610000}, + {0.410000, 1.310000, 1.610000, 1.210000}}, + + {{0.320000, 1.020000, 1.720000, 1.220000}, + {1.220000, 3.120000, 3.820000, 2.420000}}}}})); + } + SECTION("inChannels = 1, outChannels = 2, kernelSize = {2,3}, " + "inDataSize = {2,3}") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 1; + constexpr DimSize_t outChannels = 2; + + constexpr std::array<DimSize_t, DIM> kernelSize{2, 3}; + + constexpr std::array<DimSize_t, DIM> inDataSize{2, 3}; + constexpr std::array<DimSize_t, DIM> outDataSize{3, 5}; + + constexpr std::array<DimSize_t, DIM> stride{1, 1}; + constexpr std::array<DimSize_t, DIM> dilation{1, 1}; + + auto input = std::make_shared<Tensor>(Array4D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1]>( + {{{{{1.000000, 2.000000, 3.000000}, + {4.000000, 5.000000, 6.000000}}}}})); + + auto weights = std::make_shared<Tensor>(Array4D<float, + inChannels, + outChannels, + kernelSize[0], + kernelSize[1]>( + {{{{{0.100000, 0.200000, 0.300000}, + {0.400000, 0.500000, 0.600000}}, + + {{0.700000, 0.800000, 0.900000}, + {1.000000, 1.100000, 1.200000}}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.010000, 0.020000}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared< + Tensor>(Array4D<float, + batchSize, + outChannels, + outDataSize[0], + outDataSize[1]>( + {{{{{0.110000, 0.410000, 1.010000, 1.210000, 0.910000}, + {0.810000, 2.610000, 5.610000, 5.410000, 3.610000}, + {1.610000, 4.010000, 7.310000, 6.010000, 3.610000}}, + + {{0.720000, 2.220000, 4.620000, 4.220000, 2.720000}, + {3.820000, 9.820001, 18.220001, 15.020000, 9.020000}, + {4.020000, 9.420000, 16.320000, 12.620001, 7.220000}}}}})); + } + SECTION("inChannels = 1, outChannels = 2, kernelSize = {2,3}, " + "inDataSize = {6,6}, stride = {2, 2}, dilation = {2, 2}") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 1; + constexpr DimSize_t outChannels = 2; + + constexpr std::array<DimSize_t, DIM> kernelSize{2, 3}; + + constexpr std::array<DimSize_t, DIM> inDataSize{4, 4}; + constexpr std::array<DimSize_t, DIM> outDataSize{9, 11}; + + constexpr std::array<DimSize_t, DIM> stride{2, 2}; + constexpr std::array<DimSize_t, DIM> dilation{2, 2}; + + auto input = std::make_shared<Tensor>(Array4D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1]>( + {{{{{1.00, 2.00, 3.00, 4.000000}, + {5.00, 6.00, 7.00, 8.000000}, + {9.00, 10.00, 11.00, 12.000000}, + {13.00, 14.00, 15.00, 16.000000}}}}})); + + auto weights = std::make_shared<Tensor>(Array4D<float, + inChannels, + outChannels, + kernelSize[0], + kernelSize[1]>( + {{{{{0.10, 0.20, 0.300000}, {0.40, 0.50, 0.600000}}, + + {{0.70, 0.80, 0.900000}, {1.00, 1.10, 1.200000}}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.01, 0.020000}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array4D<float, + batchSize, + outChannels, + outDataSize[0], + outDataSize[1]>({{{{{0.11, + 0.01, + 0.41, + 0.01, + 1.01, + 0.01, + 1.61, + 0.01, + 1.71, + 0.01, + 1.210000}, + {0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.010000}, + {0.91, + 0.01, + 2.91, + 0.01, + 6.210001, + 0.01, + 8.31, + 0.01, + 7.510001, + 0.01, + 4.810000}, + {0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.010000}, + {2.91, + 0.01, + 7.710001, + 0.01, + 14.610001, + 0.01, + 16.710001, + 0.01, + 13.910002, + 0.01, + 8.410001}, + {0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.010000}, + {4.91, + 0.01, + 12.51, + 0.01, + 23.01, + 0.01, + 25.110001, + 0.01, + 20.309999, + 0.01, + 12.010000}, + {0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.01, + 0.010000}, + {5.210001, + 0.01, + 12.110001, + 0.01, + 20.809999, + 0.01, + 22.309999, + 0.01, + 17.01, + 0.01, + 9.610001}}, + + {{0.72, + 0.02, + 2.22, + 0.02, + 4.62, + 0.02, + 7.02, + 0.02, + 5.92, + 0.02, + 3.620000}, + {0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.020000}, + {4.52, + 0.02, + 11.320001, + 0.02, + 20.620003, + 0.02, + 26.320002, + 0.02, + 20.720001, + 0.02, + 12.020000}, + {0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.020000}, + {11.32, + 0.02, + 25.720001, + 0.02, + 43.420002, + 0.02, + 49.120003, + 0.02, + 36.720001, + 0.02, + 20.420002}, + {0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.020000}, + {18.119999, + 0.02, + 40.120003, + 0.02, + 66.220001, + 0.02, + 71.919998, + 0.02, + 52.720001, + 0.02, + 28.820002}, + {0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.02, + 0.020000}, + {13.02, + 0.02, + 28.32, + 0.02, + 46.02, + 0.02, + 49.320004, + 0.02, + 35.619999, + 0.02, + 19.220001}}}}})); + } + SECTION("inChannels = 4, outChannels = 3, kernelSize = {2,2}, " + "inDataSize = {3,3}, stride = {2, 2}, dilation = {2, 2}") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 4; + constexpr DimSize_t outChannels = 3; + + constexpr std::array<DimSize_t, DIM> kernelSize{2, 2}; + + constexpr std::array<DimSize_t, DIM> inDataSize{4, 4}; + constexpr std::array<DimSize_t, DIM> outDataSize{7, 7}; + + constexpr std::array<DimSize_t, DIM> stride{2, 2}; + constexpr std::array<DimSize_t, DIM> dilation{2, 2}; + + auto input = std::make_shared<Tensor>(Array4D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1]>( + {{{{{1.0, 2.0, 3.0}, {4.0, 5.0, 6.0}, {7.0, 8.0, 9.0}}, + + {{10.0, 11.0, 12.0}, + {13.0, 14.0, 15.0}, + {16.0, 17.0, 18.0}}, + + {{19.0, 20.0, 21.0}, + {22.0, 23.0, 24.0}, + {25.0, 26.0, 27.0}}, + + {{28.0, 29.0, 30.0}, + {31.0, 32.0, 33.0}, + {34.0, 35.0, 36.0}}}}})); + + auto weights = std::make_shared<Tensor>( + Array4D<float, + inChannels, + outChannels, + kernelSize[0], + kernelSize[1]>({{{{{0.1, 0.2}, {0.3, 0.4}}, + + {{0.5, 0.6}, {0.7, 0.8}}, + + {{0.9, 1.0}, {1.1, 1.2}}}, + + {{{1.3, 1.4}, {1.5, 1.6}}, + + {{1.7, 1.8}, {1.9, 2.0}}, + + {{2.1, 2.2}, {2.3, 2.4}}}, + + {{{2.5, 2.6}, {2.7, 2.8}}, + + {{2.9, 3.0}, {3.1, 3.2}}, + + {{3.3, 3.4}, {3.5, 3.6}}}, + + {{{3.7, 3.8}, {3.9, 4.0}}, + + {{4.1, 4.2}, {4.3, 4.4}}, + + {{4.5, 4.6}, {4.7, 4.8}}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.010000, 0.020000, 0.030000}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array4D<float, + batchSize, + outChannels, + outDataSize[0], + outDataSize[1]>({{{{{164.209991, + 0.010000, + 341.809998, + 0.010000, + 357.410034, + 0.010000, + 186.009995}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {362.809998, + 0.010000, + 754.410034, + 0.010000, + 787.210083, + 0.010000, + 409.210022}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {410.809998, + 0.010000, + 852.810059, + 0.010000, + 885.609985, + 0.010000, + 459.610016}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {226.209991, + 0.010000, + 469.010010, + 0.010000, + 486.210022, + 0.010000, + 252.009995}}, + + {{187.419998, + 0.020000, + 389.820007, + 0.020000, + 408.619995, + 0.020000, + 212.420013}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {414.019989, + 0.020000, + 860.020020, + 0.020000, + 899.220032, + 0.020000, + 466.820007}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {471.620026, + 0.020000, + 977.619995, + 0.020000, + 1016.820068, + 0.020000, + 526.820007}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {259.019989, + 0.020000, + 536.220032, + 0.020000, + 556.619995, + 0.020000, + 288.019989}}, + + {{210.630005, + 0.030000, + 437.829987, + 0.030000, + 459.829987, + 0.030000, + 238.830002}, + {0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000}, + {465.230011, + 0.030000, + 965.630005, + 0.030000, + 1011.230103, + 0.030000, + 524.430054}, + {0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000}, + {532.430054, + 0.030000, + 1102.430054, + 0.030000, + 1148.030029, + 0.030000, + 594.030029}, + {0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000, + 0.030000}, + {291.830017, + 0.030000, + 603.430054, + 0.030000, + 627.030029, + 0.030000, + 324.029999}}}}})); + } + SECTION("Big test to ensure kernel capabilities 1") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 3; + constexpr DimSize_t outChannels = 4; + + constexpr std::array<DimSize_t, DIM> kernelSize{2, 2}; + + constexpr std::array<DimSize_t, DIM> inDataSize{6, 5}; + constexpr std::array<DimSize_t, DIM> outDataSize{8, 17}; + + constexpr std::array<DimSize_t, DIM> stride{1, 3}; + constexpr std::array<DimSize_t, DIM> dilation{2, 4}; + + auto input = std::make_shared<Tensor>( + Array4D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1]>({{{{{1., 2., 3., 4., 5.}, + {6., 7., 8., 9., 10.}, + {11., 12., 13., 14., 15.}, + {16., 17., 18., 19., 20.}, + {21., 22., 23., 24., 25.}, + {26., 27., 28., 29., 30.}}, + + {{31., 32., 33., 34., 35.}, + {36., 37., 38., 39., 40.}, + {41., 42., 43., 44., 45.}, + {46., 47., 48., 49., 50.}, + {51., 52., 53., 54., 55.}, + {56., 57., 58., 59., 60.}}, + + {{61., 62., 63., 64., 65.}, + {66., 67., 68., 69., 70.}, + {71., 72., 73., 74., 75.}, + {76., 77., 78., 79., 80.}, + {81., 82., 83., 84., 85.}, + {86., 87., 88., 89., 90.}}}}})); + + auto weights = std::make_shared<Tensor>(Array4D<float, + inChannels, + outChannels, + kernelSize[0], + kernelSize[1]>( + {{{{{0.100000, 0.200000}, {0.300000, 0.400000}}, + + {{0.500000, 0.600000}, {0.700000, 0.800000}}, + + {{0.900000, 1.000000}, {1.100000, 1.200000}}, + + {{1.300000, 1.400000}, {1.500000, 1.600000}}}, + + {{{1.700000, 1.800000}, {1.900000, 2.000000}}, + + {{2.100000, 2.200000}, {2.300000, 2.400000}}, + + {{2.500000, 2.600000}, {2.700000, 2.800000}}, + + {{2.900000, 3.000000}, {3.100000, 3.200000}}}, + + {{{3.300000, 3.400000}, {3.500000, 3.600000}}, + + {{3.700000, 3.800000}, {3.900000, 4.000000}}, + + {{4.100000, 4.200000}, {4.300000, 4.400000}}, + + {{4.500000, 4.600000}, {4.700000, 4.800000}}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.01, 0.02, 0.03, 0.04}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array4D<float, + batchSize, + outChannels, + outDataSize[0], + outDataSize[1]>({{{{{254.110001, + 0.010000, + 0.010000, + 259.210022, + 263.410034, + 0.010000, + 264.309998, + 268.810028, + 0.010000, + 269.410004, + 274.210022, + 0.010000, + 274.510010, + 279.610016, + 0.010000, + 0.010000, + 285.010010}, + {279.610016, + 0.010000, + 0.010000, + 284.710022, + 290.410004, + 0.010000, + 289.809998, + 295.810028, + 0.010000, + 294.910004, + 301.210022, + 0.010000, + 300.010010, + 306.610016, + 0.010000, + 0.010000, + 312.010010}, + {577.810059, + 0.010000, + 0.010000, + 588.609985, + 599.410034, + 0.010000, + 599.410034, + 610.810059, + 0.010000, + 610.209961, + 622.210022, + 0.010000, + 621.010010, + 633.609985, + 0.010000, + 0.010000, + 645.010010}, + {631.810059, + 0.010000, + 0.010000, + 642.609985, + 656.410034, + 0.010000, + 653.410034, + 667.810059, + 0.010000, + 664.209961, + 679.210022, + 0.010000, + 675.010010, + 690.609985, + 0.010000, + 0.010000, + 702.010010}, + {685.810059, + 0.010000, + 0.010000, + 696.609985, + 713.410034, + 0.010000, + 707.410034, + 724.810059, + 0.010000, + 718.209961, + 736.210022, + 0.010000, + 729.010010, + 747.609985, + 0.010000, + 0.010000, + 759.010010}, + {739.810059, + 0.010000, + 0.010000, + 750.609985, + 770.410034, + 0.010000, + 761.410034, + 781.810059, + 0.010000, + 772.209961, + 793.210022, + 0.010000, + 783.010010, + 804.609985, + 0.010000, + 0.010000, + 816.010010}, + {386.710022, + 0.010000, + 0.010000, + 392.410004, + 402.010010, + 0.010000, + 398.110016, + 408.010010, + 0.010000, + 403.809998, + 414.010010, + 0.010000, + 409.510010, + 420.010010, + 0.010000, + 0.010000, + 426.010010}, + {415.210022, + 0.010000, + 0.010000, + 420.910004, + 432.010010, + 0.010000, + 426.610016, + 438.010040, + 0.010000, + 432.309998, + 444.010010, + 0.010000, + 438.010010, + 450.010040, + 0.010000, + 0.010000, + 456.010010}}, + + {{291.320007, + 0.020000, + 0.020000, + 297.619995, + 300.619995, + 0.020000, + 303.919983, + 307.219971, + 0.020000, + 310.220001, + 313.819977, + 0.020000, + 316.519989, + 320.419983, + 0.020000, + 0.020000, + 327.019989}, + {322.820007, + 0.020000, + 0.020000, + 329.119995, + 333.619995, + 0.020000, + 335.419983, + 340.219971, + 0.020000, + 341.720001, + 346.819977, + 0.020000, + 348.019989, + 353.419983, + 0.020000, + 0.020000, + 360.019989}, + {664.220032, + 0.020000, + 0.020000, + 677.420044, + 685.820068, + 0.020000, + 690.619995, + 699.619995, + 0.020000, + 703.820068, + 713.420044, + 0.020000, + 717.020020, + 727.219971, + 0.020000, + 0.020000, + 741.020020}, + {730.220032, + 0.020000, + 0.020000, + 743.420044, + 754.820068, + 0.020000, + 756.619995, + 768.619995, + 0.020000, + 769.820068, + 782.420044, + 0.020000, + 783.020020, + 796.219971, + 0.020000, + 0.020000, + 810.020020}, + {796.220032, + 0.020000, + 0.020000, + 809.420044, + 823.820068, + 0.020000, + 822.620056, + 837.619995, + 0.020000, + 835.820068, + 851.420044, + 0.020000, + 849.020020, + 865.219971, + 0.020000, + 0.020000, + 879.020020}, + {862.220032, + 0.020000, + 0.020000, + 875.420044, + 892.820068, + 0.020000, + 888.619995, + 906.619995, + 0.020000, + 901.820068, + 920.420044, + 0.020000, + 915.020020, + 934.219971, + 0.020000, + 0.020000, + 948.020020}, + {447.919983, + 0.020000, + 0.020000, + 454.820007, + 463.220001, + 0.020000, + 461.720001, + 470.420013, + 0.020000, + 468.619995, + 477.619995, + 0.020000, + 475.519989, + 484.819977, + 0.020000, + 0.020000, + 492.019989}, + {482.419983, + 0.020000, + 0.020000, + 489.320007, + 499.220001, + 0.020000, + 496.220001, + 506.420013, + 0.020000, + 503.119995, + 513.619995, + 0.020000, + 510.019989, + 520.820007, + 0.020000, + 0.020000, + 528.020020}}, + + {{328.529999, + 0.030000, + 0.030000, + 336.029999, + 337.830017, + 0.030000, + 343.529999, + 345.630035, + 0.030000, + 351.029999, + 353.430023, + 0.030000, + 358.529999, + 361.230011, + 0.030000, + 0.030000, + 369.030029}, + {366.029999, + 0.030000, + 0.030000, + 373.529999, + 376.830017, + 0.030000, + 381.029999, + 384.630035, + 0.030000, + 388.529999, + 392.430023, + 0.030000, + 396.029999, + 400.230042, + 0.030000, + 0.030000, + 408.030029}, + {750.630005, + 0.030000, + 0.030000, + 766.230042, + 772.230042, + 0.030000, + 781.830078, + 788.430054, + 0.030000, + 797.430054, + 804.630066, + 0.030000, + 813.030029, + 820.830078, + 0.030000, + 0.030000, + 837.030029}, + {828.630005, + 0.030000, + 0.030000, + 844.230042, + 853.230042, + 0.030000, + 859.830078, + 869.430054, + 0.030000, + 875.430054, + 885.630066, + 0.030000, + 891.030029, + 901.830078, + 0.030000, + 0.030000, + 918.030029}, + {906.630005, + 0.030000, + 0.030000, + 922.230042, + 934.230042, + 0.030000, + 937.830078, + 950.430054, + 0.030000, + 953.430054, + 966.630066, + 0.030000, + 969.030029, + 982.830078, + 0.030000, + 0.030000, + 999.030090}, + {984.630005, + 0.030000, + 0.030000, + 1000.230042, + 1015.230103, + 0.030000, + 1015.830078, + 1031.430054, + 0.030000, + 1031.430054, + 1047.630127, + 0.030000, + 1047.030029, + 1063.830078, + 0.030000, + 0.030000, + 1080.030029}, + {509.130005, + 0.030000, + 0.030000, + 517.230042, + 524.430054, + 0.030000, + 525.330078, + 532.830017, + 0.030000, + 533.430054, + 541.230042, + 0.030000, + 541.530029, + 549.630066, + 0.030000, + 0.030000, + 558.030029}, + {549.630066, + 0.030000, + 0.030000, + 557.730042, + 566.430054, + 0.030000, + 565.830078, + 574.830017, + 0.030000, + 573.930054, + 583.230042, + 0.030000, + 582.030029, + 591.630066, + 0.030000, + 0.030000, + 600.030029}}, + + {{365.740021, + 0.040000, + 0.040000, + 374.440002, + 375.040009, + 0.040000, + 383.140015, + 384.040009, + 0.040000, + 391.839996, + 393.040009, + 0.040000, + 400.540009, + 402.040009, + 0.040000, + 0.040000, + 411.040009}, + {409.240021, + 0.040000, + 0.040000, + 417.940002, + 420.040009, + 0.040000, + 426.640015, + 429.040009, + 0.040000, + 435.339996, + 438.040009, + 0.040000, + 444.040009, + 447.040009, + 0.040000, + 0.040000, + 456.040009}, + {837.039978, + 0.040000, + 0.040000, + 855.040039, + 858.639954, + 0.040000, + 873.039978, + 877.239990, + 0.040000, + 891.039978, + 895.840027, + 0.040000, + 909.039978, + 914.440002, + 0.040000, + 0.040000, + 933.039978}, + {927.039978, + 0.040000, + 0.040000, + 945.040039, + 951.639954, + 0.040000, + 963.039978, + 970.239990, + 0.040000, + 981.039978, + 988.840027, + 0.040000, + 999.039978, + 1007.440002, + 0.040000, + 0.040000, + 1026.040039}, + {1017.039978, + 0.040000, + 0.040000, + 1035.040039, + 1044.640015, + 0.040000, + 1053.040039, + 1063.239990, + 0.040000, + 1071.040039, + 1081.840088, + 0.040000, + 1089.040039, + 1100.440063, + 0.040000, + 0.040000, + 1119.040039}, + {1107.040039, + 0.040000, + 0.040000, + 1125.040039, + 1137.640137, + 0.040000, + 1143.040039, + 1156.239990, + 0.040000, + 1161.040039, + 1174.840088, + 0.040000, + 1179.040039, + 1193.440063, + 0.040000, + 0.040000, + 1212.040039}, + {570.340027, + 0.040000, + 0.040000, + 579.640015, + 585.640015, + 0.040000, + 588.940002, + 595.239990, + 0.040000, + 598.239990, + 604.840027, + 0.040000, + 607.540039, + 614.440002, + 0.040000, + 0.040000, + 624.039978}, + {616.840027, + 0.040000, + 0.040000, + 626.140015, + 633.640015, + 0.040000, + 635.440002, + 643.239990, + 0.040000, + 644.739990, + 652.840027, + 0.040000, + 654.040039, + 662.440002, + 0.040000, + 0.040000, + 672.039978}}}}})); + CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); + } + SECTION("Big test to ensure kernel capabilities") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 3; + constexpr DimSize_t outChannels = 4; + + constexpr std::array<DimSize_t, DIM> kernelSize{6, 4}; + + constexpr std::array<DimSize_t, DIM> inDataSize{6, 5}; + constexpr std::array<DimSize_t, DIM> outDataSize{16, 25}; + + constexpr std::array<DimSize_t, DIM> stride{1, 3}; + constexpr std::array<DimSize_t, DIM> dilation{2, 4}; + + auto input = std::make_shared<Tensor>( + Array4D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1]>({{{{{1., 2., 3., 4., 5.}, + {6., 7., 8., 9., 10.}, + {11., 12., 13., 14., 15.}, + {16., 17., 18., 19., 20.}, + {21., 22., 23., 24., 25.}, + {26., 27., 28., 29., 30.}}, + + {{31., 32., 33., 34., 35.}, + {36., 37., 38., 39., 40.}, + {41., 42., 43., 44., 45.}, + {46., 47., 48., 49., 50.}, + {51., 52., 53., 54., 55.}, + {56., 57., 58., 59., 60.}}, + + {{61., 62., 63., 64., 65.}, + {66., 67., 68., 69., 70.}, + {71., 72., 73., 74., 75.}, + {76., 77., 78., 79., 80.}, + {81., 82., 83., 84., 85.}, + {86., 87., 88., 89., 90.}}}}})); + + auto weights = std::make_shared<Tensor>(Array4D<float, + inChannels, + outChannels, + kernelSize[0], + kernelSize[1]>( + {{{{{0.100000, 0.200000, 0.300000, 0.400000}, + {0.500000, 0.600000, 0.700000, 0.800000}, + {0.900000, 1.000000, 1.100000, 1.200000}, + {1.300000, 1.400000, 1.500000, 1.600000}, + {1.700000, 1.800000, 1.900000, 2.000000}, + {2.100000, 2.200000, 2.300000, 2.400000}}, + + {{2.500000, 2.600000, 2.700000, 2.800000}, + {2.900000, 3.000000, 3.100000, 3.200000}, + {3.300000, 3.400000, 3.500000, 3.600000}, + {3.700000, 3.800000, 3.900000, 4.000000}, + {4.100000, 4.200000, 4.300000, 4.400000}, + {4.500000, 4.600000, 4.700000, 4.800000}}, + + {{4.900000, 5.000000, 5.100000, 5.200000}, + {5.300000, 5.400000, 5.500000, 5.600000}, + {5.700000, 5.800000, 5.900000, 6.000000}, + {6.100000, 6.200000, 6.300000, 6.400000}, + {6.500000, 6.600000, 6.700000, 6.800000}, + {6.900000, 7.000000, 7.100000, 7.200000}}, + + {{7.300000, 7.400000, 7.500000, 7.600000}, + {7.700000, 7.800000, 7.900000, 8.000000}, + {8.100000, 8.200000, 8.300000, 8.400001}, + {8.500000, 8.600000, 8.700000, 8.800000}, + {8.900001, 9.000000, 9.100000, 9.200000}, + {9.300000, 9.400001, 9.500000, 9.600000}}}, + + {{{9.700000, 9.800000, 9.900001, 10.000000}, + {10.100000, 10.200000, 10.300000, 10.400001}, + {10.500000, 10.600000, 10.700000, 10.800000}, + {10.900001, 11.000000, 11.100000, 11.200000}, + {11.300000, 11.400001, 11.500000, 11.600000}, + {11.700000, 11.800000, 11.900001, 12.000000}}, + + {{12.100000, 12.200000, 12.300000, 12.400001}, + {12.500000, 12.600000, 12.700000, 12.800000}, + {12.900001, 13.000000, 13.100000, 13.200000}, + {13.300000, 13.400001, 13.500000, 13.600000}, + {13.700000, 13.800000, 13.900001, 14.000000}, + {14.100000, 14.200000, 14.300000, 14.400001}}, + + {{14.500000, 14.600000, 14.700000, 14.800000}, + {14.900001, 15.000000, 15.100000, 15.200000}, + {15.300000, 15.400001, 15.500000, 15.600000}, + {15.700000, 15.800000, 15.900001, 16.000000}, + {16.100000, 16.200001, 16.300001, 16.400000}, + {16.500000, 16.600000, 16.700001, 16.800001}}, + + {{16.900000, 17.000000, 17.100000, 17.200001}, + {17.300001, 17.400000, 17.500000, 17.600000}, + {17.700001, 17.800001, 17.900000, 18.000000}, + {18.100000, 18.200001, 18.300001, 18.400000}, + {18.500000, 18.600000, 18.700001, 18.800001}, + {18.900000, 19.000000, 19.100000, 19.200001}}}, + + {{{19.300001, 19.400000, 19.500000, 19.600000}, + {19.700001, 19.800001, 19.900000, 20.000000}, + {20.100000, 20.200001, 20.300001, 20.400000}, + {20.500000, 20.600000, 20.700001, 20.800001}, + {20.900000, 21.000000, 21.100000, 21.200001}, + {21.300001, 21.400000, 21.500000, 21.600000}}, + + {{21.700001, 21.800001, 21.900000, 22.000000}, + {22.100000, 22.200001, 22.300001, 22.400000}, + {22.500000, 22.600000, 22.700001, 22.800001}, + {22.900000, 23.000000, 23.100000, 23.200001}, + {23.300001, 23.400000, 23.500000, 23.600000}, + {23.700001, 23.800001, 23.900000, 24.000000}}, + + {{24.100000, 24.200001, 24.300001, 24.400000}, + {24.500000, 24.600000, 24.700001, 24.800001}, + {24.900000, 25.000000, 25.100000, 25.200001}, + {25.300001, 25.400000, 25.500000, 25.600000}, + {25.700001, 25.800001, 25.900000, 26.000000}, + {26.100000, 26.200001, 26.300001, 26.400000}}, + + {{26.500000, 26.600000, 26.700001, 26.800001}, + {26.900000, 27.000000, 27.100000, 27.200001}, + {27.300001, 27.400000, 27.500000, 27.600000}, + {27.700001, 27.800001, 27.900000, 28.000000}, + {28.100000, 28.200001, 28.300001, 28.400000}, + {28.500000, 28.600000, 28.700001, 28.800001}}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.01, 0.02, 0.03, 0.04}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = + std::make_shared<Tensor>(Array4D<float, + batchSize, + outChannels, + outDataSize[0], + outDataSize[1]>( + {{{{{1478.110107, 0.010000, 0.010000, 1507.210083, + 1487.410034, 0.010000, 1536.310059, 1516.809937, + 1496.709961, 1565.410034, 1546.209961, 1526.410034, + 3100.510010, 1575.609985, 1556.109985, 1536.010010, + 1605.010010, 1585.810059, 1566.010010, 0.010000, + 1615.510010, 1596.010010, 0.010000, 0.010000, + 1626.010010}, + {1623.610107, 0.010000, 0.010000, 1652.710083, + 1634.410034, 0.010000, 1681.810059, 1663.809937, + 1645.209961, 1710.910034, 1693.209961, 1674.910034, + 3396.010010, 1722.609985, 1704.610107, 1686.010010, + 1752.010010, 1734.310059, 1716.010010, 0.010000, + 1764.010010, 1746.010010, 0.010000, 0.010000, + 1776.010010}, + {3284.410156, 0.010000, 0.010000, 3343.810303, + 3306.010010, 0.010000, 3403.210205, 3366.010010, + 3327.610107, 3462.610107, 3426.010010, 3388.209961, + 6871.209961, 3486.010010, 3448.810059, 3410.409912, + 3546.010010, 3509.409912, 3471.610107, 0.010000, + 3570.010010, 3532.810059, 0.010000, 0.010000, + 3594.010010}, + {3581.410156, 0.010000, 0.010000, 3640.810303, + 3606.010010, 0.010000, 3700.210205, 3666.010010, + 3630.610107, 3759.610107, 3726.010010, 3691.209961, + 7474.209961, 3786.010010, 3751.810059, 3716.409912, + 3846.010010, 3812.409912, 3777.610107, 0.010000, + 3873.010010, 3838.810059, 0.010000, 0.010000, + 3900.010010}, + {5430.910156, 0.010000, 0.010000, 5521.809570, + 5467.809570, 0.010000, 5612.709961, 5559.609863, + 5504.709961, 5703.609863, 5651.409668, 5597.409668, + 11336.110352, 5743.209961, 5690.109863, 5635.209473, + 5835.009766, 5782.809570, 5728.809570, 0.010000, + 5875.509766, 5822.409668, 0.010000, 0.010000, + 5916.009766}, + {5885.410156, 0.010000, 0.010000, 5976.310059, + 5926.809570, 0.010000, 6067.209961, 6018.609863, + 5968.209961, 6158.110352, 6110.409668, 6060.909668, + 12258.610352, 6202.209961, 6153.609375, 6103.209473, + 6294.009766, 6246.309570, 6196.809570, 0.010000, + 6339.009766, 6290.409668, 0.010000, 0.010000, + 6384.009766}, + {5578.509766, 0.010000, 0.010000, 5673.009766, + 5615.410156, 0.010000, 5767.510254, 5710.809570, + 5652.309570, 5862.009766, 5806.209961, 5748.609863, + 11645.710938, 5901.609863, 5844.909668, 5786.409668, + 5997.009766, 5941.209961, 5883.609375, 0.010000, + 6037.509766, 5980.809570, 0.010000, 0.010000, + 6078.009766}, + {6051.009766, 0.010000, 0.010000, 6145.509766, + 6092.410156, 0.010000, 6240.010254, 6187.810059, + 6133.809570, 6334.509766, 6283.209961, 6230.109863, + 12604.208984, 6378.610352, 6326.409668, 6272.410156, + 6474.009766, 6422.709961, 6369.609375, 0.010000, + 6519.009766, 6466.809570, 0.010000, 0.010000, + 6564.009766}, + {5726.109863, 0.010000, 0.010000, 5824.209473, + 5763.009766, 0.010000, 5922.309570, 5862.009766, + 5799.910156, 6020.409668, 5961.010254, 5899.809570, + 11955.309570, 6060.009766, 5999.709961, 5937.609863, + 6159.009766, 6099.609863, 6038.409668, 0.010000, + 6199.509766, 6139.209961, 0.010000, 0.010000, + 6240.009766}, + {6216.609863, 0.010000, 0.010000, 6314.709473, + 6258.009766, 0.010000, 6412.809570, 6357.009766, + 6299.410156, 6510.909668, 6456.010254, 6399.310059, + 12949.809570, 6555.009766, 6499.209961, 6441.609863, + 6654.009766, 6599.110352, 6542.409668, 0.010000, + 6699.009766, 6643.209961, 0.010000, 0.010000, + 6744.009766}, + {5873.709961, 0.010000, 0.010000, 5975.409668, + 5910.609863, 0.010000, 6077.109375, 6013.209473, + 5947.509766, 6178.809570, 6115.809570, 6051.009766, + 12264.910156, 6218.409668, 6154.510254, 6088.809570, + 6321.009766, 6258.009766, 6193.209961, 0.010000, + 6361.509766, 6297.610352, 0.010000, 0.010000, + 6402.009766}, + {6382.209473, 0.010000, 0.010000, 6483.910156, + 6423.609863, 0.010000, 6585.609375, 6526.209473, + 6465.009766, 6687.309570, 6628.809570, 6568.509766, + 13295.410156, 6731.409668, 6672.010254, 6610.810059, + 6834.009766, 6775.509766, 6715.209961, 0.010000, + 6879.009766, 6819.610352, 0.010000, 0.010000, + 6924.009766}, + {4320.009766, 0.010000, 0.010000, 4389.009766, + 4347.609863, 0.010000, 4458.009766, 4417.209961, + 4375.209961, 4527.009766, 4486.809570, 4445.409668, + 8998.809570, 4556.409668, 4515.609863, 4473.609863, + 4626.009766, 4585.809570, 4544.410156, 0.010000, + 4656.009766, 4615.209961, 0.010000, 0.010000, + 4686.009766}, + {4665.009766, 0.010000, 0.010000, 4734.009766, + 4695.609375, 0.010000, 4803.009766, 4765.209961, + 4726.209961, 4872.009766, 4834.809570, 4796.409668, + 9697.809570, 4904.409668, 4866.609863, 4827.609863, + 4974.009766, 4936.809570, 4898.410156, 0.010000, + 5007.009766, 4969.209961, 0.010000, 0.010000, + 5040.009766}, + {2366.110107, 0.010000, 0.010000, 2401.209961, + 2381.409912, 0.010000, 2436.310059, 2416.810059, + 2396.709961, 2471.410156, 2452.209961, 2432.409912, + 4918.509766, 2487.609863, 2468.110107, 2448.010010, + 2523.010010, 2503.810059, 2484.010010, 0.010000, + 2539.510010, 2520.010010, 0.010000, 0.010000, + 2556.010010}, + {2541.610107, 0.010000, 0.010000, 2576.710205, + 2558.409912, 0.010000, 2611.810059, 2593.810059, + 2575.209961, 2646.910156, 2629.209961, 2610.909912, + 5274.009766, 2664.609863, 2646.610107, 2628.010010, + 2700.010010, 2682.310059, 2664.010010, 0.010000, + 2718.010010, 2700.010010, 0.010000, 0.010000, + 2736.010010}}, + + {{1701.320068, 0.020000, 0.020000, 1737.620117, + 1710.620117, 0.020000, 1773.920044, 1747.220093, + 1719.920044, 1810.220093, 1783.820068, 1756.819946, + 3575.719971, 1820.420044, 1793.719971, 1766.420044, + 1857.020142, 1830.619995, 1803.619995, 0.020000, + 1867.520020, 1840.820068, 0.020000, 0.020000, + 1878.020020}, + {1882.820068, 0.020000, 0.020000, 1919.120117, + 1893.620117, 0.020000, 1955.420044, 1930.220093, + 1904.420044, 1991.720093, 1966.820068, 1941.319946, + 3943.219971, 2003.420044, 1978.219971, 1952.420044, + 2040.020142, 2015.119995, 1989.620117, 0.020000, + 2052.020020, 2026.820068, 0.020000, 0.020000, + 2064.020020}, + {3802.820068, 0.020000, 0.020000, 3876.620117, + 3824.420166, 0.020000, 3950.420166, 3898.820068, + 3846.020020, 4024.220215, 3973.220215, 3921.020020, + 7965.620117, 4047.620117, 3996.020020, 3943.219727, + 4122.020020, 4071.020020, 4018.820068, 0.020000, + 4146.020020, 4094.419922, 0.020000, 0.020000, + 4170.020020}, + {4171.819824, 0.020000, 0.020000, 4245.620117, + 4196.420410, 0.020000, 4319.420410, 4270.819824, + 4221.020020, 4393.220215, 4345.220215, 4296.020020, + 8712.620117, 4419.620605, 4371.020020, 4321.219727, + 4494.020020, 4446.020020, 4396.819824, 0.020000, + 4521.020020, 4472.419922, 0.020000, 0.020000, + 4548.020020}, + {6316.520020, 0.020000, 0.020000, 6429.020020, + 6353.420410, 0.020000, 6541.520508, 6466.819824, + 6390.319824, 6654.020020, 6580.220215, 6504.620117, + 13193.718750, 6693.620605, 6618.919922, 6542.420410, + 6807.020020, 6733.220215, 6657.619629, 0.020000, + 6847.520020, 6772.819824, 0.020000, 0.020000, + 6888.020020}, + {6879.020020, 0.020000, 0.020000, 6991.520020, + 6920.420410, 0.020000, 7104.020508, 7033.820312, + 6961.819824, 7216.520020, 7147.220215, 7076.120117, + 14332.218750, 7260.620605, 7190.420410, 7118.420410, + 7374.020020, 7304.720215, 7233.619629, 0.020000, + 7419.020020, 7348.819824, 0.020000, 0.020000, + 7464.020020}, + {6464.120117, 0.020000, 0.020000, 6580.219727, + 6501.020020, 0.020000, 6696.319824, 6618.020020, + 6537.920410, 6812.419922, 6735.020508, 6655.819824, + 13503.319336, 6852.020020, 6773.720215, 6693.620117, + 6969.020020, 6891.620605, 6812.419922, 0.020000, + 7009.520020, 6931.220215, 0.020000, 0.020000, + 7050.020020}, + {7044.620117, 0.020000, 0.020000, 7160.720215, + 7086.020020, 0.020000, 7276.819824, 7203.020020, + 7127.420410, 7392.919434, 7320.020508, 7245.320312, + 14677.819336, 7437.020020, 7363.220215, 7287.620117, + 7554.020020, 7481.120605, 7406.420410, 0.020000, + 7599.020020, 7525.220215, 0.020000, 0.020000, + 7644.020020}, + {6611.719727, 0.020000, 0.020000, 6731.420410, + 6648.620117, 0.020000, 6851.119629, 6769.219727, + 6685.520020, 6970.819824, 6889.819824, 6807.020020, + 13812.919922, 7010.419922, 6928.520508, 6844.819824, + 7131.020020, 7050.020020, 6967.220215, 0.020000, + 7171.520020, 7089.620605, 0.020000, 0.020000, + 7212.020020}, + {7210.219727, 0.020000, 0.020000, 7329.920410, + 7251.620117, 0.020000, 7449.619629, 7372.220215, + 7293.020020, 7569.319824, 7492.819824, 7414.520020, + 15023.418945, 7613.419434, 7536.020508, 7456.820312, + 7734.020020, 7657.520020, 7579.220215, 0.020000, + 7779.020020, 7701.620605, 0.020000, 0.020000, + 7824.020020}, + {6759.319824, 0.020000, 0.020000, 6882.620117, + 6796.219727, 0.020000, 7005.919922, 6920.420410, + 6833.120117, 7129.220215, 7044.619629, 6958.219727, + 14122.519531, 7168.819824, 7083.319824, 6996.020020, + 7293.020020, 7208.419922, 7122.020508, 0.020000, + 7333.520020, 7248.020020, 0.020000, 0.020000, + 7374.020020}, + {7375.819824, 0.020000, 0.020000, 7499.120117, + 7417.219727, 0.020000, 7622.420410, 7541.420410, + 7458.620117, 7745.720215, 7665.619629, 7583.720215, + 15369.019531, 7789.819824, 7708.819824, 7626.020020, + 7914.020020, 7833.919434, 7752.020508, 0.020000, + 7959.020020, 7878.020020, 0.020000, 0.020000, + 8004.020020}, + {4982.420410, 0.020000, 0.020000, 5065.819824, + 5010.020020, 0.020000, 5149.220215, 5094.020020, + 5037.619629, 5232.620605, 5178.020020, 5122.220215, + 10381.219727, 5262.020020, 5206.819824, 5150.419922, + 5346.020020, 5291.419922, 5235.620117, 0.020000, + 5376.020020, 5320.819824, 0.020000, 0.020000, + 5406.020020}, + {5399.420410, 0.020000, 0.020000, 5482.820312, + 5430.020020, 0.020000, 5566.220215, 5514.020020, + 5460.619629, 5649.620605, 5598.020020, 5545.220215, + 11224.219727, 5682.020020, 5629.819824, 5576.419922, + 5766.020020, 5714.419922, 5661.620117, 0.020000, + 5799.020020, 5746.819824, 0.020000, 0.020000, + 5832.020020}, + {2733.320068, 0.020000, 0.020000, 2775.620117, + 2748.620117, 0.020000, 2817.920166, 2791.219971, + 2763.919922, 2860.220215, 2833.820068, 2806.820068, + 5681.720215, 2876.420166, 2849.719971, 2822.419922, + 2919.020020, 2892.619873, 2865.620117, 0.020000, + 2935.520020, 2908.820068, 0.020000, 0.020000, + 2952.020020}, + {2944.820068, 0.020000, 0.020000, 2987.120117, + 2961.620117, 0.020000, 3029.420166, 3004.220215, + 2978.419922, 3071.720215, 3046.820068, 3021.320068, + 6109.220215, 3089.420166, 3064.219971, 3038.419922, + 3132.020020, 3107.119873, 3081.620117, 0.020000, + 3150.020020, 3124.820068, 0.020000, 0.020000, + 3168.020020}}, + + {{1924.530029, 0.030000, 0.030000, 1968.030029, + 1933.830078, 0.030000, 2011.530029, 1977.630127, + 1943.130127, 2055.030029, 2021.430054, 1987.230103, + 4050.929932, 2065.230225, 2031.330078, 1996.829956, + 2109.030029, 2075.430176, 2041.229980, 0.030000, + 2119.530029, 2085.630127, 0.030000, 0.030000, + 2130.030029}, + {2142.030029, 0.030000, 0.030000, 2185.530029, + 2152.830078, 0.030000, 2229.030029, 2196.630127, + 2163.630127, 2272.530029, 2240.430176, 2207.729980, + 4490.429688, 2284.230225, 2251.830078, 2218.830078, + 2328.030029, 2295.930176, 2263.229980, 0.030000, + 2340.030029, 2307.629883, 0.030000, 0.030000, + 2352.030029}, + {4321.229980, 0.030000, 0.030000, 4409.429688, + 4342.829590, 0.030000, 4497.629883, 4431.629883, + 4364.430176, 4585.829590, 4520.430176, 4453.829590, + 9060.030273, 4609.229980, 4543.229980, 4476.029785, + 4698.029785, 4632.630371, 4566.029785, 0.030000, + 4722.029785, 4656.029785, 0.030000, 0.030000, + 4746.029785}, + {4762.229980, 0.030000, 0.030000, 4850.429688, + 4786.829590, 0.030000, 4938.629883, 4875.629883, + 4811.430176, 5026.829590, 4964.430176, 4900.829590, + 9951.030273, 5053.229980, 4990.229980, 4926.029785, + 5142.029785, 5079.630371, 5016.029785, 0.030000, + 5169.029785, 5106.029785, 0.030000, 0.030000, + 5196.029785}, + {7202.129883, 0.030000, 0.030000, 7336.229492, + 7239.029785, 0.030000, 7470.329590, 7374.029785, + 7275.930176, 7604.429688, 7509.030273, 7411.829590, + 15051.330078, 7644.029785, 7547.729980, 7449.629883, + 7779.029785, 7683.630371, 7586.430176, 0.030000, + 7819.529785, 7723.229980, 0.030000, 0.030000, + 7860.029785}, + {7872.629883, 0.030000, 0.030000, 8006.729980, + 7914.029785, 0.030000, 8140.829590, 8049.029785, + 7955.430176, 8274.929688, 8184.030273, 8091.330078, + 16405.830078, 8319.030273, 8227.230469, 8133.629883, + 8454.030273, 8363.130859, 8270.430664, 0.030000, + 8499.030273, 8407.230469, 0.030000, 0.030000, + 8544.030273}, + {7349.729492, 0.030000, 0.030000, 7487.430176, + 7386.629883, 0.030000, 7625.129395, 7525.229980, + 7423.529785, 7762.829590, 7663.829590, 7563.029785, + 15360.929688, 7802.429688, 7702.530273, 7600.829590, + 7941.029785, 7842.029785, 7741.229980, 0.030000, + 7981.529785, 7881.630371, 0.030000, 0.030000, + 8022.029785}, + {8038.229492, 0.030000, 0.030000, 8175.930176, + 8079.629883, 0.030000, 8313.629883, 8218.230469, + 8121.029785, 8451.330078, 8356.830078, 8260.530273, + 16751.427734, 8495.429688, 8400.030273, 8302.831055, + 8634.030273, 8539.530273, 8443.230469, 0.030000, + 8679.030273, 8583.630859, 0.030000, 0.030000, + 8724.030273}, + {7497.329590, 0.030000, 0.030000, 7638.629883, + 7534.229492, 0.030000, 7779.930176, 7676.430176, + 7571.130371, 7921.229980, 7818.629395, 7714.229980, + 15670.530273, 7960.829590, 7857.329590, 7752.029785, + 8103.029785, 8000.429688, 7896.030273, 0.030000, + 8143.529785, 8040.029785, 0.030000, 0.030000, + 8184.029785}, + {8203.830078, 0.030000, 0.030000, 8345.129883, + 8245.229492, 0.030000, 8486.430664, 8387.430664, + 8286.630859, 8627.730469, 8529.629883, 8429.730469, + 17097.029297, 8671.830078, 8572.830078, 8472.030273, + 8814.030273, 8715.930664, 8616.030273, 0.030000, + 8859.030273, 8760.030273, 0.030000, 0.030000, + 8904.030273}, + {7644.930176, 0.030000, 0.030000, 7789.829590, + 7681.829590, 0.030000, 7934.729980, 7827.629883, + 7718.729980, 8079.630371, 7973.430176, 7865.430176, + 15980.130859, 8119.229980, 8012.129395, 7903.229980, + 8265.030273, 8158.830566, 8050.829590, 0.030000, + 8305.530273, 8198.430664, 0.030000, 0.030000, + 8346.030273}, + {8369.430664, 0.030000, 0.030000, 8514.331055, + 8410.830078, 0.030000, 8659.230469, 8556.629883, + 8452.231445, 8804.130859, 8702.430664, 8598.930664, + 17442.628906, 8848.230469, 8745.629883, 8641.230469, + 8994.030273, 8892.331055, 8788.830078, 0.030000, + 9039.030273, 8936.430664, 0.030000, 0.030000, + 9084.030273}, + {5644.829590, 0.030000, 0.030000, 5742.629883, + 5672.430176, 0.030000, 5840.430176, 5770.830078, + 5700.029785, 5938.229980, 5869.229980, 5799.029785, + 11763.630859, 5967.630371, 5898.029785, 5827.229980, + 6066.029785, 5997.029785, 5926.829590, 0.030000, + 6096.029785, 6026.430176, 0.030000, 0.030000, + 6126.029785}, + {6133.829590, 0.030000, 0.030000, 6231.629883, + 6164.430176, 0.030000, 6329.430176, 6262.830078, + 6195.029785, 6427.229980, 6361.229980, 6294.029785, + 12750.630859, 6459.630371, 6393.029785, 6325.229980, + 6558.029785, 6492.029785, 6424.829590, 0.030000, + 6591.029785, 6524.430176, 0.030000, 0.030000, + 6624.029785}, + {3100.530029, 0.030000, 0.030000, 3150.030029, + 3115.830078, 0.030000, 3199.530029, 3165.630127, + 3131.130127, 3249.030029, 3215.430176, 3181.230225, + 6444.930176, 3265.230225, 3231.330078, 3196.830078, + 3315.030029, 3281.430176, 3247.230225, 0.030000, + 3331.530029, 3297.630127, 0.030000, 0.030000, + 3348.030029}, + {3348.030029, 0.030000, 0.030000, 3397.530029, + 3364.830078, 0.030000, 3447.030029, 3414.630127, + 3381.630127, 3496.530029, 3464.430176, 3431.730225, + 6944.430176, 3514.230225, 3481.830078, 3448.830078, + 3564.030029, 3531.930176, 3499.230225, 0.030000, + 3582.030029, 3549.630127, 0.030000, 0.030000, + 3600.030029}}, + + {{2147.739990, 0.040000, 0.040000, 2198.439941, + 2157.040039, 0.040000, 2249.140137, 2208.040039, + 2166.340088, 2299.840088, 2259.040039, 2217.640137, + 4526.140137, 2310.040039, 2268.940186, 2227.240234, + 2361.040039, 2320.240234, 2278.840088, 0.040000, + 2371.540039, 2330.440186, 0.040000, 0.040000, + 2382.040039}, + {2401.239990, 0.040000, 0.040000, 2451.939941, + 2412.040039, 0.040000, 2502.640137, 2463.040039, + 2422.840088, 2553.340088, 2514.040039, 2474.140137, + 5037.640137, 2565.040039, 2525.440186, 2485.240234, + 2616.040039, 2576.740234, 2536.840088, 0.040000, + 2628.040039, 2588.440186, 0.040000, 0.040000, + 2640.040039}, + {4839.640137, 0.040000, 0.040000, 4942.240234, + 4861.240234, 0.040000, 5044.839844, 4964.439941, + 4882.839844, 5147.440430, 5067.640137, 4986.640137, + 10154.440430, 5170.839844, 5090.440430, 5008.840332, + 5274.040039, 5194.240234, 5113.240234, 0.040000, + 5298.040039, 5217.640625, 0.040000, 0.040000, + 5322.040039}, + {5352.640137, 0.040000, 0.040000, 5455.240234, + 5377.240234, 0.040000, 5557.839844, 5480.439941, + 5401.839844, 5660.440430, 5583.640137, 5505.640137, + 11189.439453, 5686.839844, 5609.440430, 5530.840332, + 5790.040039, 5713.240234, 5635.240234, 0.040000, + 5817.040039, 5739.640625, 0.040000, 0.040000, + 5844.040039}, + {8087.740234, 0.040000, 0.040000, 8243.440430, + 8124.640625, 0.040000, 8399.139648, 8281.240234, + 8161.540039, 8554.840820, 8437.839844, 8319.040039, + 16908.937500, 8594.440430, 8476.540039, 8356.840820, + 8751.040039, 8634.040039, 8515.240234, 0.040000, + 8791.540039, 8673.640625, 0.040000, 0.040000, + 8832.040039}, + {8866.240234, 0.040000, 0.040000, 9021.940430, + 8907.640625, 0.040000, 9177.639648, 9064.240234, + 8949.040039, 9333.340820, 9220.839844, 9106.540039, + 18479.437500, 9377.440430, 9264.040039, 9148.840820, + 9534.040039, 9421.540039, 9307.240234, 0.040000, + 9579.040039, 9465.640625, 0.040000, 0.040000, + 9624.040039}, + {8235.339844, 0.040000, 0.040000, 8394.639648, + 8272.240234, 0.040000, 8553.940430, 8432.440430, + 8309.140625, 8713.240234, 8592.639648, 8470.240234, + 17218.539062, 8752.840820, 8631.339844, 8508.040039, + 8913.040039, 8792.440430, 8670.040039, 0.040000, + 8953.540039, 8832.040039, 0.040000, 0.040000, + 8994.040039}, + {9031.839844, 0.040000, 0.040000, 9191.139648, + 9073.240234, 0.040000, 9350.440430, 9233.440430, + 9114.640625, 9509.740234, 9393.639648, 9275.740234, + 18825.039062, 9553.839844, 9436.839844, 9318.040039, + 9714.040039, 9597.940430, 9480.040039, 0.040000, + 9759.040039, 9642.040039, 0.040000, 0.040000, + 9804.040039}, + {8382.940430, 0.040000, 0.040000, 8545.840820, + 8419.839844, 0.040000, 8708.740234, 8583.639648, + 8456.740234, 8871.640625, 8747.440430, 8621.440430, + 17528.138672, 8911.240234, 8786.139648, 8659.240234, + 9075.040039, 8950.840820, 8824.839844, 0.040000, + 9115.540039, 8990.440430, 0.040000, 0.040000, + 9156.040039}, + {9197.440430, 0.040000, 0.040000, 9360.340820, + 9238.839844, 0.040000, 9523.240234, 9402.639648, + 9280.240234, 9686.140625, 9566.440430, 9444.940430, + 19170.638672, 9730.240234, 9609.639648, 9487.240234, + 9894.040039, 9774.339844, 9652.839844, 0.040000, + 9939.040039, 9818.440430, 0.040000, 0.040000, + 9984.040039}, + {8530.540039, 0.040000, 0.040000, 8697.040039, + 8567.440430, 0.040000, 8863.540039, 8734.840820, + 8604.339844, 9030.040039, 8902.240234, 8772.639648, + 17837.740234, 9069.640625, 8940.940430, 8810.440430, + 9237.040039, 9109.240234, 8979.639648, 0.040000, + 9277.540039, 9148.840820, 0.040000, 0.040000, + 9318.040039}, + {9363.040039, 0.040000, 0.040000, 9529.540039, + 9404.440430, 0.040000, 9696.040039, 9571.840820, + 9445.839844, 9862.540039, 9739.240234, 9614.139648, + 19516.240234, 9906.640625, 9782.440430, 9656.440430, + 10074.040039, 9950.740234, 9825.639648, 0.040000, + 10119.040039, 9994.839844, 0.040000, 0.040000, + 10164.040039}, + {6307.240234, 0.040000, 0.040000, 6419.439941, + 6334.839844, 0.040000, 6531.640137, 6447.640137, + 6362.440430, 6643.839844, 6560.440430, 6475.840332, + 13146.040039, 6673.240234, 6589.240234, 6504.040039, + 6786.040039, 6702.640625, 6618.040039, 0.040000, + 6816.040039, 6732.040039, 0.040000, 0.040000, + 6846.040039}, + {6868.240234, 0.040000, 0.040000, 6980.439941, + 6898.839844, 0.040000, 7092.640137, 7011.640137, + 6929.440430, 7204.839844, 7124.440430, 7042.840332, + 14277.040039, 7237.240234, 7156.240234, 7074.040039, + 7350.040039, 7269.640625, 7188.040039, 0.040000, + 7383.040039, 7302.040039, 0.040000, 0.040000, + 7416.040039}, + {3467.739990, 0.040000, 0.040000, 3524.439941, + 3483.040039, 0.040000, 3581.140137, 3540.040039, + 3498.340088, 3637.840088, 3597.040039, 3555.640137, + 7208.140137, 3654.040039, 3612.940186, 3571.240234, + 3711.040039, 3670.240234, 3628.840088, 0.040000, + 3727.540039, 3686.440186, 0.040000, 0.040000, + 3744.040039}, + {3751.239990, 0.040000, 0.040000, 3807.939941, + 3768.040039, 0.040000, 3864.640137, 3825.040039, + 3784.840088, 3921.340088, 3882.040039, 3842.140137, + 7779.640137, 3939.040039, 3899.440186, 3859.240234, + 3996.040039, 3956.740234, 3916.840088, 0.040000, + 4014.040039, 3974.440186, 0.040000, 0.040000, + 4032.040039}}}}})); + CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); + } + } +} + +} // namespace Aidge diff --git a/unit_tests/operator/Test_ExpandImpl.cpp b/unit_tests/operator/Test_ExpandImpl.cpp index 878c608110eabb824d8a6c0d1ceb0853b3c1449d..ad30457d33307ca595ecddfd3b06d58e118a02d0 100644 --- a/unit_tests/operator/Test_ExpandImpl.cpp +++ b/unit_tests/operator/Test_ExpandImpl.cpp @@ -13,20 +13,20 @@ #include <catch2/catch_test_macros.hpp> -#include "aidge/backend/cpu/data/TensorImpl.hpp" -#include "aidge/backend/cpu/operator/ExpandImpl.hpp" #include "aidge/data/DataType.hpp" #include "aidge/data/Tensor.hpp" #include "aidge/operator/Expand.hpp" #include "aidge/utils/ArrayHelpers.hpp" -using std::shared_ptr; -using namespace Aidge; +namespace Aidge { + +using std::shared_ptr; -void setupTestExpand(shared_ptr<Tensor> inputData, - shared_ptr<Tensor> inputShape, - shared_ptr<Expand_Op> &op) { +static void setupTestExpand(shared_ptr<Tensor> inputData, + shared_ptr<Tensor> inputShape, + shared_ptr<Expand_Op> &op, + Tensor &expectedOutput) { op->getOutput(0)->setDataType(inputData->dataType()); @@ -35,6 +35,9 @@ void setupTestExpand(shared_ptr<Tensor> inputData, inputShape->setBackend("cpu"); op->associateInput(1, inputShape); + + expectedOutput.setBackend("cpu"); + expectedOutput.setDataType(DataType::Int32); } TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") { @@ -49,7 +52,7 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") { Array4D<cpptype_t<DataType::Int32>, 1, 3, 4, 2>({{{{{1, 3}, {1, 3}, {1, 3}, {1, 3}}, {{1, 3}, {1, 3}, {1, 3}, {1, 3}}, {{1, 3}, {1, 3}, {1, 3}, {1, 3}}}}}); - setupTestExpand(inputData, inputShape, op); + setupTestExpand(inputData, inputShape, op, expectedOutput); // forwardDims has already been tested in core CHECK(op->forwardDims(true)); @@ -63,7 +66,7 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") { std::make_shared<Tensor>(Array1D<std::int64_t, 2>({2, 3})); Tensor expectedOutput = Array3D<cpptype_t<DataType::Int32>, 2, 2, 3>( {{{{2, 1, 3}, {2, 1, 3}}, {{2, 1, 3}, {2, 1, 3}}}}); - setupTestExpand(inputData, inputShape, op); + setupTestExpand(inputData, inputShape, op,expectedOutput); // forwardDims has already been tested in core CHECK(op->forwardDims(true)); @@ -77,7 +80,7 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") { std::make_shared<Tensor>(Array1D<std::int64_t, 1>({1})); Tensor expectedOutput = Array4D<cpptype_t<DataType::Int32>, 2, 1, 3, 1>({{{2, 1, 3}, {2, 1, 3}}}); - setupTestExpand(inputData, inputShape, op); + setupTestExpand(inputData, inputShape, op, expectedOutput); // forwardDims has already been tested in core CHECK(op->forwardDims(true)); @@ -91,7 +94,7 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") { std::make_shared<Tensor>(Array1D<std::int64_t, 3>({2, 1, 1})); Tensor expectedOutput = Array4D<cpptype_t<DataType::Int32>, 1, 2, 3, 1>({{{{2, 1, 3}, {2, 1, 3}}}}); - setupTestExpand(inputData, inputShape, op); + setupTestExpand(inputData, inputShape, op,expectedOutput); // forwardDims has already been tested in core CHECK(op->forwardDims(true)); @@ -101,3 +104,4 @@ TEST_CASE("[cpu/operator] Expand(forward)", "[Expand][CPU]") { SECTION("N-Dim to N-Dim") {} auto inputData = std::shared_ptr<Tensor>(); } +} // namespace Aidge