diff --git a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp index e764eecfd6746585a7526b8dc7c2a7295c242285..f772ed77cb8d543cfa43df35502784cb6309a5ec 100644 --- a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp @@ -1108,13 +1108,12 @@ void ConvImpl3D_cpu_forward_kernel(const array<DimSize_t, 3> &strideDims, //////////////////////////////////////////////////////////////////////// // COMPUTATION - for (DimSize_t batch = 0; batch < iDims[0]; - ++batch, oOffset[0] += oStride[0], iOffset[0] += iStride[0]) { - - oOffset[1] = oOffset[0]; - kOffset[0] = 0; - for (DimSize_t oChannel = 0; oChannel < oDims[1]; - ++oChannel, oOffset[1] += oStride[1], kOffset[0] += kStride[0]) { + for (DimSize_t batch = 0; batch < iDims[0]; ++batch) { + oOffset[0] = batch * oStride[0]; + iOffset[0] = batch * iStride[0]; + for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) { + oOffset[1] = oChannel * oStride[1] + oOffset[0]; + kOffset[0] = oChannel * kStride[0]; // Filling given channel with corresponding bias value if (biases != nullptr) { @@ -1124,43 +1123,33 @@ void ConvImpl3D_cpu_forward_kernel(const array<DimSize_t, 3> &strideDims, biasVal); } - iOffset[1] = iOffset[0]; - kOffset[1] = kOffset[0]; - for (DimSize_t iChannel = 0; iChannel < iDims[1]; ++iChannel, - iOffset[1] += iStride[1], - kOffset[1] += kStride[1]) { - - iOffset[2] = iOffset[1]; - oOffset[2] = oOffset[1]; - for (DimSize_t oX = 0; oX < oDims[2]; ++oX, - iOffset[2] += strideDims[0] * iStride[2], - oOffset[2] += oStride[2]) { - - iOffset[3] = iOffset[2]; - oOffset[3] = oOffset[2]; - for (DimSize_t oY = 0; oY < oDims[3]; ++oY, - iOffset[3] += strideDims[1] * iStride[3], - oOffset[3] += oStride[3]) { - - for (DimSize_t iZ = 0, oZ = 0; oZ < oDims[4]; - ++oZ, iZ += strideDims[2]) { + for (DimSize_t iChannel = 0; iChannel < iDims[1]; ++iChannel) { + iOffset[1] = iChannel * iStride[1] + iOffset[0]; + kOffset[1] = iChannel * kStride[1] + kOffset[0]; + + for (DimSize_t oX = 0; oX < oDims[2]; ++oX) { + iOffset[2] = oX * strideDims[0] * iStride[2] + iOffset[1]; + oOffset[2] = oX * oStride[2] + oOffset[1]; + + for (DimSize_t oY = 0; oY < oDims[3]; ++oY) { + iOffset[3] = + oY * strideDims[1] * iStride[3] + iOffset[2]; + oOffset[3] = oY * oStride[3] + oOffset[2]; + + for (DimSize_t oZ = 0; oZ < oDims[4]; ++oZ) { auto oIdx = oOffset[3] + oZ; - auto iIdx = iOffset[3] + iZ; - - kOffset[2] = kOffset[1]; - kDilOffset[0] = 0; - for (DimSize_t kX = 0; kX < kDims[0]; ++kX, - kOffset[2] += kStride[2], - kDilOffset[0] += dilationDims[0] * - iStride[2]) { - - kOffset[3] = kOffset[2]; - kDilOffset[1] = kDilOffset[0]; - for (DimSize_t kY = 0; kY < kDims[1]; - ++kY, - kOffset[3] += kStride[3], - kDilOffset[1] += - dilationDims[1] * iStride[3]) { + auto iIdx = iOffset[3] + oZ * strideDims[2]; + + for (DimSize_t kX = 0; kX < kDims[0]; ++kX) { + kOffset[2] = kX * kStride[2] + kOffset[1]; + kDilOffset[0] = + kX * dilationDims[0] * iStride[2]; + + for (DimSize_t kY = 0; kY < kDims[1]; ++kY) { + kOffset[3] = kY * kStride[3] + kOffset[2]; + kDilOffset[1] = + kY * dilationDims[1] * iStride[3] + + kDilOffset[0]; for (DimSize_t kZ = 0; kZ < kDims[2]; ++kZ) { @@ -1194,8 +1183,8 @@ void ConvImpl3D_cpu_forward_kernel(const array<DimSize_t, 3> &strideDims, * for each weight * for each output * multiply the weight with the associated value - * @note kernel & stride are passed as single integers as they are just arrays - * of length 1 + * @note kernel & stride are passed as single integers as they are just + * arrays of length 1 * @note reminder that kernel dimensions are * {outChannels, inChannels, {kernelDims}} * <=> {oDims[1], iDims[1], kernelDim} @@ -1238,54 +1227,36 @@ void conv3DBackwardInput(const array<DimSize_t, 3> &stride, for (DimSize_t batch = 0; batch < iDims[0]; ++batch, iOffset[0] += iStrides[0], oOffset[0] += oStrides[0]) { - kOffset[0] = 0; - oOffset[1] = oOffset[0]; - for (DimSize_t oChannel = 0; oChannel < oDims[1]; oChannel++, - oOffset[1] += oStrides[1], - kOffset[0] += kStrides[0]) { - - iOffset[1] = iOffset[0]; - kOffset[1] = kOffset[0]; - for (DimSize_t iChannel = 0; iChannel < iDims[1]; ++iChannel, - iOffset[1] += iStrides[1], - kOffset[1] += kStrides[1]) { - - oOffset[2] = oOffset[1]; - iOffset[2] = iOffset[1]; - DimSize_t iX = 0; - for (DimSize_t oX = 0; oX < oDims[2]; ++oX, - iX += stride[0], - oOffset[2] += oStrides[2], - iOffset[2] += stride[0] * iStrides[2]) { - - DimSize_t iY = 0; - oOffset[3] = oOffset[2]; - iOffset[3] = iOffset[2]; - for (DimSize_t oY = 0; oY < oDims[3]; ++oY, - iY += stride[1], - oOffset[3] += oStrides[3], - iOffset[3] += stride[1] * iStrides[3]) { - - DimSize_t iZ = 0; - for (DimSize_t oZ = 0; oZ < oDims[4]; - ++oZ, iZ += stride[2]) { + for (DimSize_t oChannel = 0; oChannel < oDims[1]; oChannel++) { + oOffset[1] = oChannel * oStrides[1] + oOffset[0]; + kOffset[0] = oChannel * kStrides[0]; + + for (DimSize_t iChannel = 0; iChannel < iDims[1]; ++iChannel) { + iOffset[1] = iChannel * iStrides[1] + iOffset[0]; + kOffset[1] = iChannel * kStrides[1] + kOffset[0]; + + for (DimSize_t oX = 0; oX < oDims[2]; ++oX) { + oOffset[2] = oX * oStrides[2] + oOffset[1]; + iOffset[2] = oX * stride[0] * iStrides[2] + iOffset[1]; + + for (DimSize_t oY = 0; oY < oDims[3]; ++oY) { + oOffset[3] = oY * oStrides[3] + oOffset[2]; + iOffset[3] = oY * stride[1] * iStrides[3] + iOffset[2]; + + for (DimSize_t oZ = 0; oZ < oDims[4]; ++oZ) { auto oIdx = oOffset[3] + oZ; - auto iIdx = iOffset[3] + iZ; - - iDilkernelOffset[0] = 0; - kOffset[2] = kOffset[1]; - for (DimSize_t kX = 0; kX < kDims[0]; ++kX, - iDilkernelOffset[0] += dilation[0] * - iStrides[2], - kOffset[2] += kStrides[2]) { - - kOffset[3] = kOffset[2]; - iDilkernelOffset[1] = iDilkernelOffset[0]; - for (DimSize_t kY = 0; kY < kDims[1]; - ++kY, - kOffset[3] += kStrides[3], - iDilkernelOffset[1] += - dilation[1] * iStrides[3]) { + auto iIdx = iOffset[3] + oZ * stride[2]; + + for (DimSize_t kX = 0; kX < kDims[0]; ++kX) { + kOffset[2] = kX * kStrides[2] + kOffset[1]; + iDilkernelOffset[0] = + kX * dilation[0] * iStrides[2]; + + for (DimSize_t kY = 0; kY < kDims[1]; ++kY) { + kOffset[3] = kY * kStrides[3] + kOffset[2]; + iDilkernelOffset[1] = + kY * dilation[1] * iStrides[3] + + iDilkernelOffset[0]; for (DimSize_t kZ = 0; kZ < kDims[2]; ++kZ) { @@ -1325,9 +1296,10 @@ void conv3DBackwardInput(const array<DimSize_t, 3> &stride, * @param[in] oDims output data dimmensions * @param[in] oStrides nb element in each dimension of output tensor * @param[in] oGrad gradients of output data - * @param[in] kDim dimensions of kernel (not taking in count In/OutChannels) - * @param[in] kStrides nb element in each dimension of kernel tensor (taking in - * count In/OutChannels) + * @param[in] kDim dimensions of kernel (not taking in count + * In/OutChannels) + * @param[in] kStrides nb element in each dimension of kernel tensor + * (taking in count In/OutChannels) * @param[in] stride attribute of the convolution operator * @param[in] dilation attribute of the convolution operator * @param[inout] weightsGrad gradients of the kernel weights @@ -1444,17 +1416,14 @@ static void conv3DBackwardBias(const array<DimSize_t, 5> &oDims, for (DimSize_t batchIdx = 0; batchIdx < oDims[0]; ++batchIdx) { oOffsets[0] = batchIdx * oStrides[0]; - oOffsets[1] = oOffsets[0]; - for (DimSize_t oChannel = 0; oChannel < oDims[1]; - ++oChannel, oOffsets[1] += oStrides[1]) { + for (DimSize_t oChannel = 0; oChannel < oDims[1]; ++oChannel) { + oOffsets[1] = oChannel * oStrides[1] + oOffsets[0]; - oOffsets[2] = oOffsets[1]; - for (DimSize_t oX = 0; oX < oDims[2]; - ++oX, oOffsets[2] += oStrides[2]) { + for (DimSize_t oX = 0; oX < oDims[2]; ++oX) { + oOffsets[2] = oX * oStrides[2] + oOffsets[1]; - oOffsets[3] = oOffsets[2]; - for (DimSize_t oY = 0; oY < oDims[3]; - ++oY, oOffsets[3] += oStrides[3]) { + for (DimSize_t oY = 0; oY < oDims[3]; ++oY) { + oOffsets[3] = oY * oStrides[3] + oOffsets[2]; for (DimSize_t oZ = 0; oZ < oDims[4]; ++oZ) { biasesGrad[oChannel] += oGrad[oOffsets[3] + oZ]; } diff --git a/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp b/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp index 7604a96a18e7be44f4c2e8970a0b60b1c4ad918b..d47636ef8112f2905583e92be4ccbd9710102bde 100644 --- a/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp +++ b/include/aidge/backend/cpu/operator/ConvTransposeImpl.hpp @@ -51,9 +51,24 @@ using ConvTransposeImpl2D_cpu = const void *, void *)>; +using ConvTranspose3D_Op = ConvTranspose_Op<3>; +using ConvTransposeImpl3D_cpu = + OperatorImpl_cpu<ConvTranspose3D_Op, + void(const array<DimSize_t, 3> &, + const array<DimSize_t, 3> &, + const array<DimSize_t, 3> &, + const array<DimSize_t, 5> &, + const array<DimSize_t, 5> &, + const void *, + const void *, + const void *, + void *)>; + // Implementation entry point registration to Operator REGISTRAR(ConvTranspose1D_Op, "cpu", ConvTransposeImpl1D_cpu::create); REGISTRAR(ConvTranspose2D_Op, "cpu", ConvTransposeImpl2D_cpu::create); +REGISTRAR(ConvTranspose3D_Op, "cpu", ConvTransposeImpl3D_cpu::create); + } // namespace Aidge #endif /* AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_H_ */ diff --git a/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp index e11dd2625ae1645a8e7c5482b1635b85fb475b06..a734add2acb612d86ceaa1d09514da5a727c7ce4 100644 --- a/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ConvTransposeImpl_kernels.hpp @@ -300,6 +300,149 @@ REGISTRAR( ConvTransposeImpl2D_cpu_forward_kernel<double, double, double, double>, nullptr}); +//////////////////////////////////////////////////////// +//////////////////////////////////////////////////////// +// 3D +//////////////////////////////////////////////////////// +//////////////////////////////////////////////////////// + +/** + * @brief performs forward bias operation for convtranspose operator + * + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] bias bias values + * @param[in] oDims dimensions of the output + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[out] output + */ +template <class B, class O> +static void convTranspose3DForwardBias(const B *biases, + const array<DimSize_t, 5> &oDims, + const array<DimSize_t, 4> &oStrides, + O *output) { + array<DimSize_t, 2> outOffsets{0, 0}; + + for (DimSize_t batch = 0; batch < oDims[0]; ++batch) { + outOffsets[0] = batch * oStrides[0]; + + for (DimSize_t outCh = 0; outCh < oDims[1]; ++outCh) { + outOffsets[1] = outCh * oStrides[1] + outOffsets[0]; + // If bias = nullptr, set B(0) + B biasVal = (biases != nullptr) ? biases[outCh] : B(0); + std::fill(output + outOffsets[1], + (output + outOffsets[1]) + oStrides[1], + biasVal); + } + } +} + +/** + * @brief forward kernel for convtranspose + * @note ConvTranspose forward is simply convolution backward kernel. + * Check convolution functions for more in-depth details on how the + subfunctions are built. + * @tparam I Input data type. + * @tparam W Weight data type. + * @tparam B Bias data type. + * @tparam O Output data type. + * @param[in] stride stride parameter of the convTranspose operator + * @param[in] dilation dilation parameter of the convTranspose operator + * @param[in] inputDims input dimensions + * @param[in] outputDims output tensor dimensions + * @param[in] oStrides nb of elements contained per dimension of the output + * @param[in] input_ values + * @param[in] weight_ values + * @param[in] biases_ values + * @param[out] output + */ +template <class I, class W, class B, class O> +void ConvTransposeImpl3D_cpu_forward_kernel( + const array<DimSize_t, 3> &stride, + const array<DimSize_t, 3> &dilation, + const array<DimSize_t, 3> &kernelDims, + const array<DimSize_t, 5> &inputDims, + const array<DimSize_t, 5> &outputDims, + const void *input_, + const void *weights_, + const void *biases_, + void *output_) { + + auto input = static_cast<const I *>(input_); + auto weights = static_cast<const W *>(weights_); + auto output = static_cast<O *>(output_); + + // {channel_stride, dim0_stride, dim1_stride} + const array<DimSize_t, 4> inputStrides{ + inputDims[1] * inputDims[2] * inputDims[3] * inputDims[4], + inputDims[2] * inputDims[3] * inputDims[4], + inputDims[3] * inputDims[4], + inputDims[4]}; + + // {channel_stride, dim0_stride, dim1_stride} + const array<DimSize_t, 4> outputStrides{ + outputDims[1] * outputDims[2] * outputDims[3] * outputDims[4], + outputDims[2] * outputDims[3] * outputDims[4], + outputDims[3] * outputDims[4], + outputDims[4]}; + + // NOTE: kernel dims = {inChannels, outChannels, kernelDims[0], + // kernelDims[1]} + const array<DimSize_t, 4> kernelStrides{ + outputDims[1] * kernelDims[0] * kernelDims[1] * kernelDims[2], + kernelDims[0] * kernelDims[1] * kernelDims[2], + kernelDims[1] * kernelDims[2], + kernelDims[2]}; + + if (biases_ != nullptr) { + auto biases = static_cast<const B *>(biases_); + convTranspose3DForwardBias(biases, outputDims, outputStrides, output); + } + + conv3DBackwardInput(stride, + dilation, + kernelDims, + kernelStrides, + weights, + inputDims, + inputStrides, + input, + outputDims, + outputStrides, + output); +} + +REGISTRAR(ConvTransposeImpl3D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Int32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl3D_cpu_forward_kernel<std::int32_t, + std::int32_t, + std::int32_t, + std::int32_t>, + nullptr}); +REGISTRAR(ConvTransposeImpl3D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float16, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl3D_cpu_forward_kernel<half_float::half, + half_float::half, + half_float::half, + half_float::half>, + nullptr}); +REGISTRAR(ConvTransposeImpl3D_cpu, + {{DataType::Any, DataFormat::NCHW}, + {DataType::Float32, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl3D_cpu_forward_kernel<float, float, float, float>, + nullptr}); +REGISTRAR( + ConvTransposeImpl3D_cpu, + {{DataType::Any, DataFormat::NCHW}, {DataType::Float64, DataFormat::NCHW}}, + {ProdConso::inPlaceModel, + ConvTransposeImpl3D_cpu_forward_kernel<double, double, double, double>, + nullptr}); + } // namespace Aidge #endif /* AIDGE_CPU_OPERATOR_CONVTRANSPOSEIMPL_KERNELS_H_ */ diff --git a/src/operator/ConvTransposeImpl.cpp b/src/operator/ConvTransposeImpl.cpp index d1135cc92dd3c68746b9dcf80739f4f65acdad2e..4f6a8f62be6cc14303419ce6cfa89b3065b01569 100644 --- a/src/operator/ConvTransposeImpl.cpp +++ b/src/operator/ConvTransposeImpl.cpp @@ -89,3 +89,42 @@ template <> void Aidge::ConvTransposeImpl2D_cpu::backward() { "Backward not yet implemented for Conv_Op<2> on backend cpu"); } +template <> void Aidge::ConvTransposeImpl3D_cpu::forward() { + const auto &op = static_cast<const ConvTranspose_Op<3> &>(mOp); + + AIDGE_ASSERT(op.getInput(0), "{}: missing data input (#0).", op.type()); + AIDGE_ASSERT(op.getInput(1), "{}: missing bias input (#1).", op.type()); + AIDGE_ASSERT(op.getInput(2), "{}: missing weight input (#1).", op.type()); + + std::shared_ptr<Tensor> inputDataFallback, inputWeightFallback, + inputBiasFallback; + const auto &inputData = + op.getInput(0)->refCastFrom(inputDataFallback, *op.getOutput(0)); + const auto &inputWeight = + op.getInput(1)->refCastFrom(inputWeightFallback, *op.getOutput(0)); + const auto &inputBias = + (op.getInput(2)) + ? op.getInput(2)->refCastFrom(inputBiasFallback, *op.getOutput(0)) + : Tensor(); + + // Call kernel + const auto impl = Registrar<ConvTransposeImpl3D_cpu>::create( + getBestMatch(getRequiredSpec())); + + impl.forward(op.strideDims(), + op.dilationDims(), + op.kernelDims(), + op.getInput(0)->template dims<5>(), + op.getOutput(0)->template dims<5>(), + inputData.getImpl()->hostPtr(), + inputWeight.getImpl()->hostPtr(), + op.getInput(2) ? inputBias.getImpl()->hostPtr() : nullptr, + op.getOutput(0)->getImpl()->rawPtr()); +} + +template <> void Aidge::ConvTransposeImpl3D_cpu::backward() { + AIDGE_THROW_OR_ABORT( + std::runtime_error, + "Backward not yet implemented for Conv_Op<2> on backend cpu"); +} + diff --git a/unit_tests/operator/Test_ConvImpl.cpp b/unit_tests/operator/Test_ConvImpl.cpp index a47b315bd89ee7e9f054311dc88c04767e518c0a..7b14546351e1eb61af1063a5b98aa16fcebdd029 100644 --- a/unit_tests/operator/Test_ConvImpl.cpp +++ b/unit_tests/operator/Test_ConvImpl.cpp @@ -21,8 +21,8 @@ #include "aidge/filler/Filler.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Conv.hpp" -#include "aidge/utils/TensorUtils.hpp" #include "aidge/operator/Pad.hpp" +#include "aidge/utils/TensorUtils.hpp" namespace Aidge { diff --git a/unit_tests/operator/Test_ConvTranspose.cpp b/unit_tests/operator/Test_ConvTranspose.cpp index 6e889e809e0a05d551829bd15fda9cc651068465..7bb87835a3d9210b7f2f6bce682df60657d049a7 100644 --- a/unit_tests/operator/Test_ConvTranspose.cpp +++ b/unit_tests/operator/Test_ConvTranspose.cpp @@ -2293,6 +2293,1332 @@ TEST_CASE("[cpu/operator] ConvTranspose(forward)", "[ConvTranspose][CPU]") { CHECK(approxEq<float, float>(*op->getOutput(0), *expectedOutput)); } } + SECTION("3D") { + constexpr DimSize_t DIM = 3; + SECTION("Big test to ensure kernel capabilities") { + constexpr DimSize_t batchSize = 1; + constexpr DimSize_t inChannels = 3; + constexpr DimSize_t outChannels = 2; + + constexpr std::array<DimSize_t, DIM> kernelSize{1, 2, 3}; + + constexpr std::array<DimSize_t, DIM> inDataSize{4, 4, 5}; + constexpr std::array<DimSize_t, DIM> outDataSize{4, 10, 15}; + + constexpr std::array<DimSize_t, DIM> stride{1, 2, 3}; + constexpr std::array<DimSize_t, DIM> dilation{2, 3, 1}; + + auto input = std::make_shared<Tensor>(Array5D<float, + batchSize, + inChannels, + inDataSize[0], + inDataSize[1], + inDataSize[2]>( + {{{{{{1., 2., 3., 4., 5.}, + {6., 7., 8., 9., 10.}, + {11., 12., 13., 14., 15.}, + {16., 17., 18., 19., 20.}}, + + {{21., 22., 23., 24., 25.}, + {26., 27., 28., 29., 30.}, + {31., 32., 33., 34., 35.}, + {36., 37., 38., 39., 40.}}, + + {{41., 42., 43., 44., 45.}, + {46., 47., 48., 49., 50.}, + {51., 52., 53., 54., 55.}, + {56., 57., 58., 59., 60.}}, + + {{61., 62., 63., 64., 65.}, + {66., 67., 68., 69., 70.}, + {71., 72., 73., 74., 75.}, + {76., 77., 78., 79., 80.}}}, + + {{{81., 82., 83., 84., 85.}, + {86., 87., 88., 89., 90.}, + {91., 92., 93., 94., 95.}, + {96., 97., 98., 99., 100.}}, + + {{101., 102., 103., 104., 105.}, + {106., 107., 108., 109., 110.}, + {111., 112., 113., 114., 115.}, + {116., 117., 118., 119., 120.}}, + + {{121., 122., 123., 124., 125.}, + {126., 127., 128., 129., 130.}, + {131., 132., 133., 134., 135.}, + {136., 137., 138., 139., 140.}}, + + {{141., 142., 143., 144., 145.}, + {146., 147., 148., 149., 150.}, + {151., 152., 153., 154., 155.}, + {156., 157., 158., 159., 160.}}}, + + {{{161., 162., 163., 164., 165.}, + {166., 167., 168., 169., 170.}, + {171., 172., 173., 174., 175.}, + {176., 177., 178., 179., 180.}}, + + {{181., 182., 183., 184., 185.}, + {186., 187., 188., 189., 190.}, + {191., 192., 193., 194., 195.}, + {196., 197., 198., 199., 200.}}, + + {{201., 202., 203., 204., 205.}, + {206., 207., 208., 209., 210.}, + {211., 212., 213., 214., 215.}, + {216., 217., 218., 219., 220.}}, + + {{221., 222., 223., 224., 225.}, + {226., 227., 228., 229., 230.}, + {231., 232., 233., 234., 235.}, + {236., 237., 238., 239., 240.}}}}}})); + + auto weights = std::make_shared<Tensor>(Array5D<float, + inChannels, + outChannels, + kernelSize[0], + kernelSize[1], + kernelSize[2]>( + {{{{{{0.1, 0.2, 0.3}, {0.4, 0.5, 0.6}}}, + + {{{0.7, 0.8, 0.9}, {1.0, 1.1, 1.2}}}}, + + {{{{1.3, 1.4, 1.5}, {1.6, 1.7, 1.8}}}, + + {{{1.9, 2.0, 2.1}, {2.2, 2.3, 2.4}}}}}})); + + auto biases = std::make_shared<Tensor>( + Array1D<float, outChannels>({{0.01, 0.02}})); + + auto op = setupTestConvTranspose<DIM>(batchSize, + inChannels, + outChannels, + kernelSize, + inDataSize, + stride, + dilation, + input, + weights, + biases); + + REQUIRE_NOTHROW(op->forward()); + + auto expectedOutput = std::make_shared<Tensor>( + Array5D<float, + batchSize, + outChannels, + outDataSize[0], + outDataSize[1], + outDataSize[2]>({{{{{{507.910034, + 532.210022, + 556.510010, + 511.809998, + 536.410034, + 561.010010, + 515.710022, + 540.610046, + 565.510010, + 519.609985, + 544.810059, + 570.010010, + 523.510010, + 549.010010, + 574.510010}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {527.410034, + 553.210022, + 579.010010, + 531.309998, + 557.410034, + 583.510010, + 535.210022, + 561.610046, + 588.010010, + 539.109985, + 565.810059, + 592.510010, + 543.010010, + 570.010010, + 597.010010}, + {580.809998, + 605.110046, + 629.410034, + 585.609985, + 610.210022, + 634.809998, + 590.409973, + 615.310059, + 640.210022, + 595.210022, + 620.410034, + 645.609985, + 600.010010, + 625.510010, + 651.010010}, + {546.910034, + 574.210022, + 601.510010, + 550.809998, + 578.410034, + 606.010010, + 554.710022, + 582.610046, + 610.510010, + 558.609985, + 586.810059, + 615.010010, + 562.510010, + 591.010010, + 619.510010}, + {604.809998, + 630.610046, + 656.410034, + 609.609985, + 635.710022, + 661.809998, + 614.409973, + 640.810059, + 667.210022, + 619.210022, + 645.910034, + 672.609985, + 624.010010, + 651.010010, + 678.010010}, + {566.410034, + 595.210022, + 624.010010, + 570.309998, + 599.410034, + 628.510010, + 574.210022, + 603.610046, + 633.010010, + 578.109985, + 607.810059, + 637.510010, + 582.010010, + 612.010010, + 642.010010}, + {628.809998, + 656.110046, + 683.410034, + 633.609985, + 661.210022, + 688.809998, + 638.409973, + 666.310059, + 694.210022, + 643.210022, + 671.410034, + 699.609985, + 648.010010, + 676.510010, + 705.010010}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {652.809998, + 681.610046, + 710.410034, + 657.609985, + 686.710022, + 715.809998, + 662.409973, + 691.810059, + 721.210022, + 667.210022, + 696.910034, + 726.609985, + 672.010010, + 702.010010, + 732.010010}}, + + {{585.910034, + 616.210022, + 646.510010, + 589.809998, + 620.410034, + 651.010010, + 593.710022, + 624.610046, + 655.510010, + 597.609985, + 628.810059, + 660.010010, + 601.510010, + 633.010010, + 664.510010}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {605.410034, + 637.210022, + 669.010010, + 609.309998, + 641.410034, + 673.510010, + 613.210022, + 645.610046, + 678.010010, + 617.109985, + 649.810059, + 682.510010, + 621.010010, + 654.010010, + 687.010010}, + {676.809998, + 707.110046, + 737.410034, + 681.609985, + 712.210022, + 742.809998, + 686.409973, + 717.310059, + 748.210022, + 691.210022, + 722.410034, + 753.609985, + 696.010010, + 727.510010, + 759.010010}, + {624.910034, + 658.210022, + 691.510010, + 628.809998, + 662.410034, + 696.010010, + 632.710022, + 666.610046, + 700.510010, + 636.609985, + 670.810059, + 705.010010, + 640.510010, + 675.010010, + 709.510010}, + {700.809998, + 732.610046, + 764.410034, + 705.609985, + 737.710022, + 769.809998, + 710.409973, + 742.810059, + 775.210022, + 715.210022, + 747.910034, + 780.609985, + 720.010010, + 753.010010, + 786.010010}, + {644.410034, + 679.210022, + 714.010010, + 648.309998, + 683.410034, + 718.510010, + 652.210022, + 687.610046, + 723.010010, + 656.109985, + 691.810059, + 727.510010, + 660.010010, + 696.010010, + 732.010010}, + {724.809998, + 758.110046, + 791.410034, + 729.609985, + 763.210022, + 796.809998, + 734.409973, + 768.310059, + 802.210022, + 739.210022, + 773.410034, + 807.609985, + 744.010010, + 778.510010, + 813.010010}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {748.809998, + 783.610046, + 818.410034, + 753.609985, + 788.710022, + 823.809998, + 758.409973, + 793.810059, + 829.210022, + 763.210022, + 798.910034, + 834.609985, + 768.010010, + 804.010010, + 840.010010}}, + + {{663.910034, + 700.210022, + 736.510010, + 667.809998, + 704.410034, + 741.010010, + 671.710022, + 708.610046, + 745.510010, + 675.609985, + 712.810059, + 750.010010, + 679.510010, + 717.010010, + 754.510010}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {683.410034, + 721.210022, + 759.010010, + 687.309998, + 725.410034, + 763.510010, + 691.210022, + 729.610046, + 768.010010, + 695.109985, + 733.810059, + 772.510010, + 699.010010, + 738.010010, + 777.010010}, + {772.809998, + 809.110046, + 845.410034, + 777.609985, + 814.210022, + 850.809998, + 782.409973, + 819.310059, + 856.210022, + 787.210022, + 824.410034, + 861.609985, + 792.010010, + 829.510010, + 867.010010}, + {702.910034, + 742.210022, + 781.510010, + 706.809998, + 746.410034, + 786.010010, + 710.710022, + 750.610046, + 790.510010, + 714.609985, + 754.810059, + 795.010010, + 718.510010, + 759.010071, + 799.510010}, + {796.809998, + 834.610046, + 872.410034, + 801.609985, + 839.710022, + 877.810059, + 806.409973, + 844.810059, + 883.210022, + 811.210022, + 849.910034, + 888.609985, + 816.010010, + 855.010010, + 894.010010}, + {722.410034, + 763.210022, + 804.010010, + 726.309998, + 767.410034, + 808.510010, + 730.210022, + 771.610046, + 813.010010, + 734.109985, + 775.810059, + 817.510010, + 738.010010, + 780.010071, + 822.010010}, + {820.809998, + 860.110046, + 899.410034, + 825.609985, + 865.210022, + 904.810059, + 830.409973, + 870.310059, + 910.210022, + 835.210022, + 875.410034, + 915.609985, + 840.010010, + 880.510010, + 921.010010}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {844.809998, + 885.610046, + 926.410034, + 849.609985, + 890.710022, + 931.810059, + 854.409973, + 895.810059, + 937.210022, + 859.210022, + 900.910034, + 942.609985, + 864.010010, + 906.010010, + 948.010010}}, + + {{741.910034, + 784.210022, + 826.510010, + 745.809998, + 788.410034, + 831.010010, + 749.710022, + 792.610046, + 835.510010, + 753.609985, + 796.810059, + 840.010010, + 757.510010, + 801.010071, + 844.510010}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {761.410034, + 805.210022, + 849.010010, + 765.310059, + 809.410034, + 853.510010, + 769.210022, + 813.610046, + 858.010010, + 773.109985, + 817.810059, + 862.510010, + 777.010010, + 822.010071, + 867.010010}, + {868.809998, + 911.110046, + 953.410034, + 873.609985, + 916.210022, + 958.810059, + 878.409973, + 921.310059, + 964.210022, + 883.210022, + 926.410034, + 969.609985, + 888.010010, + 931.510010, + 975.010010}, + {780.910034, + 826.210022, + 871.510010, + 784.810059, + 830.410034, + 876.010010, + 788.710022, + 834.610046, + 880.510010, + 792.609985, + 838.810059, + 885.010010, + 796.510010, + 843.010071, + 889.510010}, + {892.809998, + 936.610046, + 980.410034, + 897.609985, + 941.710022, + 985.810059, + 902.409973, + 946.810059, + 991.210022, + 907.210022, + 951.910034, + 996.609985, + 912.010010, + 957.010010, + 1002.010010}, + {800.410034, + 847.210022, + 894.010010, + 804.310059, + 851.410034, + 898.510010, + 808.210022, + 855.610046, + 903.010010, + 812.109985, + 859.810059, + 907.510010, + 816.010010, + 864.010071, + 912.010010}, + {916.809998, + 962.110046, + 1007.410034, + 921.609985, + 967.210022, + 1012.810059, + 926.409973, + 972.310059, + 1018.210022, + 931.210022, + 977.410034, + 1023.609985, + 936.010010, + 982.510010, + 1029.010010}, + {0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000, + 0.010000}, + {940.809998, + 987.610046, + 1034.410034, + 945.609985, + 992.710022, + 1039.810059, + 950.409973, + 997.810059, + 1045.209961, + 955.210022, + 1002.910034, + 1050.609985, + 960.010010, + 1008.010010, + 1056.010010}}}, + + {{{653.720032, + 678.020020, + 702.320007, + 659.420044, + 684.020020, + 708.620056, + 665.120056, + 690.020020, + 714.920044, + 670.820007, + 696.020020, + 721.220032, + 676.520020, + 702.020020, + 727.520020}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {682.220032, + 708.020020, + 733.820007, + 687.920044, + 714.020020, + 740.120056, + 693.620056, + 720.020020, + 746.420044, + 699.320007, + 726.020020, + 752.720032, + 705.020020, + 732.020020, + 759.020020}, + {726.620056, + 750.920044, + 775.220032, + 733.220032, + 757.820007, + 782.420044, + 739.820068, + 764.720032, + 789.620056, + 746.420044, + 771.619995, + 796.820068, + 753.020020, + 778.520020, + 804.020081}, + {710.720032, + 738.020020, + 765.320007, + 716.420044, + 744.020020, + 771.620056, + 722.120056, + 750.020020, + 777.920044, + 727.820068, + 756.020020, + 784.220032, + 733.520020, + 762.020020, + 790.520020}, + {759.620056, + 785.420044, + 811.220032, + 766.220032, + 792.320007, + 818.420044, + 772.820068, + 799.220032, + 825.620056, + 779.420044, + 806.119995, + 832.820068, + 786.020020, + 813.020020, + 840.020081}, + {739.220032, + 768.020020, + 796.820007, + 744.920044, + 774.020020, + 803.120056, + 750.620056, + 780.020020, + 809.420044, + 756.320068, + 786.020020, + 815.720032, + 762.020020, + 792.020020, + 822.020020}, + {792.620056, + 819.920044, + 847.220032, + 799.220032, + 826.820007, + 854.420044, + 805.820068, + 833.720032, + 861.620056, + 812.420044, + 840.619995, + 868.820068, + 819.020020, + 847.520020, + 876.020081}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {825.620056, + 854.420044, + 883.220032, + 832.220032, + 861.320007, + 890.420044, + 838.820068, + 868.220032, + 897.620056, + 845.420044, + 875.119995, + 904.820068, + 852.020020, + 882.020020, + 912.020020}}, + + {{767.720032, + 798.020020, + 828.320007, + 773.420044, + 804.020020, + 834.620056, + 779.120056, + 810.020020, + 840.920044, + 784.820068, + 816.020020, + 847.220032, + 790.520020, + 822.020020, + 853.520020}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {796.220032, + 828.020020, + 859.820007, + 801.920044, + 834.020020, + 866.120056, + 807.620056, + 840.020020, + 872.420044, + 813.320068, + 846.020020, + 878.720032, + 819.020020, + 852.020020, + 885.020020}, + {858.620056, + 888.920044, + 919.220032, + 865.220032, + 895.820007, + 926.420044, + 871.820068, + 902.719971, + 933.620056, + 878.420044, + 909.619995, + 940.820068, + 885.020020, + 916.520020, + 948.020020}, + {824.720032, + 858.020020, + 891.320068, + 830.420044, + 864.020020, + 897.620056, + 836.120056, + 870.020020, + 903.920044, + 841.820068, + 876.020020, + 910.220032, + 847.520020, + 882.020020, + 916.520020}, + {891.620056, + 923.420044, + 955.220032, + 898.220032, + 930.320007, + 962.420044, + 904.820068, + 937.219971, + 969.620056, + 911.420044, + 944.119995, + 976.820068, + 918.020020, + 951.020020, + 984.020020}, + {853.220032, + 888.020020, + 922.820068, + 858.920044, + 894.020020, + 929.120056, + 864.620056, + 900.020020, + 935.420044, + 870.320068, + 906.020020, + 941.720032, + 876.020020, + 912.020020, + 948.020020}, + {924.620056, + 957.920044, + 991.220032, + 931.220032, + 964.820007, + 998.420044, + 937.820068, + 971.719971, + 1005.620056, + 944.420044, + 978.619995, + 1012.820068, + 951.020020, + 985.520020, + 1020.020020}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {957.620056, + 992.420044, + 1027.220093, + 964.220032, + 999.320007, + 1034.420044, + 970.820068, + 1006.219971, + 1041.620117, + 977.420044, + 1013.119995, + 1048.820068, + 984.020020, + 1020.020020, + 1056.020020}}, + + {{881.720032, + 918.020020, + 954.320068, + 887.420044, + 924.020020, + 960.620056, + 893.120056, + 930.020020, + 966.920044, + 898.820068, + 936.020020, + 973.220032, + 904.520020, + 942.020020, + 979.520020}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {910.220032, + 948.020020, + 985.820068, + 915.920044, + 954.020020, + 992.120056, + 921.620056, + 960.020020, + 998.420044, + 927.320068, + 966.020020, + 1004.720032, + 933.020020, + 972.020020, + 1011.020020}, + {990.620056, + 1026.920044, + 1063.220093, + 997.220032, + 1033.820068, + 1070.420044, + 1003.820068, + 1040.719971, + 1077.620117, + 1010.420044, + 1047.619995, + 1084.820068, + 1017.020020, + 1054.520020, + 1092.020020}, + {938.720032, + 978.020020, + 1017.320068, + 944.420044, + 984.020020, + 1023.620056, + 950.120056, + 990.020020, + 1029.920044, + 955.820068, + 996.020020, + 1036.220093, + 961.520081, + 1002.020020, + 1042.520020}, + {1023.620056, + 1061.420044, + 1099.220093, + 1030.220093, + 1068.320068, + 1106.420044, + 1036.820068, + 1075.219971, + 1113.620117, + 1043.420044, + 1082.119995, + 1120.820068, + 1050.020020, + 1089.020020, + 1128.020020}, + {967.220032, + 1008.020020, + 1048.820068, + 972.920044, + 1014.020020, + 1055.119995, + 978.620056, + 1020.020020, + 1061.420044, + 984.320068, + 1026.020020, + 1067.720093, + 990.020081, + 1032.020020, + 1074.020020}, + {1056.619995, + 1095.920044, + 1135.220093, + 1063.220093, + 1102.820068, + 1142.420044, + 1069.820068, + 1109.719971, + 1149.620117, + 1076.420044, + 1116.619995, + 1156.820068, + 1083.020020, + 1123.520020, + 1164.020020}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {1089.619995, + 1130.420044, + 1171.220093, + 1096.220093, + 1137.320068, + 1178.420044, + 1102.820068, + 1144.219971, + 1185.620117, + 1109.420044, + 1151.119995, + 1192.820068, + 1116.020020, + 1158.020020, + 1200.020020}}, + + {{995.720032, + 1038.020020, + 1080.320068, + 1001.420044, + 1044.020020, + 1086.619995, + 1007.120056, + 1050.020020, + 1092.920044, + 1012.820068, + 1056.020020, + 1099.220093, + 1018.520081, + 1062.020020, + 1105.520020}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {1024.220093, + 1068.020020, + 1111.820068, + 1029.920044, + 1074.020020, + 1118.119995, + 1035.619995, + 1080.020020, + 1124.420044, + 1041.320068, + 1086.020020, + 1130.720093, + 1047.020020, + 1092.020020, + 1137.020020}, + {1122.619995, + 1164.920044, + 1207.220093, + 1129.220093, + 1171.820068, + 1214.420044, + 1135.820068, + 1178.719971, + 1221.620117, + 1142.420044, + 1185.619995, + 1228.820068, + 1149.020020, + 1192.520020, + 1236.020020}, + {1052.720093, + 1098.020020, + 1143.320068, + 1058.420044, + 1104.020020, + 1149.619995, + 1064.119995, + 1110.020020, + 1155.920044, + 1069.820068, + 1116.020020, + 1162.220093, + 1075.520020, + 1122.020020, + 1168.520020}, + {1155.619995, + 1199.420044, + 1243.220093, + 1162.220093, + 1206.320068, + 1250.420044, + 1168.820068, + 1213.219971, + 1257.620117, + 1175.420044, + 1220.119995, + 1264.820068, + 1182.020020, + 1227.020020, + 1272.020020}, + {1081.220093, + 1128.020020, + 1174.820068, + 1086.920044, + 1134.020020, + 1181.119995, + 1092.619995, + 1140.020020, + 1187.420044, + 1098.320068, + 1146.020020, + 1193.720093, + 1104.020020, + 1152.020020, + 1200.020020}, + {1188.619995, + 1233.920044, + 1279.220093, + 1195.220093, + 1240.820068, + 1286.420044, + 1201.820068, + 1247.719971, + 1293.620117, + 1208.420044, + 1254.619995, + 1300.820068, + 1215.020020, + 1261.520020, + 1308.020020}, + {0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000, + 0.020000}, + {1221.619995, + 1268.420044, + 1315.220093, + 1228.220093, + 1275.320068, + 1322.420044, + 1234.820068, + 1282.219971, + 1329.620117, + 1241.420044, + 1289.119995, + 1336.820068, + 1248.020020, + 1296.020020, + 1344.020020}}}}}})); + } + } } } // namespace Aidge