diff --git a/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp index 46ae59877bee1b87a9a17be242434d3caca7aae2..906ea1adf744353372c844fd3e16b9dbd13e7f7d 100644 --- a/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp @@ -175,17 +175,17 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri } } } else { - for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex-=strideDims[0]*inputDims[3]) { + for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=(strideDims[0]-2)*inputDims[3]) { for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy]+weights[wIndex+1]*input[iIndex+oy+strideDims[0]]+weights[wIndex+2]*input[iIndex+oy+strideDims[0]*2]; + output[oIndex + oy] = biasVal + weights[wIndex+0]*input[iIndex+oy*strideDims[1]]+weights[wIndex+1]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+2]*input[iIndex+oy*strideDims[1]+2]; } - iIndex+=strideDims[0]*inputDims[3]; + iIndex+=inputDims[3]; for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy]+weights[wIndex+4]*input[iIndex+oy+strideDims[0]]+weights[wIndex+5]*input[iIndex+oy+strideDims[0]*2]; + output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy*strideDims[1]]+weights[wIndex+4]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+5]*input[iIndex+oy*strideDims[1]+2]; } - iIndex+=strideDims[0]*inputDims[3]; + iIndex+=inputDims[3]; for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy]+weights[wIndex+7]*input[iIndex+oy+strideDims[0]]+weights[wIndex+8]*input[iIndex+oy+strideDims[0]*2]; + output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy*strideDims[1]]+weights[wIndex+7]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+8]*input[iIndex+oy*strideDims[1]+2]; } } } @@ -193,25 +193,23 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri } } } else if (dilated_kernel_x == 1 && dilated_kernel_y == 1) { - std::size_t index = 0; for (std::size_t batch = 0; batch < inputDims[0]; ++batch) { for (std::size_t ch = 0; ch < inputDims[1]; ++ch) { B biasVal = (biases != nullptr) ? biases[ch] : B(0); - const std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3]; + std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3]; const std::size_t wIndex = ch; if (strideDims[0] == 1 && strideDims[1] == 1) { - for (; index < iIndex + oxSize*oySize; ++index) { - output[index] = biasVal + weights[wIndex] * input[index]; + for (std::size_t i = iIndex; i < iIndex + oxSize*oySize; ++i) { + output[i] = biasVal + weights[wIndex] * input[i]; } } else { std::size_t oIndex = (ch + batch*inputDims[1]) * oxSize * oySize; - for (std::size_t ox = 0; ox < oxSize; ++ox, oIndex+=oySize) { - index = iIndex + strideDims[0]*inputDims[3]; + for (std::size_t ox = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=strideDims[0]*inputDims[3]) { for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) { - output[oIndex + oy] += weights[wIndex]*input[index+iy]; + output[oIndex + oy] = biasVal + weights[wIndex]*input[iIndex+iy]; } } } @@ -234,16 +232,16 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri const std::size_t ix = ox * strideDims[0]; const std::size_t iy = oy * strideDims[1]; - for (std::size_t sx = 0; sx*dilationDims[0] < dilated_kernel_x; ++sx) { - for (std::size_t sy = 0; sy*dilationDims[1] < dilated_kernel_y; ++sy) { - output[oIndexFull] += weights[wIndex + sx*kernelDims[1] + sy] * - input[iIndex + static_cast<std::size_t>(ix + sx*dilationDims[0])*inputDims[3] + static_cast<std::size_t>(iy + sy*dilationDims[1])]; + for (std::size_t kx = 0; kx*dilationDims[0] < dilated_kernel_x; ++kx) { + for (std::size_t ky = 0; ky*dilationDims[1] < dilated_kernel_y; ++ky) { + output[oIndexFull] += weights[wIndex + kx*kernelDims[1] + ky] * + input[iIndex + (ix + kx*dilationDims[0])*inputDims[3] + (iy + ky*dilationDims[1])]; } } } } + output += outChannels_s; } - output += outChannels_s; } } } diff --git a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp index e3b709bf308288a93fd72865a2fdef0e58908134..1229d5714e6b0cbae4e42ece9130c2c2305f133e 100644 --- a/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/ConvImpl_kernels.hpp @@ -183,17 +183,17 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, } } } else { - for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex-=strideDims[0]*inputDims[3]) { + for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=(strideDims[0]-2)*inputDims[3]) { for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy]+weights[wIndex+1]*input[iIndex+oy+strideDims[0]]+weights[wIndex+2]*input[iIndex+oy+strideDims[0]*2]; + output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy*strideDims[1]]+weights[wIndex+1]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+2]*input[iIndex+oy*strideDims[1]+2]; } - iIndex+=strideDims[0]*inputDims[3]; + iIndex+=inputDims[3]; for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy]+weights[wIndex+4]*input[iIndex+oy+strideDims[0]]+weights[wIndex+5]*input[iIndex+oy+strideDims[0]*2]; + output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy*strideDims[1]]+weights[wIndex+4]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+5]*input[iIndex+oy*strideDims[1]+2]; } - iIndex+=strideDims[0]*inputDims[3]; + iIndex+=inputDims[3]; for (std::size_t oy = 0; oy < oySize; ++oy) { - output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy]+weights[wIndex+7]*input[iIndex+oy+strideDims[0]]+weights[wIndex+8]*input[iIndex+oy+strideDims[0]*2]; + output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy*strideDims[1]]+weights[wIndex+7]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+8]*input[iIndex+oy*strideDims[1]+2]; } } } diff --git a/unit_tests/operator/Test_ConvDepthWiseImpl.cpp b/unit_tests/operator/Test_ConvDepthWiseImpl.cpp index f1594ef5a21070803a7b86861eac513708ec03a2..4382610148bb3bf3722cb3f3518d5efeb1ee2caf 100644 --- a/unit_tests/operator/Test_ConvDepthWiseImpl.cpp +++ b/unit_tests/operator/Test_ConvDepthWiseImpl.cpp @@ -9,221 +9,1342 @@ * ********************************************************************************/ -#include <catch2/catch_test_macros.hpp> #include <memory> -#include <vector> + +#include <catch2/catch_test_macros.hpp> +#include <fmt/core.h> #include "aidge/backend/cpu/operator/ConvDepthWiseImpl.hpp" -#include "aidge/backend/cpu/operator/ConvDepthWiseImpl_kernels.hpp" +#include "aidge/data/Data.hpp" // DataType #include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" #include "aidge/operator/ConvDepthWise.hpp" #include "aidge/utils/TensorUtils.hpp" using namespace Aidge; +/** + * @brief ConvDepthWise reference cpp backend forward implmentation tests. + * + * Summary + * ======= + * kernel [3, 3] + * no stride, no dilation + * stride [2,2], no dilation + * stride [2,2], dilation [2,2] + * kernel [1,1] + * no stride, no dilation + * stride [3,3], no dilation + * stride [3,3], dilation [2,2] + * kernel [5,5] + * no stride, no dilation + * stride [2,2], no dilation + * stride [2,2], dilation [2,2] + */ TEST_CASE("[cpu/operator] ConvDepthWise(forward)", "[ConvDepthWise][CPU]") { - SECTION("k[3,3]") { - std::shared_ptr<Node> myCDW = ConvDepthWise(4, {3,3}, "mycdw"); - auto op = std::static_pointer_cast<OperatorTensor>(myCDW -> getOperator()); - std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<int,4,1,3,3> { - { - {{ - { 0, 1, 2}, - { 3, 4, 5}, - { 6, 7, 8} - - }}, - {{ - { 27, 28, 29}, - { 30, 31, 32}, - { 33, 34, 35} - - }}, - {{ - { 54, 55, 56}, - { 57, 58, 59}, - { 60, 61, 62} - }}, - {{ - { 81, 82, 83}, - { 84, 85, 86}, - { 87, 88, 89} - }} - } - }); - std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<int,4> {{7,0,9,0}}); - std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<int,2,4,5,5> { //NCHW - { - { - {{ 0, 1, 2, 3, 4}, - { 5, 6, 7, 8, 9}, - { 10, 11, 12, 13, 14}, - { 15, 16, 17, 18, 19}, - { 20, 21, 22, 23, 24}}, - - {{ 25, 26, 27, 28, 29}, - { 30, 31, 32, 33, 34}, - { 35, 36, 37, 38, 39}, - { 40, 41, 42, 43, 44}, - { 45, 46, 47, 48, 49}}, - - {{ 50, 51, 52, 53, 54}, - { 55, 56, 57, 58, 59}, - { 60, 61, 62, 63, 64}, - { 65, 66, 67, 68, 69}, - { 70, 71, 72, 73, 74}}, - - {{ 75, 76, 77, 78, 79}, - { 80, 81, 82, 83, 84}, - { 85, 86, 87, 88, 89}, - { 90, 91, 92, 93, 94}, - { 95, 96, 97, 98, 99}} - }, + SECTION("ConvDepthWise kernel [3,3]") { + SECTION("no stride, no dilation") { + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({3,3}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,1,4,5,5> { + {{{{-0.1008466408, -1.5454049110, -0.9166140556, -0.6960951090, + 0.4341324568}, + { 0.0774235576, -0.5880309343, -0.7458236814, 1.1315208673, + -0.1549620479}, + { 0.0743396878, 2.1142954826, -0.1723070294, -0.1795921773, + 0.5806804895}, + { 1.3079929352, 1.0224009752, -0.1107744649, -0.9886689186, + 0.2421970069}, + {-0.5798103809, 0.3528926671, 0.0050931070, -0.7710842490, + 0.3619758785}}, + + {{ 0.0081272032, 0.9984526038, 0.0044101765, -1.6572961807, + 2.0608859062}, + {-1.1862419844, -0.4931973815, 0.7710964084, 0.8817673326, + 0.8965246081}, + { 1.8537108898, -0.0401010700, -0.4194879532, 0.3477281332, + 0.6765057445}, + {-0.1150730550, -0.1088671982, 0.1020692363, -1.0760768652, + 0.5623582602}, + {-0.8432090282, -1.9785683155, -1.0973212719, -0.4528564811, + 0.5299630761}}, + + {{-1.4599646330, -1.2320238352, 0.2687234879, 0.4537659883, + -0.5159105062}, + {-0.1159662902, -0.0771213397, 0.1781232059, 0.4988347590, + 1.6487812996}, + { 0.4867084324, -1.0319201946, -1.0943733454, -1.9665944576, + 1.4405336380}, + {-0.0458223820, 1.9759161472, 1.0542000532, -0.8943792582, + 0.0833332092}, + { 0.6894319654, 0.1574374884, -0.0074822172, 0.1254266798, + -1.0130254030}}, + + {{ 0.7434654236, 0.5090847015, 0.9238219261, -1.6246345043, + -1.7051482201}, + {-0.2401624918, 0.1479524970, 0.1278737187, 0.3838838637, + 0.5377982855}, + { 0.5112501979, -2.1439685822, -0.4956556559, -0.3001609743, + -2.1275873184}, + {-2.0808370113, 0.0635774806, -0.2659386396, -0.0834266022, + -1.3333587646}, + { 2.3835628033, 0.3105354607, -0.3466446698, -1.3037562370, + -0.2132562548}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> {{ 0.1543500125, 0.2259403467, -0.2096073627, 0.0794987679}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,1,3,3> { + {{{{-0.0184964351, 0.0452591591, -0.0370532684}, + { 0.1729838848, 0.0589027032, -0.2055611610}, + { 0.0022869906, -0.0960286856, 0.1693624258}}}, + + + {{{ 0.3332056701, -0.3317362666, 0.2975485921}, + { 0.0031725965, 0.3205705583, -0.2838594615}, + { 0.3050023913, -0.2213795185, 0.2740720510}}}, + + + {{{-0.2431102246, -0.0334809646, 0.2784897089}, + {-0.2372554243, -0.0052002668, 0.2773198783}, + {-0.2523153722, 0.2629073262, -0.2890377939}}}, + + + {{{ 0.0322953090, 0.0774075240, 0.2625359297}, + { 0.0187648144, -0.2478408515, 0.2355882376}, + { 0.2588306367, -0.0054992437, 0.0906284302}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,1,4,3,3> { + {{{{ 0.0202585980, -0.2200401127, 0.2083991319}, + { 0.2128068209, 0.3275838494, 0.2010547519}, + { 0.5299543738, 0.3573902845, 0.1360719502}}, + + {{-0.0227194652, 0.2353070378, 1.4015233517}, + { 0.3528071046, -0.5158598423, 0.7987069488}, + { 0.5479428768, 0.3083744943, -0.4213459492}}, + + {{ 0.2609346807, 0.9040609002, -0.6786935925}, + {-0.3164857328, -0.3156455159, 0.3337014914}, + {-0.4356064498, -1.2700247765, 0.6287755966}}, + + {{ 0.4735998511, -0.7769540548, -0.7495914698}, + {-0.0220829751, 0.2172142118, -0.3708224297}, + { 0.2662829161, -0.0953377336, -0.9186285734}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f));; + } + SECTION("stride [2,2] | no dilation") { + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({3,3}, {2,2}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,3,5,5> { + {{{{-0.2681647539, -0.1401816458, -1.3439471722, 0.7686232328, + -0.2347259074}, + {-0.9053395391, -0.1154953390, 0.2734031677, 0.6977564096, + -0.2161262333}, + {-0.5152124763, 0.1068634093, 0.8755886555, 0.3106011450, + -0.2565480769}, + {-0.1076742634, -1.2246952057, -1.3814687729, -0.5840652585, + 0.4304563105}, + {-0.8992680907, 0.4047995806, 1.4492179155, -1.4900603294, + -0.7605531812}}, + + {{-2.3480441570, 0.6959431767, -0.1289412379, -0.0847483501, + 1.2968308926}, + { 0.9658490419, -0.2208128721, -0.8502574563, -1.5217782259, + 0.3917841315}, + {-0.0676943064, -0.8175349236, 0.7544731498, 0.0412569866, + 1.2587231398}, + {-0.7003766298, -0.8936796188, -0.0226393063, -1.2184852362, + 0.3866785169}, + {-1.7608956099, -0.1914640963, 0.2436290830, -0.9511274099, + 1.5242536068}}, + + {{-1.6968810558, -0.7415107489, -0.3093018830, -0.5676894784, + 0.2574917674}, + { 0.2789881527, -1.1715706587, 0.3031383455, 0.2484192103, + 0.1954772770}, + { 0.9325433969, -2.1942939758, -0.9328040481, 0.9583657384, + 1.5130572319}, + {-1.2234312296, 0.7099080086, 0.9838530421, -0.3570700288, + -1.9504652023}, + {-0.1599121839, 1.4158203602, -0.6402221918, -1.1810790300, + -0.4780270755}}}, + + + {{{ 0.2103048563, -1.6611757278, 0.4508873522, 0.8979755640, + -0.6757122874}, + { 0.7067258954, 0.3836486340, -1.5270982981, -0.2568639815, + -0.1403182000}, + { 0.3186497986, -0.0742176697, 0.2034454942, -1.4518156052, + 0.5708613396}, + {-0.9756730199, -0.1207360327, 0.5579432845, 0.2221798450, + 0.7631093264}, + {-0.6217514277, 0.0976419225, -0.3045219481, 1.8516135216, + -0.2196053267}}, + + {{ 1.2789859772, -0.6263879538, -0.2063939124, -0.2311875522, + -0.0393278264}, + {-0.8454674482, -1.0055953264, 1.1767570972, -1.2289278507, + 0.1877539605}, + { 0.8406858444, 1.5269470215, 0.9868479967, -0.3241306841, + 0.9222514033}, + { 1.3278032541, 1.0738047361, 0.2232345939, -1.9111207724, + 2.5548310280}, + { 1.1046593189, -0.6609010696, -0.1762587577, -0.4457865655, + 0.8877998590}}, + + {{-0.3908021450, 1.6496912241, 0.1737804860, 0.1961778849, + -0.4031102061}, + {-0.8883095384, 1.4801807404, -0.6092755198, -0.5375859141, + -0.3113131523}, + { 2.0676779747, 0.3772547543, -0.3045346141, 1.6700518131, + -1.8084005117}, + { 1.4025870562, 0.8708391786, 1.3200664520, -1.1006357670, + 0.3649013042}, + {-0.5854687095, -1.2669901848, -2.2839319706, -0.1692169160, + -0.2855746150}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,3> {{0.0749854296, 0.2219027281, 0.1885577142}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,3,1,3,3> { + {{{{ 0.3290626705, 0.3244876564, -0.3134182394}, + { 0.2142836750, 0.2250442505, 0.1981022060}, + {-0.1862878054, 0.2690648437, 0.0310798883}}}, + + + {{{ 0.0116933584, -0.3271239698, 0.1202279776}, + { 0.0051873922, -0.0847972259, -0.3323501348}, + {-0.2489729375, 0.2729874253, 0.0016002655}}}, + + + {{{-0.3015822172, -0.0256469250, -0.2859892547}, + { 0.1704760790, 0.1709747761, 0.1115431413}, + {-0.3066976070, 0.2390796840, -0.0861117467}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,3,2,2> { + {{{{ 0.3485876918, 0.0410024002}, + {-0.5851752758, -0.4924227297}}, + + {{ 0.0524865389, 0.2238895148}, + { 1.0454657078, 0.0253930446}}, + + {{-0.0414484441, 0.7236352563}, + { 0.6955899596, -0.1431636363}}}, + + + {{{-0.6739091277, -0.0971068442}, + { 0.0989151821, 0.3607567251}}, + + {{ 0.3158496320, 0.0055956692}, + {-0.7632108927, -0.3119188249}}, + + {{-0.2696485221, 0.6642971039}, + { 0.2509680390, 1.5169239044}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + + conv_op.forward(); + + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [2,2] | dilation [2,2]") { + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({3,3}, {2,2}, {2,2}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,1,4,7,7> { + {{{{-1.2564380169e+00, 1.7001671791e+00, -6.0228500515e-02, + -7.6883906126e-01, 1.0394480228e+00, -2.1278746128e+00, + -4.3723234534e-01}, + { 1.1784723997e+00, -8.7596154213e-01, -5.6893348694e-02, + -4.8710629344e-01, 1.6934870481e+00, -8.1963825226e-01, + -9.9433082342e-01}, + { 1.6150641441e+00, 2.2550213337e-01, 1.3934186101e-01, + -6.8104225397e-01, -2.2220163047e-01, 1.1362174898e-01, + 2.0110601187e-01}, + { 6.7898714542e-01, -7.3953729868e-01, -3.0562976003e-01, + -4.8523655534e-01, 4.3536397815e-01, -2.4560281634e-01, + 6.4579688013e-02}, + { 1.3424048424e+00, 1.6474188864e-01, 7.4488747120e-01, + -5.9620857239e-01, -9.5960032940e-01, 4.5834407210e-01, + 4.9313405156e-01}, + { 1.4139299393e+00, -1.3618037701e+00, 3.3027759194e-01, + 6.8120902777e-01, 1.9601145983e+00, -4.0156817436e-01, + -7.7237486839e-02}, + { 1.3182686567e+00, 3.3211612701e-01, 1.9552657604e+00, + 1.4429262877e+00, 1.2531367540e+00, -1.5773595572e+00, + 1.0225969553e+00}}, + + {{ 4.6619251370e-01, 6.9228565693e-01, 8.8645166159e-01, + -8.3506953716e-01, -5.2961051464e-01, 1.0165786743e-01, + 5.3774905205e-01}, + {-6.0989326239e-01, 2.3970830441e-01, 2.3374938965e-01, + -4.9342277646e-01, -3.7169873714e-01, 3.4093639255e-01, + 2.3664817214e-01}, + {-7.5131034851e-01, -4.2149517685e-02, -4.3322569132e-01, + 2.1540248394e+00, -6.9199752808e-01, 4.7342130542e-01, + 1.9464567900e+00}, + { 2.4289269932e-03, -2.2609269619e+00, -3.6993902922e-01, + -2.6160044670e+00, -1.0806908607e+00, -9.2318016291e-01, + 9.6653455496e-01}, + {-4.3112087250e-01, 9.3174472451e-02, 2.1137170494e-01, + 3.6451536417e-01, 6.0560785234e-02, -1.4053032398e+00, + 2.6295976639e+00}, + { 3.3558085561e-01, 2.0609512925e-01, 8.1405574083e-01, + -1.1626043916e-01, -9.6128863096e-01, -1.0162148476e+00, + 2.2983274460e+00}, + {-5.1882678270e-01, 9.7170782089e-01, 5.9890896082e-01, + -1.2613058090e+00, -4.7689700127e-01, 3.2950636744e-01, + 2.5496333838e-01}}, + + {{ 1.2547644973e-01, -2.1516680717e+00, -4.3004885316e-01, + -1.1163233519e+00, -7.9468077421e-01, -8.5132592916e-01, + -6.7698568106e-01}, + {-2.3809320927e+00, 9.1189408302e-01, -7.9828941822e-01, + -2.1867971420e+00, -2.0300696790e-01, -6.1769866943e-01, + -5.6792998314e-01}, + { 4.5785719156e-01, -3.5315357149e-02, -5.2074277401e-01, + 1.2201535702e+00, -1.7547093630e+00, -9.1879181564e-02, + -9.0850913525e-01}, + {-1.0663042068e+00, 9.3642288446e-01, 1.0326064825e+00, + 2.7203741670e-01, 1.5793048143e+00, -7.6377516985e-01, + 6.5407752991e-01}, + {-6.6077453084e-03, 4.6359539032e-01, -7.1511381865e-01, + -2.0252700150e-01, -1.2316555977e+00, 5.3828799725e-01, + 2.9643586278e-01}, + {-2.1578962803e+00, 1.4375370741e+00, 1.4743455648e+00, + -1.8298947811e-01, -1.7145735025e+00, 1.9807685614e+00, + -7.3558908701e-01}, + {-8.6257940531e-01, 8.4425401688e-01, -1.0371112823e+00, + -4.0326038003e-01, -7.3599940538e-01, 3.4502989054e-01, + -1.2538944483e+00}}, + + {{ 1.3018572330e+00, 3.4584665298e-01, 3.7988024950e-01, + 2.6572030783e-01, -1.4982204139e-01, 7.3580324650e-01, + -1.3376350701e-01}, + {-1.1923942566e+00, -9.8181134462e-01, 2.2573418915e-01, + 8.0120041966e-03, -9.5310580730e-01, 9.6748578548e-01, + 5.4245507717e-01}, + {-4.9568879604e-01, -1.4902360439e+00, 4.9879044294e-02, + -1.5037181377e+00, 1.5938287973e-01, 8.2041674852e-01, + -1.0953415632e+00}, + {-4.9750682712e-01, 4.4327926636e-01, -2.5331549346e-02, + -9.7987723351e-01, 1.3228254318e+00, 5.7619941235e-01, + -5.6825790554e-02}, + {-1.5635057688e+00, 6.7909795046e-01, -4.3309000134e-01, + 1.8277673721e+00, -9.7239983082e-01, -1.0455882549e-01, + 5.8249467611e-01}, + {-7.6343780756e-01, -3.5443061590e-01, -5.5096411705e-01, + -1.1431121826e+00, -8.4882009029e-01, 5.1973551512e-01, + 2.6110163331e-01}, + { 2.3929489776e-02, 7.2413641214e-01, 4.2561513186e-01, + -6.4914476871e-01, -5.4483998567e-02, 3.4525018930e-01, + -3.9139431715e-01}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> {{-0.0600201301, 0.3112891614, 0.2262887955, 0.1428407431}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,1,3,3> { + {{{{-0.0483490229, 0.1010556668, -0.3016797900}, + {-0.0918154344, -0.0209818687, 0.3124184012}, + { 0.2390201986, 0.2415951192, 0.1040253267}}}, + + + {{{-0.2466100901, 0.2284083068, 0.0995316133}, + { 0.0553057604, 0.1444625109, 0.1359003782}, + {-0.2869182825, -0.0471504554, -0.0665103197}}}, + + + {{{ 0.0298744049, 0.1652455777, -0.2147967815}, + {-0.2902197242, 0.3079298437, -0.2284899652}, + {-0.1819925010, 0.2709769011, 0.1631795168}}}, + + + {{{ 0.1494711637, -0.0957429856, -0.2600780129}, + { 0.0156168938, 0.0227777958, -0.0343352184}, + {-0.1857107133, -0.2571181655, -0.2455177009}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,1,4,2,2> { + {{{{-0.1385704875, 0.2320426404}, + { 0.4221621752, 0.8324003816}}, + {{ 0.2576041818, -0.0725638047}, + { 0.4960045815, 0.6652954817}}, + {{ 0.0438128375, -0.1093067303}, + { 0.3498060405, -0.3388394713}}, + {{ 0.9684044123, 0.4782117605}, + {-0.0788729265, 0.4020597339}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + } + SECTION("point-wise") { + SECTION("no stride, no dilation") { + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({1,1}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,4,3,3> { + {{{{-0.2665283084, 1.0591213703, -0.5744783282}, + { 2.1924018860, -0.9184432626, -0.4519051015}, + { 0.5954509974, 0.0728591084, -0.4014659226}}, + + {{-1.0344575644, 1.3931720257, 0.3318610489}, + {-0.6563054919, 0.2039012611, 1.4156180620}, + {-0.8701118827, 1.4934384823, -2.0206739902}}, + + {{-0.5417286754, 0.0236786865, -1.1417526007}, + {-0.0592311621, -0.3561672270, 0.4465615153}, + { 0.5427954793, -0.4105411768, -1.2697076797}}, + + {{ 0.5128195882, -0.9545230865, 0.3979352117}, + { 0.8590119481, 0.2024669945, -0.0111086201}, + { 1.5435370207, -0.5318347216, 0.1749227196}}}, + + + {{{-0.6836048961, -1.2168579102, -0.8282459378}, + {-0.2170266211, -0.3614979684, -0.5755043030}, + { 0.0710424408, -0.6281879544, -0.8610697985}}, + + {{-0.3863270581, 0.6085444093, -0.0376757309}, + { 1.3019801378, -0.6326317191, -0.1477297693}, + { 0.6462736726, -1.0503417253, 0.5722824931}}, + + {{-0.0206014682, 0.2297244966, 0.7034130096}, + { 1.4367071390, 0.1368671060, 0.6463897228}, + {-2.0419392586, -2.2926816940, 1.9855005741}}, + + {{ 1.6628106833, -0.4438441694, 1.3320568800}, + { 0.3473397493, -0.5916580558, 1.3845838308}, + { 0.3625002801, 2.3297600746, -0.2896993756}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> {{-0.7087996006, 0.5437290668, -0.0904183388, -0.1987243891}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,1,1,1> { + {{{{ 0.8847143650}}}, + {{{-0.5319792032}}}, + {{{ 0.0759756565}}}, + {{{ 0.1301449537}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,4,3,3> { + {{{{-0.9446009994, 0.2282202840, -1.2170488834}, + { 1.2308498621, -1.5213595629, -1.1086065769}, + {-0.1819955558, -0.6443400979, -1.0639822483}}, + + {{ 1.0940389633, -0.1974094808, 0.3671858907}, + { 0.8928699493, 0.4352578223, -0.2093503028}, + { 1.0066105127, -0.2507491410, 1.6186856031}}, + + {{-0.1315765232, -0.0886193365, -0.1771637350}, + {-0.0949184671, -0.1174783781, -0.0564905331}, + {-0.0491790958, -0.1216094717, -0.1868852079}}, + + {{-0.1319835037, -0.3229507506, -0.1469351351}, + {-0.0869283155, -0.1723743379, -0.2001701146}, + { 0.0021591650, -0.2679399848, -0.1759590805}}}, + + + {{{-1.3135946989, -1.7853713036, -1.4415606260}, + {-0.9008061886, -1.0286220312, -1.2179565430}, + {-0.6459473372, -1.2645665407, -1.4706003666}}, + + {{ 0.7492470145, 0.2199960947, 0.5637717843}, + {-0.1488972902, 0.8802759647, 0.6223182082}, + { 0.1999249160, 1.1024889946, 0.2392866760}}, + + {{-0.0919835493, -0.0729648694, -0.0369760729}, + { 0.0187364295, -0.0800197721, -0.0413084552}, + {-0.2455560118, -0.2646063268, 0.0604313724}}, + + {{ 0.0176820308, -0.2564884722, -0.0253639072}, + {-0.1535198689, -0.2757256925, -0.0185277909}, + {-0.1515468061, 0.1044821292, -0.2364273071}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [3,3], no dilation") { + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({1,1}, {3,3}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,1,4,3,3> { + {{{{ 1.6097626686, 0.6317374706, 0.8121796846}, + {-0.7499074936, -0.3534076512, -0.3863078654}, + { 0.2124902308, -1.2060434818, -0.2350673527}}, + + {{ 1.6450011730, 0.7838846445, 0.5905120373}, + { 1.3153772354, -2.0690157413, -0.0058457609}, + {-1.5660933256, 0.0484106764, 1.1444953680}}, + + {{-0.3757407069, 0.5180828571, -0.1972250640}, + { 0.6753169894, 0.1572864950, 1.7338060141}, + {-0.3412690759, 0.1255278289, 0.7706317306}}, + + {{-0.4972190559, -0.7812244892, 0.6989694834}, + {-0.0817485675, -0.3598272502, -0.9195072055}, + {-0.9635654092, 1.1187410355, 1.2071155310}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> {{ 0.3924750090, 0.6698757410, 0.3069384098, -0.5433753729}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,1,1,1> { { - {{100, 101, 102, 103, 104}, - {105, 106, 107, 108, 109}, - {110, 111, 112, 113, 114}, - {115, 116, 117, 118, 119}, - {120, 121, 122, 123, 124}}, - - {{125, 126, 127, 128, 129}, - {130, 131, 132, 133, 134}, - {135, 136, 137, 138, 139}, - {140, 141, 142, 143, 144}, - {145, 146, 147, 148, 149}}, - - {{150, 151, 152, 153, 154}, - {155, 156, 157, 158, 159}, - {160, 161, 162, 163, 164}, - {165, 166, 167, 168, 169}, - {170, 171, 172, 173, 174}}, - - {{175, 176, 177, 178, 179}, - {180, 181, 182, 183, 184}, - {185, 186, 187, 188, 189}, - {190, 191, 192, 193, 194}, - {195, 196, 197, 198, 199}} + {{{-0.7436860800}}}, + {{{-0.0358216763}}}, + {{{ 0.0749748945}}}, + {{{-0.8600753546}}} } - } - }); - std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<int,2,4,3,3> { - { - { - {{ 319, 355, 391}, - { 499, 535, 571}, - { 679, 715, 751}}, - - {{ 8745, 9024, 9303}, - { 10140, 10419, 10698}, - { 11535, 11814, 12093}}, - - {{ 29337, 29859, 30381}, - { 31947, 32469, 32991}, - { 34557, 35079, 35601}}, - - {{ 62061, 62826, 63591}, - { 65886, 66651, 67416}, - { 69711, 70476, 71241}} - }, + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,1,4,1,1> { { - {{ 3919, 3955, 3991}, - { 4099, 4135, 4171}, - { 4279, 4315, 4351}}, + {{{-0.8046830893}}, + {{ 0.6109490395}}, + {{ 0.2787672877}}, + {{-0.1157295182}}} + } + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [3,3], dilation [2,2]") { // same as 'no dilation' test + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({1,1}, {3,3}, {2,2}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,1,4,3,3> { + {{{{ 1.6097626686, 0.6317374706, 0.8121796846}, + {-0.7499074936, -0.3534076512, -0.3863078654}, + { 0.2124902308, -1.2060434818, -0.2350673527}}, - {{ 36645, 36924, 37203}, - { 38040, 38319, 38598}, - { 39435, 39714, 39993}}, + {{ 1.6450011730, 0.7838846445, 0.5905120373}, + { 1.3153772354, -2.0690157413, -0.0058457609}, + {-1.5660933256, 0.0484106764, 1.1444953680}}, - {{ 81537, 82059, 82581}, - { 84147, 84669, 85191}, - { 86757, 87279, 87801}}, + {{-0.3757407069, 0.5180828571, -0.1972250640}, + { 0.6753169894, 0.1572864950, 1.7338060141}, + {-0.3412690759, 0.1255278289, 0.7706317306}}, - {{138561, 139326, 140091}, - {142386, 143151, 143916}, - {146211, 146976, 147741}} + {{-0.4972190559, -0.7812244892, 0.6989694834}, + {-0.0817485675, -0.3598272502, -0.9195072055}, + {-0.9635654092, 1.1187410355, 1.2071155310}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> {{ 0.3924750090, 0.6698757410, 0.3069384098, -0.5433753729}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,1,1,1> { + { + {{{-0.7436860800}}}, + {{{-0.0358216763}}}, + {{{ 0.0749748945}}}, + {{{-0.8600753546}}} + } + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,1,4,1,1> { + { + {{{-0.8046830893}}, + {{ 0.6109490395}}, + {{ 0.2787672877}}, + {{-0.1157295182}}} } - } - }); - op -> associateInput(0, myInput); - op -> associateInput(1, myWeights); - op -> associateInput(2, myBias); - op->setDataType(DataType::Int32); - op->setBackend("cpu"); - myCDW -> forward(); - op -> getOutput(0) -> print(); - REQUIRE(*(op -> getOutput(0)) == *myOutput); + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } } - SECTION("point-wise") { - ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({1,1}); - std::shared_ptr<Tensor> weights = std::make_shared<Tensor>(std::vector<std::size_t>({3,1,1,1})); - weights -> setBackend("cpu"); - std::shared_ptr<Tensor> biases = std::make_shared<Tensor>(std::vector<std::size_t>({3})); - biases -> setBackend("cpu"); - std::shared_ptr<Tensor> input = std::make_shared<Tensor>(std::vector<std::size_t>({2,3,5,5})); - input -> setBackend("cpu"); - std::shared_ptr<Tensor> expected_output = std::make_shared<Tensor>(std::vector<std::size_t>({2,3,5,5})); - expected_output -> setBackend("cpu"); - - float weighst_array[3] {-0.0045, -0.4223, -0.9452}; - weights->getImpl()->setRawPtr(weighst_array, 3); - - float biases_array[3] {-0.8595, 0.7062, -0.0062}; - biases->getImpl()->setRawPtr(biases_array, 3); - - float input_array[2*3*5*5] { - 0.6581, 0.2509, 0.2660, 0.8270, 0.8040, 0.3147, 0.5028, 0.2591, 0.8585, - 0.7762, 0.9972, 0.0305, 0.1202, 0.2682, 0.9306, 0.7927, 0.1494, 0.0678, - 0.5550, 0.4132, 0.4742, 0.6199, 0.1802, 0.6350, 0.2539, 0.5594, 0.0143, - 0.8656, 0.7105, 0.1420, 0.2464, 0.7883, 0.5715, 0.7642, 0.5492, 0.6628, - 0.4922, 0.7941, 0.8421, 0.7914, 0.0237, 0.8081, 0.0174, 0.6018, 0.7402, - 0.3770, 0.8786, 0.3651, 0.5355, 0.4267, 0.4457, 0.6756, 0.9631, 0.0145, - 0.4470, 0.5202, 0.2675, 0.5815, 0.3487, 0.3457, 0.7179, 0.0518, 0.1520, - 0.0573, 0.9219, 0.3615, 0.0866, 0.5237, 0.4725, 0.2565, 0.8726, 0.6434, - 0.6875, 0.2919, 0.3355, 0.1886, 0.1749, 0.0785, 0.4091, 0.1907, 0.4664, - 0.2738, 0.4784, 0.7807, 0.0687, 0.3091, 0.4557, 0.2277, 0.2424, 0.8691, - 0.1893, 0.2918, 0.5691, 0.1926, 0.2866, 0.0097, 0.5445, 0.5085, 0.1110, - 0.7099, 0.8927, 0.6182, 0.2538, 0.8694, 0.7872, 0.3196, 0.0710, 0.2888, - 0.0403, 0.1670, 0.6840, 0.7323, 0.4861, 0.3390, 0.1096, 0.5070, 0.3872, - 0.7473, 0.6224, 0.6910, 0.7530, 0.0149, 0.0866, 0.9022, 0.5027, 0.3849, - 0.5255, 0.1977, 0.0570, 0.9581, 0.5461, 0.4623, 0.0101, 0.2362, 0.5922, - 0.8398, 0.1497, 0.5160, 0.2862, 0.5931, 0.9728, 0.1353, 0.7790, 0.9137, - 0.9351, 0.4036, 0.7638, 0.3873, 0.0494, 0.7450}; - input->getImpl()->setRawPtr(input_array, 2*3*5*5); - - float expected_output_array[2*3*5*5] { - -0.8624, -0.8606, -0.8607, -0.8632, -0.8631, -0.8609, -0.8617, -0.8606, - -0.8633, -0.8629, -0.8639, -0.8596, -0.8600, -0.8607, -0.8636, -0.8630, - -0.8601, -0.8598, -0.8620, -0.8613, -0.8616, -0.8622, -0.8603, -0.8623, - -0.8606, 0.4700, 0.7002, 0.3407, 0.4062, 0.6463, 0.6022, 0.3733, - 0.4649, 0.3835, 0.4743, 0.4263, 0.4984, 0.3709, 0.3506, 0.3720, - 0.6962, 0.3650, 0.6989, 0.4521, 0.3936, 0.5470, 0.3352, 0.5520, - 0.4801, 0.5260, -0.4274, -0.6447, -0.9165, -0.0199, -0.4287, -0.4979, - -0.2590, -0.5559, -0.3358, -0.3329, -0.6847, -0.0552, -0.1499, -0.0603, - -0.8776, -0.3479, -0.0881, -0.5011, -0.4528, -0.2486, -0.8309, -0.6143, - -0.6561, -0.2821, -0.3233, -0.8603, -0.8603, -0.8598, -0.8613, -0.8603, - -0.8616, -0.8607, -0.8616, -0.8630, -0.8598, -0.8609, -0.8615, -0.8605, - -0.8606, -0.8634, -0.8603, -0.8608, -0.8620, -0.8603, -0.8608, -0.8595, - -0.8619, -0.8617, -0.8600, -0.8626, 0.3292, 0.4451, 0.5991, 0.3390, - 0.3738, 0.5712, 0.6762, 0.5843, 0.6892, 0.6357, 0.4174, 0.3969, - 0.5009, 0.5631, 0.6599, 0.4921, 0.5427, 0.3906, 0.4434, 0.4144, - 0.3882, 0.6999, 0.6697, 0.3252, 0.4939, -0.3700, -0.5029, -0.1931, - -0.0601, -0.9118, -0.5224, -0.4432, -0.0157, -0.2294, -0.5660, -0.7999, - -0.1477, -0.4939, -0.2767, -0.5668, -0.9257, -0.1341, -0.7425, -0.8698, - -0.8900, -0.3877, -0.7282, -0.3722, -0.0529, -0.7103}; - expected_output->getImpl()->setRawPtr(expected_output_array, 2*3*5*5); - - conv_op.associateInput(0, input); - conv_op.associateInput(1, weights); - conv_op.associateInput(2, biases); - - conv_op.setBackend("cpu"); - conv_op.setDataType(DataType::Float32); - conv_op.forwardDims(); - - conv_op.forward(); - - conv_op.getOutput(0)->print(); - - REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expected_output, 1e-3f, 1e-4f)); + SECTION("kernel size [5,5]") { + SECTION("no stride, no dilation") { + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({5,5}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,4,7,7> { + {{{{ 4.9335759878e-01, 1.1863079071e+00, 2.5131928921e-01, + 1.6316433251e-01, 2.5724807382e-01, 1.7706827819e-01, + -9.5957201719e-01}, + {-1.0026545525e+00, -8.3003485203e-01, -9.3903255463e-01, + -4.6593669057e-01, -7.7092325687e-01, -1.7429690361e+00, + 4.9265477061e-01}, + {-1.3393815756e+00, -1.5150873661e+00, -5.8329886198e-01, + 1.7911167145e+00, -1.0506145954e+00, 6.3952505589e-01, + 1.0916038752e+00}, + { 7.4818961322e-02, 1.5963518620e+00, 1.2598557770e-01, + 6.4491826296e-01, -3.0015480518e-01, -4.8072105646e-01, + -6.9521999359e-01}, + {-1.1644884348e+00, 3.7635925412e-01, -1.3692667484e+00, + -7.2627186775e-02, 5.5572849512e-01, 6.1445808411e-01, + -1.3938227892e+00}, + {-1.4636498690e+00, -9.7709447145e-01, -1.8374252319e-01, + 5.8343982697e-01, 1.1417788267e+00, -4.4667036273e-03, + 4.9295508862e-01}, + {-1.7479445040e-01, -1.0437289476e+00, -1.3344774246e+00, + -1.9067826271e+00, 1.0409342051e+00, 6.6524130106e-01, + -1.6800432205e+00}}, + + {{ 9.7197180986e-01, -1.1580578089e+00, 1.5227444172e+00, + -2.1772108078e+00, 4.8213523626e-01, -3.4500488639e-01, + -6.7194223404e-01}, + { 4.9432659149e-01, -1.0472580194e+00, -7.3332744837e-01, + 1.0557442904e+00, -9.4611018896e-01, 4.7074545175e-02, + 2.5027732849e+00}, + {-2.0324008167e-01, 6.9782984257e-01, -5.9244088829e-02, + 3.4188255668e-02, 6.4118854702e-02, -1.1833695322e-01, + 1.7038782835e+00}, + {-6.1576306820e-01, 1.1467368603e+00, 9.0839028358e-02, + 1.4732836485e+00, -3.2838854194e-01, 1.1726535559e+00, + 9.6947526932e-01}, + {-3.2122168690e-02, 1.3275359869e+00, -1.3638773561e-01, + 2.1276748180e-01, 8.0851209164e-01, -9.8784273863e-01, + 1.3375968933e+00}, + {-5.4332029819e-01, -6.0529774427e-01, -6.0545504093e-01, + 1.2644755840e+00, -8.6449778080e-01, -6.2184357643e-01, + -6.7940688133e-01}, + {-5.1089364290e-01, -1.1370127201e+00, -1.9654258490e+00, + -2.0984578133e+00, -5.3804492950e-01, -2.2316617966e+00, + 1.1619290113e+00}}, + + {{ 4.6988701820e-01, -5.1307964325e-01, -7.0698708296e-01, + 5.7957285643e-01, 5.6874805689e-01, -1.9858320951e+00, + -1.7708021402e+00}, + { 1.4547123909e+00, 3.7047669291e-01, -8.2360202074e-01, + 1.9833043814e+00, 1.5422163904e-01, -6.8875616789e-01, + 1.9319385290e+00}, + { 9.7966343164e-02, -1.3681530952e+00, 6.8940818310e-01, + -1.0752324760e-01, -6.1970126629e-01, 1.8546850979e-01, + -4.5794528723e-01}, + { 7.3246699572e-01, 4.9492576718e-01, -4.0711274743e-01, + -1.9404098857e-03, -1.5990917683e+00, -4.1567105055e-01, + 1.6714956760e+00}, + { 6.3360172510e-01, -1.1427930593e+00, -1.6082632542e-01, + -9.5651131868e-01, -1.1128952503e+00, -2.5961843133e-01, + -5.4337906837e-01}, + { 3.8337892294e-01, 7.0294111967e-01, -2.2530713081e+00, + 4.6874850988e-01, -3.7142974138e-01, -8.2423162460e-01, + 3.1144621968e-01}, + { 4.9787372351e-01, -1.0927795172e+00, -1.7619090080e+00, + 4.4252237678e-01, -7.1531772614e-01, 2.7259647846e-01, + -4.3468984962e-01}}, + + {{-4.8326885700e-01, -3.9079174399e-01, -1.0790492296e+00, + -2.4060493708e-01, -3.3320981264e-01, 2.9022085667e-01, + 1.3364650011e+00}, + { 2.8901857138e-01, -3.9582711458e-01, -4.5644235611e-01, + -7.7036660910e-01, 8.4683763981e-01, -4.4504839182e-01, + 3.8618934155e-01}, + {-1.9830763340e+00, -8.0258053541e-01, -1.6078829765e+00, + 2.1368785203e-01, 7.4899446964e-01, -5.6091493368e-01, + 2.4929441512e-02}, + {-5.3176864982e-02, 1.1627144814e+00, -5.2606719732e-01, + 3.7888059020e-01, 7.5841093063e-01, 1.5893269777e+00, + -1.1967051029e+00}, + { 1.7698400021e+00, 1.6875107288e+00, -1.2041009665e+00, + 1.0938201100e-01, -1.3895796537e+00, 2.6665708423e-01, + -1.0111159086e+00}, + {-5.7373844087e-02, -1.2338018417e+00, -2.9199585319e-01, + -2.0545010269e-01, 2.3327396810e-01, -8.3556395769e-01, + -2.0631006360e-01}, + {-1.0011457205e+00, 6.4688628912e-01, -5.2283269167e-01, + -8.6595642567e-01, 4.4378590584e-01, -2.9411891475e-02, + 2.0832476020e-01}}}, + + + {{{-2.8625172377e-01, -4.6468538046e-01, 1.3221564293e+00, + 2.1612360477e+00, -8.7142616510e-01, 1.9371863604e+00, + -1.2806377411e+00}, + {-9.1990309954e-01, -2.0963511467e+00, 6.4161384106e-01, + -6.4031910896e-01, -6.1191931367e-02, -1.5462826490e+00, + 7.9379910231e-01}, + {-5.4952275753e-01, 5.6310713291e-01, 4.3853071332e-01, + -1.2469210625e+00, -1.1873561144e+00, -6.0527914762e-01, + 5.4670602083e-01}, + {-2.9861965775e-01, 2.5091698766e-01, -5.2039808035e-01, + -5.5684101582e-01, -6.3102495670e-01, 4.6421602368e-01, + 1.5310447216e+00}, + { 8.8189703226e-01, -1.7110499144e+00, -4.5705246925e-01, + 4.8052483797e-01, -1.6791898012e+00, -1.8851631880e+00, + 6.8464434147e-01}, + {-2.1336026192e+00, 1.0871043205e+00, 4.0943421423e-02, + 1.3751971722e-01, -1.1933552027e+00, -4.8217883706e-01, + 1.6583485603e+00}, + {-2.7782380581e-01, -1.1283507943e-01, -4.0001991391e-01, + -2.4640804529e-01, -1.5624488890e-01, -2.5247385502e+00, + -4.2229643464e-01}}, + + {{-1.0902142525e+00, 7.1061784029e-01, 5.2563959360e-01, + 8.9594686031e-01, 3.3749544621e-01, 1.5508149862e+00, + 3.1452372670e-01}, + { 8.8045895100e-01, -1.0802421570e+00, 6.9660454988e-01, + 7.6869897544e-02, -6.5158027411e-01, -6.1672717333e-01, + -1.2306349277e+00}, + { 8.2995426655e-01, -3.1962795258e+00, 1.1294349432e+00, + -8.1900782883e-02, -1.1951421499e+00, -8.0654764175e-01, + -1.0077600479e+00}, + {-8.5095930099e-01, -1.5733307600e-01, 1.7131972313e-01, + -2.5118584633e+00, -2.5612711906e+00, -1.2504032254e-01, + 9.9039399624e-01}, + { 1.6889984906e-01, -1.2384599447e+00, -3.2178740948e-02, + -6.5206307173e-01, 3.1484216452e-01, -1.3365371525e-01, + -4.7578561306e-01}, + { 3.0374403000e+00, 6.3499100506e-02, -5.4018121958e-01, + 3.0082745552e+00, -1.2163658142e+00, -3.3809682727e-01, + -7.0253151655e-01}, + {-1.8768366575e+00, -1.3563711643e+00, -4.9710619450e-01, + 1.0979669094e+00, 2.4758360386e+00, -7.8894788027e-01, + -4.7926986217e-01}}, + + {{-1.4381635189e+00, 1.4026446342e+00, -3.1678429246e-01, + 7.0693075657e-01, 6.9888514280e-01, 5.8169060946e-01, + 1.0585654974e+00}, + { 3.5511282086e-01, -1.2119283527e-01, 6.5024077892e-01, + -6.7714643478e-01, -6.6425532103e-01, -2.4314621091e-01, + 2.0232704282e-01}, + {-7.3568201065e-01, -6.9552195072e-01, -9.7805447876e-02, + -3.5810253024e-01, -4.0582028031e-01, -7.9612672329e-02, + -9.9031373858e-02}, + { 1.1441185474e+00, 3.6610561609e-01, 1.7374299467e-01, + -5.1957684755e-01, -7.6053464413e-01, 2.9512745142e-01, + 2.1705639362e-01}, + { 7.9289871454e-01, -1.2500211596e-01, 2.9821291566e-01, + -3.8476973772e-01, 1.1767704487e+00, -7.4524241686e-01, + 3.7969255447e-01}, + { 7.7420151234e-01, 1.6062602997e+00, 7.1063655615e-01, + 1.2802153826e+00, -2.7737879753e-01, 7.0112681389e-01, + -1.1772309542e+00}, + { 2.1208928525e-01, -1.2882195711e+00, 3.4972077608e-01, + -8.2289344072e-01, -1.4296934009e-01, -1.3751131296e+00, + 7.3476749659e-01}}, + + {{-1.5929245949e-01, 6.7564934492e-01, 6.2561661005e-01, + 1.3304146528e+00, 7.0209020376e-01, 1.0694117844e-01, + -1.4454191923e+00}, + {-5.5573159456e-01, -2.2276285291e-01, 5.6988465786e-01, + 9.7244775295e-01, -4.3723756075e-01, 7.7416992188e-01, + 8.2548350096e-02}, + { 2.7222864628e+00, 8.0904394388e-01, -4.6484589577e-01, + -9.2518895864e-01, 1.0536781549e+00, -4.6422225237e-01, + 6.7778164148e-01}, + { 7.3793217540e-02, 7.9145863652e-02, -1.2498629093e+00, + -1.2253532559e-01, 1.5088248253e-01, 1.1032183170e+00, + -5.4131639004e-01}, + { 2.1829447746e+00, -1.5288269520e+00, -8.8960021734e-01, + 8.2131899893e-02, -1.8998783827e-01, -1.8512618542e-01, + -3.4453171492e-01}, + { 3.8420417905e-01, -1.0917562991e-01, -6.3616442680e-01, + 1.1726476997e-01, 1.0503277779e+00, -6.6429930925e-01, + 9.0642881393e-01}, + {-1.1341298819e+00, -5.9495466948e-01, -3.9783969522e-01, + -7.2580540180e-01, 6.6621124744e-02, -4.1143927723e-02, + -2.2756290436e+00}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> { + { 0.0059141875, -0.0648065358, -0.0974549279, -0.0668982267} + }); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,1,5,5> { + {{{{ 0.0740372464, -0.1711187363, 0.0966774225, 0.1457725614, + 0.0493013635}, + { 0.0289726499, 0.1567692012, 0.0252312180, -0.0124352938, + 0.1629220545}, + {-0.1641854495, -0.0290440563, 0.0452437662, -0.1157503128, + 0.0947008878}, + {-0.0830174908, -0.0582915321, -0.0306369308, 0.1680577993, + 0.0682382584}, + { 0.1222482920, -0.1069305465, -0.0742284581, -0.0171356685, + -0.0342458002}}}, + + + {{{ 0.1335894614, -0.1472281665, 0.0556286834, -0.1737921238, + 0.1377953589}, + {-0.1001852304, 0.1950409263, 0.0267154705, -0.1778282225, + 0.0544498451}, + {-0.1686174124, -0.0963439271, 0.0142339235, 0.1019947082, + 0.1663077176}, + {-0.0054439544, -0.0924759433, 0.0550796054, 0.0225215200, + 0.0133644817}, + {-0.0317959562, -0.1560210288, -0.0677365586, 0.1075119525, + 0.0507474206}}}, + + + {{{ 0.1699814051, -0.1997351646, 0.0536758676, 0.0581439026, + 0.0378645435}, + { 0.1802627891, 0.1987431794, 0.0738811046, -0.0955903530, + -0.1651753038}, + { 0.1358847916, 0.1417508423, -0.0334093831, 0.1283134073, + -0.1813320667}, + { 0.0582411997, -0.0489385389, -0.0596107244, 0.1527736634, + 0.0661540776}, + { 0.0447563417, -0.1846967936, 0.0630980507, 0.1160114333, + -0.0923686028}}}, + + + {{{-0.0659249350, 0.0793606564, -0.0854479820, -0.1152713820, + 0.1727648824}, + { 0.1545656770, -0.0538719185, -0.0680776834, -0.1562429518, + -0.0725132227}, + { 0.1585823596, 0.0286217686, -0.0436206348, 0.0533598661, + -0.1532071382}, + {-0.0669565201, 0.0751482025, 0.0940384641, 0.1469179243, + -0.0734879524}, + { 0.0089315651, -0.1766889840, -0.0285181757, 0.1730310023, + 0.1393895000}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,4,3,3> { + {{{{-0.5857352614, 0.1086293235, -0.3229822814}, + {-0.7310093045, -0.3348580301, -0.0214850176}, + { 1.2433251143, 0.4028868377, -0.2706802487}}, + + {{ 0.0309218448, -0.5836250782, 0.8768596649}, + { 0.1736581773, -0.0594436042, 0.0133085661}, + {-0.0390867218, 0.2015268654, 0.4916410446}}, + + {{ 0.2147885114, -0.5134234428, -0.0374717414}, + { 0.2516125441, 0.3315618634, -1.1428399086}, + { 1.0725688934, -0.3064109087, -0.8532077670}}, + + {{-0.5997072458, -0.2242568582, 0.2218082845}, + { 0.1030889302, -0.8108748794, -0.1896360815}, + { 0.3333139122, 0.0513630882, -0.1338601708}}}, + + + {{{ 0.5120676160, -0.5367197394, 0.2702569067}, + {-0.1443247199, -0.4722655416, -0.1593769193}, + {-0.6662371159, -0.2179034948, 0.5732577443}}, + + {{-0.6737872958, 0.5619254708, -0.6692476869}, + {-0.8073410392, -0.1703055203, -0.1989304423}, + { 1.4196149111, 0.5330380797, -1.2780163288}}, + + {{-0.6440194249, 0.3993457556, -0.1229868680}, + { 0.0640285239, -0.4547872841, -0.1695011109}, + { 0.4479903877, 0.0911103636, -0.5269957185}}, + + {{ 0.1002222970, -0.1615311801, -0.4674149454}, + { 0.1204236448, 0.2234887183, -0.2007223815}, + { 0.4577222168, -0.2835268676, -0.6621670723}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [2,2], no dilation") { + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({5,5}, {2,2}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,4,7,7> { + {{{{ 3.8520759344e-01, -1.5048544109e-01, -2.1790374815e-01, + -5.7652354240e-01, 1.1100100279e+00, -2.4167035520e-01, + 1.0990920067e+00}, + { 8.7922102213e-01, 1.8489408493e+00, -2.0763272047e-01, + 1.6101540625e-01, 1.4712316990e+00, 9.0872722864e-01, + 5.9666723013e-01}, + { 2.3513903618e+00, 3.5370627046e-01, 4.6655362844e-01, + 3.6453476548e-01, -1.3369771242e+00, -3.9518725872e-01, + 1.4669662714e+00}, + { 1.2112803459e+00, 3.3809375763e-01, -7.8605490923e-01, + 1.2191309929e+00, 1.8707159758e+00, -1.5493103862e-01, + -4.8493990302e-01}, + {-8.7178820372e-01, -9.3710070848e-01, -1.1901003122e+00, + -3.0594833195e-02, 9.9669891596e-01, 1.9136610031e+00, + -1.4605854750e+00}, + {-8.1902602687e-03, 3.5515683889e-01, -1.2319850922e+00, + 1.3729830980e+00, -2.1615109444e+00, 1.0895189047e+00, + -3.2240214944e-01}, + { 7.7308756113e-01, 1.0597428083e+00, 8.5827887058e-01, + -7.9133516550e-01, -1.9188613892e+00, 1.2579337358e+00, + -1.2377259731e+00}}, + + {{-6.7472612858e-01, -3.4404060245e-01, 1.2304736376e+00, + 3.4865313768e-01, -5.6978863478e-01, 3.5225439817e-02, + 1.0718777180e+00}, + { 1.2797117233e+00, 4.2904919386e-01, -8.7542390823e-01, + -1.3666222095e+00, -5.1030803472e-02, -1.2733145058e-01, + 1.1057097465e-01}, + {-7.2918736935e-01, 1.1353262663e+00, 4.0105110407e-01, + 3.2652910799e-02, -7.9244107008e-01, 5.7068455219e-01, + 2.5546798110e-01}, + {-9.1081774235e-01, -1.6454605758e-01, 3.1387217045e+00, + 5.1187878847e-01, 7.3688519001e-01, -6.5781360865e-01, + -3.6698520184e-01}, + {-3.7950250506e-01, 9.2945128679e-01, -6.6260379553e-01, + 4.8890876770e-01, 1.2294455767e+00, 1.9234917164e+00, + 8.0498385429e-01}, + {-5.4067701101e-01, -2.9110896587e-01, -5.1824927330e-01, + -1.1617597938e-01, 1.3733102083e+00, 6.9681173563e-01, + 6.2707096338e-01}, + { 5.7201117277e-01, 1.5840615332e-01, 2.0593723655e-01, + -8.1141608953e-01, 8.5261273384e-01, -1.6056976318e+00, + -5.2579122782e-01}}, + + {{ 1.6815655231e+00, 2.4704076350e-01, -1.1197048426e+00, + -4.7988933325e-01, 1.4443746209e-01, 6.8120814860e-02, + -6.0399234295e-01}, + { 2.6893126965e-01, 1.1254637241e+00, 1.2356828451e+00, + 7.0232260227e-01, -1.7059510946e+00, 1.7088449001e-01, + 3.2452866435e-01}, + { 1.7605692148e+00, 1.0952962637e+00, 2.0282413960e+00, + 7.0874804258e-01, 6.6728973389e-01, 5.7557627559e-02, + -2.2893531248e-02}, + { 7.6859766245e-01, -1.5146261454e+00, -1.3063246012e-01, + 4.9653983116e-01, -6.1272948980e-01, 6.1812108755e-01, + 2.9842779040e-01}, + {-7.9819858074e-01, -1.5216785669e+00, 1.5476153791e-01, + -3.0010658503e-01, -8.9212977886e-01, 1.3135321140e+00, + 4.5638892055e-01}, + {-4.5912411064e-02, 3.1731166840e+00, -1.1454867572e-01, + 3.9372527599e-01, -3.0301496387e-01, -9.5864409208e-01, + 6.4990806580e-01}, + {-1.5128309838e-02, -1.8659442663e+00, -1.0648315400e-01, + -2.6408338547e+00, -1.2393210828e-01, 1.3168338537e+00, + -6.2912321091e-01}}, + + {{-8.6029833183e-03, -2.5194084644e-01, 3.4589302540e-01, + 1.0429813862e+00, 1.6685616970e+00, -7.2287905216e-01, + -8.8275146484e-01}, + {-1.1796358824e+00, 7.4369359016e-01, 1.1751577854e+00, + -8.0447465181e-01, -3.3473432064e-01, -4.7937799245e-02, + -1.8288640976e+00}, + { 5.6663048267e-01, -5.3825604916e-01, 1.8124829531e+00, + 8.6427420378e-02, -1.5517222881e+00, 2.1185632050e-01, + 9.4422245026e-01}, + {-1.9948658943e+00, 1.2489651442e+00, 1.2926094532e+00, + 6.2967604399e-01, 1.0601581335e+00, -1.4124554396e+00, + -1.0759254694e+00}, + { 1.7446736097e+00, 8.5903286934e-01, 1.2479382753e-01, + 1.0916360617e+00, 3.5449057817e-01, -1.9337905645e+00, + 8.8214844465e-02}, + { 2.2092112899e-01, 1.0851465464e+00, -8.5641130805e-02, + -6.1427676678e-01, -1.0848737955e+00, 1.7291833460e-01, + -9.3117421865e-01}, + { 4.9415281415e-01, -7.1096634865e-01, -1.2350386381e+00, + 4.8438447714e-01, -8.0378723145e-01, -2.3194132373e-02, + -8.8567745686e-01}}}, + + + {{{ 9.3833208084e-01, -9.3867695332e-01, -1.9296536446e+00, + -1.9469395280e-01, -1.0064001083e+00, 6.7425054312e-01, + 5.7381069660e-01}, + {-5.8385449648e-01, 1.5392524004e+00, -7.3658037186e-01, + -4.9099606276e-01, -5.8427224867e-03, -4.9734053016e-01, + 2.0966405869e+00}, + { 4.4381022453e-02, 1.9009371996e+00, -3.1770554185e-01, + -4.5139196515e-01, 2.2562942505e+00, 1.0809175968e+00, + -7.8067958355e-01}, + { 1.2378455400e+00, -1.2802067995e+00, -1.3410314322e+00, + -1.7746871710e-01, -9.6855717897e-01, -6.6797292233e-01, + -2.7914598584e-01}, + {-2.2995612621e+00, -4.6167243272e-02, -1.1212759018e+00, + -8.9812129736e-01, -5.6873339415e-01, -6.2530684471e-01, + -1.1113914251e+00}, + { 9.4398176670e-01, 2.9730677605e-01, 4.0788778663e-01, + -6.1924046278e-01, 3.7405481935e-01, -7.1785598993e-01, + 5.2785265446e-01}, + { 1.4844427109e+00, 1.5798323154e+00, 7.9189050198e-01, + 2.1535563469e+00, -1.3852857351e+00, 1.6917630434e+00, + 2.3598566055e+00}}, + + {{-7.7599078417e-01, 4.9129545689e-01, -6.9907718897e-01, + 5.8299517632e-01, -8.1232386827e-01, -6.8906480074e-01, + 6.0145241022e-01}, + { 6.5379202366e-01, -2.5990147591e+00, 5.7937479019e-01, + -2.4067652225e+00, -5.0686568022e-02, -6.1713993549e-01, + 1.3297734261e+00}, + {-5.2623099089e-01, 5.4217547178e-01, -1.7074975967e+00, + -1.3474613428e-02, -4.3210104108e-01, -1.3601350784e+00, + 1.1019977331e+00}, + {-3.2511076331e-01, 7.4853056669e-01, 8.6941182613e-01, + 2.6319536567e-01, -2.1560071036e-03, -6.2360137701e-01, + 6.6978454590e-01}, + {-4.7999924421e-01, 1.5140286684e+00, 5.0155067444e-01, + 2.3926508427e-01, -1.4036533423e-02, 1.0717345476e+00, + 4.9181202054e-01}, + {-1.1894277334e+00, -4.2690724134e-01, -1.0564312935e+00, + 1.3100820780e-01, -1.5495970249e+00, 3.2739278674e-01, + 2.1079761982e+00}, + { 8.7008196115e-01, -4.4802580029e-02, -3.0779972672e-01, + 9.4748604298e-01, 5.0919568539e-01, -5.6785094738e-01, + 1.8958197534e-01}}, + + {{ 7.7712529898e-01, -5.4674464464e-01, -4.3654015660e-01, + -1.7690027133e-02, -1.0759859085e+00, 7.9439246655e-01, + 4.0082043409e-01}, + {-1.4334075451e+00, -1.1236054450e-01, 4.2285147309e-01, + -4.3151515722e-01, 1.8296922743e-01, -7.5662362576e-01, + 1.2256839275e+00}, + {-1.2095364332e+00, 2.7471596003e-01, -7.1152734756e-01, + 8.0925470591e-01, -1.0531398058e+00, 1.3366776705e+00, + -1.4349820614e+00}, + {-6.5796740353e-02, -1.6194750369e-01, 1.7100380361e-01, + -3.0003476143e-01, -1.7518389225e+00, 1.0243947059e-01, + 1.8502172232e+00}, + { 1.8798203468e+00, -5.6445449591e-02, 2.2965614498e-01, + -5.0052756071e-01, -3.1924626827e+00, -2.3384423554e-01, + -1.8923732042e+00}, + {-8.3112102747e-01, 1.0703138113e+00, -1.0538098812e+00, + -5.6717932224e-01, 6.6653525829e-01, 1.0731325299e-01, + -1.4036635160e+00}, + { 6.6975963116e-01, -1.0376057625e+00, 1.4635567665e+00, + 6.0049772263e-01, 2.6335725188e-01, 1.7023352385e+00, + 3.9096495509e-01}}, + + {{-5.2209907770e-01, -7.3401969671e-01, -4.9598002434e-01, + -1.1101231575e+00, 2.7940380573e-01, 3.7108734250e-01, + -5.3757792711e-01}, + {-1.7226947546e+00, -8.4843122959e-01, 3.0122289062e-01, + -8.2913970947e-01, 1.7452031374e+00, 7.5711089373e-01, + 3.6754548550e-01}, + {-3.5317954421e-01, -2.7988141403e-02, 8.1158816814e-01, + 1.1813087463e+00, 1.7010580301e+00, 3.6692687869e-01, + 7.9475116730e-01}, + { 7.1114557981e-01, -4.3524390459e-01, -1.2870337963e+00, + -3.3187520504e-01, 4.9446484447e-01, -5.0515407324e-01, + -7.1395359933e-02}, + { 1.3766362667e+00, -3.5036647320e-01, -2.2399795055e+00, + -1.8095448017e+00, 2.2540309429e+00, 1.6960443258e+00, + -6.1720281839e-01}, + {-2.4066153169e-01, -3.4883093834e-01, -1.1646213382e-01, + -5.0506269932e-01, 1.1594126225e+00, -3.2168120146e-01, + -1.5061002970e+00}, + { 1.0476481915e+00, 1.9588752985e+00, 3.6338061094e-01, + 2.0577123165e+00, -1.0460792780e+00, -3.1160068512e+00, + -1.7618523836e+00}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> { + {-0.0200376045, -0.0020939112, -0.1464084685, -0.0419981480} + }); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,1,5,5> { + {{{{-0.0712780729, -0.1232913956, 0.1876820773, -0.0419011600, + -0.1424695104}, + { 0.0584071651, -0.1193018928, 0.1307084858, 0.1676846296, + -0.0871880054}, + {-0.0948264599, 0.0315888636, -0.1759090424, -0.0207793470, + 0.1858711988}, + {-0.1399959624, -0.0186878927, -0.0617008917, 0.0002223492, + 0.1727847904}, + {-0.0758293867, 0.0122927902, -0.0473791622, -0.1468438208, + -0.0557731166}}}, + + {{{-0.0829651132, 0.0315707214, 0.0823683068, -0.1360922605, + -0.0236635935}, + { 0.1936927885, -0.0261211153, -0.1094252840, -0.1744011492, + 0.0887691975}, + { 0.0449672453, 0.0418232940, 0.0158894062, 0.1032823324, + 0.1539564580}, + {-0.1828922033, -0.0399240516, 0.0496491455, 0.1540371031, + 0.1999383718}, + {-0.0052583455, 0.1408428699, 0.0648615360, -0.1485663503, + -0.0073370696}}}, + + {{{-0.0690000057, 0.0339213610, -0.1460432559, -0.1896216869, + -0.0705224052}, + {-0.1869531870, -0.1661461592, -0.0214934107, 0.1090264320, + 0.0115755796}, + {-0.1388747245, -0.0404886007, 0.0298435446, -0.1655783951, + -0.1112038419}, + {-0.1958016008, 0.0370284095, 0.0006295681, 0.1554906219, + -0.1377223581}, + { 0.0179722793, 0.0870472714, -0.0224549528, -0.0847019181, + -0.1470272839}}}, + + {{{-0.1011667028, 0.1455960721, -0.0806858540, 0.1572577953, + -0.0550846569}, + { 0.0752714872, 0.0279489048, 0.1414941847, 0.0247094631, + -0.0008786917}, + {-0.0451189540, -0.0664234906, 0.1269026548, -0.1047828719, + 0.0019106150}, + {-0.1722858250, -0.1355526745, -0.0965286270, 0.1626135856, + -0.0860176086}, + {-0.0193410646, -0.1041662693, -0.0561585911, 0.1893668473, + -0.1290524751}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,4,2,2> { + {{{{-0.7954307795, 0.6042141914}, + { 0.2831384242, -0.7216922641}}, + {{ 1.1425485611, -1.0336279869}, + { 0.3416210711, 1.3282159567}}, + {{-0.6753546596, -0.7460098267}, + { 0.0863537639, -1.0971375704}}, + {{ 0.3648000360, -1.4717553854}, + {-0.2303827107, 0.7337518334}}}, + + {{{ 0.1193560362, -0.5002098680}, + {-0.6565298438, -0.0648934469}}, + {{ 0.7361341119, 0.1737762839}, + {-0.5042620897, 1.6453049183}}, + {{ 1.0128767490, -0.1996757686}, + {-0.0394040421, 0.4287980795}}, + {{-0.9122012258, 0.8897001147}, + {-0.0642478392, -0.2671076953}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [2,2], dilation [2,2]") { + ConvDepthWise_Op<2> conv_op = ConvDepthWise_Op<2>({5,5}, {2,2}, {2,2}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,1,4,11,11> { + {{{{-7.9724937677e-01, 9.1211748123e-01, 2.4363055825e-01, + 2.4279198647e+00, -1.1906795502e+00, -1.2431346178e+00, + 7.9783517122e-01, 3.3326399326e-01, -8.8273769617e-01, + 6.3607555628e-01, 1.6719125509e+00}, + {-8.0337154865e-01, 4.3899053335e-01, 2.2643094063e+00, + -3.3888676763e-01, -2.1396698952e+00, -2.4637742341e-01, + 5.8492517471e-01, 1.1872248352e-01, -4.6652513742e-01, + 1.1809904873e-01, 1.1721512079e+00}, + { 1.5571426153e+00, -6.7109817266e-01, 7.4499303102e-01, + 1.0628908873e+00, -2.0967121422e-01, -9.4654607773e-01, + 4.4076669216e-01, 7.2388553619e-01, 4.3450137973e-01, + -4.0051737428e-01, -1.0369918346e+00}, + { 1.1229771376e+00, 7.1691840887e-02, 2.1389411390e-01, + 5.7155287266e-01, 6.8389695883e-01, 5.4999643564e-01, + 9.9025702477e-01, -1.6951800883e-01, -7.1715158224e-01, + 1.1616971493e+00, 1.4596436024e+00}, + { 6.1506503820e-01, -5.5823451281e-01, -7.7102023363e-01, + 1.3960702419e+00, -2.3227377236e-01, 1.5084296465e+00, + 2.4943335354e-01, 1.1076819897e+00, -5.2056992054e-01, + -1.0660402775e+00, 1.5976787806e+00}, + { 2.1903742850e-01, 1.9679501653e-02, 3.1386128068e-01, + -4.0626621246e-01, -7.5507235527e-01, -1.0893251896e+00, + 5.7655651122e-02, 1.7972362041e-01, -9.0133118629e-01, + 7.7605450153e-01, -6.6169762611e-01}, + {-6.5132927895e-01, -8.7547087669e-01, -9.1218388081e-01, + 2.3385442793e-01, -3.1539368629e-01, -1.4544982910e+00, + -6.6423857212e-01, -9.6526098251e-01, -6.6600215435e-01, + 6.4714908600e-01, 8.3344441652e-01}, + {-1.5785104036e+00, 7.4335277081e-01, -1.1462401152e+00, + 1.0709639788e+00, 2.4840381742e-01, 5.4372686148e-01, + 5.2218210697e-01, 1.9122928381e+00, -1.7239218950e+00, + -1.0918600559e+00, -4.8291814327e-01}, + { 9.1093665361e-01, 4.8966014385e-01, -5.2526658773e-01, + -5.5381953716e-01, 2.4685945511e+00, 1.8153283596e+00, + 1.3610179424e+00, 1.5524566174e+00, -1.0007113218e+00, + 9.3928158283e-01, -4.0532606840e-01}, + { 9.2250317335e-01, -5.1795337349e-02, -9.9768131971e-01, + 1.6249650717e+00, -9.7957354784e-01, 8.2169145346e-01, + 9.0952342749e-01, -2.4217249453e-01, 5.2506601810e-01, + 1.4915968180e+00, 1.8634383678e+00}, + {-1.2066408396e+00, -2.1929471493e+00, -2.1238762140e-01, + -7.7625429630e-01, 9.4938635826e-01, -1.8003302813e+00, + 3.4820947051e-01, -6.0576725006e-01, -4.6567910910e-01, + 7.3214060068e-01, -7.3195713758e-01}}, + + {{ 5.3023105860e-01, 5.2188825607e-01, -7.3025727272e-01, + 4.6667468548e-01, -1.9152610004e-01, -3.5157033801e-01, + -5.2651941776e-02, -4.8071330786e-01, -1.4187107980e-01, + -4.0443262458e-01, -9.0072855353e-02}, + { 1.7881003022e-01, 1.0663819313e+00, -1.0256117582e+00, + 8.7569499016e-01, 7.3652589321e-01, -2.6246693730e-01, + -1.6060233116e+00, -2.0135894418e-02, -3.0415063724e-02, + 3.0278328061e-01, 1.0328743458e+00}, + { 3.2101709396e-02, -1.9900113344e-01, -5.6631088257e-01, + -7.4470794201e-01, 1.7958736420e-01, -1.4189696312e+00, + -1.8464855850e-01, -1.0358128548e+00, 1.9816583395e+00, + -9.1884893179e-01, 9.0449694544e-03}, + {-2.0810999870e+00, -2.4015277624e-01, -3.2068858147e+00, + -1.2780202925e-01, 5.5257743597e-01, 8.7900471687e-01, + 1.0303863287e+00, 1.0323342085e+00, -7.1367937326e-01, + -1.9911302328e+00, -4.7755372524e-01}, + { 4.7770789266e-01, -1.2471196651e+00, 1.4831469059e+00, + -8.7475031614e-01, 6.2958401442e-01, 2.4512550831e+00, + 1.1762837172e+00, 1.0817874968e-01, -5.5265057087e-01, + -8.7890112400e-01, -8.8465034962e-01}, + { 7.3697751760e-01, 1.0448017120e+00, 5.1342499256e-01, + -6.6379088163e-01, 1.3169301748e+00, 9.0157186985e-01, + 7.0772147179e-01, 4.2946752161e-02, 2.8955113888e-01, + 2.5413966179e-01, 5.9332638979e-01}, + { 5.0560563803e-01, 1.8920665979e+00, 3.0823582411e-01, + 1.1087694168e+00, -2.0810942352e-01, -5.2579015493e-01, + -8.4162759781e-01, -2.1426311135e-01, -2.2446355820e+00, + 3.9921768010e-02, 6.9279879332e-01}, + { 1.1803445816e+00, -8.5259515047e-01, -9.9684113264e-01, + 3.4527037144e+00, -1.5741628408e+00, -3.5193979740e-01, + -1.4004269838e+00, 4.2186367512e-01, -4.9055755138e-02, + -1.2086832523e-01, 1.7582530975e+00}, + { 5.2762091160e-01, -7.7123051882e-01, 2.0454251766e+00, + 2.2788069248e+00, -1.5510680676e+00, -1.9315464497e+00, + -9.0742290020e-01, -1.6089993715e+00, -5.3302454948e-01, + 3.5658752918e-01, 4.8080060631e-02}, + { 1.8064410686e+00, 3.6816290021e-01, -1.5741494894e+00, + 5.8802300692e-01, 4.7199991345e-01, 1.2889680862e+00, + 8.7419849634e-01, -3.9012792706e-01, 5.6346172094e-01, + -7.6179641485e-01, 1.3050407171e+00}, + {-9.1590619087e-01, -1.8752954006e+00, 1.4963175058e+00, + 3.2759961486e-01, 9.2106527090e-01, -1.1356848478e+00, + -1.3705831766e+00, -8.2606017590e-01, -1.3079185486e+00, + 2.6973825693e-01, -6.5730160475e-01}}, + + {{ 9.5533591509e-01, -9.9876197055e-04, -1.4797580242e+00, + 2.1985734999e-01, 4.8844724894e-01, 4.3145084381e-01, + -2.1500962973e-01, -7.0111054182e-01, 6.9973039627e-01, + -7.2476547956e-01, 4.8026397824e-01}, + {-8.8500595093e-01, 1.8992747068e+00, 7.7701354027e-01, + -6.0169208050e-01, 1.0013473034e+00, 8.4491990507e-02, + -9.8977982998e-01, -7.9020494223e-01, 3.2907465100e-01, + -1.2078856677e-01, 1.5202083588e+00}, + {-2.7657735348e-01, 4.6925004572e-02, 1.2281295061e+00, + -1.8579104543e-01, 1.0336677730e-01, 2.5693323612e+00, + 8.9827783406e-02, 2.2823791504e+00, -8.2009571791e-01, + 2.0414433479e+00, -9.0024584532e-01}, + { 2.6029777527e+00, 4.0565639734e-01, 1.2988950312e-01, + -1.2674516439e+00, 5.8589053154e-01, 2.4598875046e+00, + 8.9385128021e-01, 6.4068651199e-01, 1.7348393798e-02, + 1.2424468994e+00, -8.4993803501e-01}, + { 5.7889044285e-01, 3.8729000092e-01, -8.8090997934e-01, + 2.1381093562e-01, 1.4890091419e+00, 1.5105549097e+00, + 1.4098797739e-01, 1.0446792096e-01, -1.9159198999e+00, + 1.3064564764e-01, -1.6926348209e-01}, + { 1.1417644024e+00, -1.4733666182e+00, 3.2986238599e-01, + 1.8303622305e-01, 5.6586086750e-01, -5.2473092079e-01, + -7.5201815367e-01, -1.5739550814e-02, -1.5592651367e+00, + -1.4688136578e+00, -3.3142146468e-01}, + {-7.3924712837e-02, 1.8161753416e+00, -9.5422208309e-01, + 3.4323176742e-01, -2.2727070749e-01, -1.1031615734e+00, + 5.7045269012e-01, 1.6896954775e+00, 1.0372216702e+00, + -1.3280247152e-01, -1.3075873852e+00}, + {-4.0329605341e-01, -1.1308847666e+00, -5.7332462072e-01, + -4.2800852656e-01, 7.3079723120e-01, 1.4624267817e-01, + -2.4124519527e-01, 1.6443549395e+00, 2.1521264315e-01, + -3.0984909534e+00, 2.1323997974e+00}, + { 5.5337917060e-02, -7.7057784796e-01, 3.2530885935e-01, + 1.3282178640e+00, -3.2126638293e-01, 2.9032289982e-01, + -2.4100792408e-01, -1.1505941153e+00, -4.0858381987e-01, + 3.8038328290e-01, 1.0238400698e+00}, + {-1.0223561525e+00, 5.4754292965e-01, 8.9632779360e-01, + 4.0344274044e-01, -7.0289498568e-01, -1.1168614626e+00, + 3.1760175228e+00, 2.0348765850e+00, -1.0406352282e+00, + 1.0582931042e+00, 1.1740338057e-01}, + { 6.0107231140e-01, 8.4875309467e-01, -8.5171341896e-02, + -1.2264981270e+00, 1.1493313313e+00, -1.9127263129e-01, + 5.3371381760e-01, -4.7718715668e-01, 8.9841789007e-01, + -4.7041997313e-01, -1.1772131920e+00}}, + + {{ 4.0120658278e-01, 6.7281287909e-01, -4.7505354881e-01, + -3.1049102545e-02, -5.2430522442e-01, -9.7885608673e-01, + -8.1729829311e-01, 2.8434273601e-01, 2.2878241539e-01, + 7.0183002949e-01, -6.1007946730e-01}, + {-1.4632455111e+00, -9.4703143835e-01, 3.1175765395e-01, + 1.7414463758e+00, 1.2987858057e+00, 2.5278210640e+00, + -2.5223663449e-01, -7.2194322944e-02, -9.3486815691e-01, + -5.4429602623e-01, -1.5758562088e+00}, + {-1.1150578260e+00, 3.1018608809e-01, -1.0259387493e+00, + 4.9761269242e-02, 2.2564062476e-01, 2.2673048079e-01, + -1.2348350286e+00, -4.8837900162e-01, -5.5627411604e-01, + 2.3974895477e+00, 1.5627510548e+00}, + { 1.3537845612e+00, -1.4481093884e+00, -5.8862978220e-01, + -9.2907649279e-01, 1.7989814281e-01, -6.1403113604e-01, + 3.9550009370e-01, 2.0637707412e-01, 1.4092805982e-01, + 1.4354915619e+00, 4.1124477983e-01}, + { 5.9437131882e-01, 2.5175651908e-01, -1.4724839926e+00, + -7.9224598408e-01, -1.0697947443e-01, -9.3873560429e-01, + -8.3823198080e-01, 1.0682547092e+00, 1.4871965647e+00, + -1.7402729392e-01, -4.1061699390e-01}, + { 6.1316050589e-02, -5.1780641079e-01, 3.9551302791e-01, + 1.3394130766e-01, 1.0029216856e-01, 5.2646106482e-01, + -2.3723851889e-02, -7.3339444399e-01, 9.1420966387e-01, + -9.4718337059e-01, -4.3122315407e-01}, + { 3.1069964170e-01, -5.1376241446e-01, 4.0816228837e-02, + -1.0862566233e+00, -7.4995791912e-01, 1.4363372326e-01, + -2.1348357201e+00, 7.7824163437e-01, 1.4786756039e-01, + -7.7644962072e-01, 5.6383001804e-01}, + {-5.8460813761e-01, 2.1913936362e-02, 2.7573537827e+00, + 1.6296634078e-01, 4.5511564612e-01, -1.9915504456e+00, + 2.9748791456e-01, 4.2073163390e-01, -7.9228687286e-01, + 7.1524131298e-01, -2.3795824051e+00}, + { 3.2347491384e-01, -1.2526123524e+00, 1.0551978350e+00, + -7.1423876286e-01, -2.3097810149e-01, -1.5859640837e+00, + -9.0056812763e-01, 2.4794772267e-01, 1.1709301472e+00, + -9.2559927702e-01, -1.4513528347e+00}, + {-6.9863849878e-01, -5.9516018629e-01, 8.8701313734e-01, + 1.2166969776e+00, -1.2960427999e+00, 1.5974614620e+00, + -1.1376231909e+00, 3.8471198082e-01, -2.3815085888e+00, + -1.9736941159e-01, 1.2386144400e+00}, + { 1.3419134617e+00, -3.7865355611e-02, -6.7163608968e-02, + 6.4137578011e-01, -4.0734642744e-01, -3.7681248784e-01, + -4.8736298084e-01, 5.5279612541e-01, 9.2274624109e-01, + -1.1439754963e+00, 3.9059036970e-01}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> { + {-0.0264293905, 0.1734595597, -0.0226496458, 0.0903987661} + }); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,1,5,5> { + {{{{ 0.0299229864, -0.0070633413, 0.0387998335, -0.1524392813, + -0.1102494225}, + { 0.0792495012, -0.1920675784, -0.0624185093, 0.0321180820, + 0.0158791542}, + {-0.1553061306, -0.1615069360, 0.1939006597, -0.1880072653, + 0.1276774853}, + {-0.1049263999, -0.1676749289, 0.0386066921, 0.1836618483, + -0.1901892275}, + {-0.0461415537, -0.1792095006, -0.1822651625, 0.1757787466, + -0.1020090356}}}, + + {{{ 0.1273712367, -0.0565941110, 0.0162034985, 0.0944691673, + 0.1084327474}, + { 0.0596342087, 0.0055869580, 0.0252192747, 0.1085267588, + -0.1543913186}, + { 0.0770991370, -0.1246689111, 0.1107715368, -0.1920998394, + -0.0687034130}, + { 0.0345953479, -0.1752771586, -0.0221131798, -0.1240047216, + -0.1447991878}, + {-0.1521565169, 0.0231835600, -0.0339560769, -0.1456816643, + -0.0323682800}}}, + + {{{-0.0775303841, -0.0420084260, 0.0729985982, -0.1716768742, + 0.0249956138}, + {-0.1547612250, 0.0420673378, 0.0686141029, -0.0187673103, + 0.0606986992}, + {-0.0045782328, 0.1753131598, 0.0497865193, 0.1222290322, + -0.0149712088}, + {-0.0015272141, -0.0793256983, -0.0027205944, 0.0122313974, + 0.0382683761}, + { 0.1789925098, 0.1004394516, 0.1536173671, 0.0328379385, + -0.1461270154}}}, + + {{{-0.0308369156, 0.0518908277, -0.0238906145, 0.0774943829, + 0.1786959022}, + {-0.1896944493, 0.0252965689, 0.1969553977, 0.0170207266, + -0.1349055022}, + { 0.1111816689, -0.1292364597, 0.1484523267, 0.1492144167, + -0.0654201508}, + {-0.0588786863, 0.0246491432, 0.1550032198, 0.1314135045, + -0.0048915865}, + { 0.0985754058, -0.0685965568, -0.1517471522, -0.1095958278, + 0.1006983072}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,1,4,2,2> { + {{{{-0.0801755786, -0.4114282727}, + { 0.8874194026, -0.3093988597}}, + {{ 0.2365355641, 0.5801096559}, + { 1.2374895811, 1.2885351181}}, + {{ 0.2366760373, -0.4241298735}, + {-0.2304262370, 0.8197762370}}, + {{ 0.1419754326, -0.6357300878}, + {-0.7330727577, 0.3483830392}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } } } \ No newline at end of file diff --git a/unit_tests/operator/Test_ConvImpl.cpp b/unit_tests/operator/Test_ConvImpl.cpp index e48d69c89eb0d6d52a834b3f32a41d8621fdd42b..f7be338c0b9c5bb1d5af6bfa09ed7855c17fb6c0 100644 --- a/unit_tests/operator/Test_ConvImpl.cpp +++ b/unit_tests/operator/Test_ConvImpl.cpp @@ -9,390 +9,1640 @@ * ********************************************************************************/ -#include <catch2/catch_test_macros.hpp> -#include <cstdlib> #include <memory> +#include <catch2/catch_test_macros.hpp> +#include <fmt/core.h> + +#include "aidge/backend/cpu/operator/ConvImpl.hpp" +#include "aidge/data/Data.hpp" // DataType #include "aidge/data/Tensor.hpp" +#include "aidge/graph/Node.hpp" #include "aidge/operator/Conv.hpp" #include "aidge/utils/TensorUtils.hpp" -#include "aidge/backend/cpu.hpp" - using namespace Aidge; +/** + * @brief ConvDepthWise reference cpp backend forward implmentation tests. + * + * Summary + * ======= + * kernel [3, 3] + * no stride, no dilation + * stride [2,2], no dilation + * stride [2,2], dilation [2,2] + * kernel [1,1] + * no stride, no dilation + * stride [3,3], no dilation + * stride [3,3], dilation [2,2] + * kernel [5,5] + * no stride, no dilation + * stride [2,2], no dilation + * stride [2,2], dilation [2,2] + */ TEST_CASE("[cpu/operator] Conv(forward)", "[Conv][CPU]") { - SECTION("Classic Conv") { - std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv"); - auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); - std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<int,4,3,3,3> { - { - { - {{ 0, 1, 2}, - { 3, 4, 5}, - { 6, 7, 8}}, - {{ 9, 10, 11}, - { 12, 13, 14}, - { 15, 16, 17}}, - {{ 18, 19, 20}, - { 21, 22, 23}, - { 24, 25, 26}} - }, - { - {{ 27, 28, 29}, - { 30, 31, 32}, - { 33, 34, 35}}, - {{ 36, 37, 38}, - { 39, 40, 41}, - { 42, 43, 44}}, - {{ 45, 46, 47}, - { 48, 49, 50}, - { 51, 52, 53}} - }, - { - {{ 54, 55, 56}, - { 57, 58, 59}, - { 60, 61, 62}}, - {{ 63, 64, 65}, - { 66, 67, 68}, - { 69, 70, 71}}, - {{ 72, 73, 74}, - { 75, 76, 77}, - { 78, 79, 80}} - }, + SECTION("Conv with kernel [3,3]") { + SECTION("No stride, no dilation") { + std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv"); + auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,3,5,5> { + {{{{ 1.9589154720, -1.0110363960, -1.3467419147, -1.2994621992, + 0.1868611127}, + {-2.1160471439, 2.6068224907, -0.3956520855, 0.6124371886, + 0.5558118224}, + {-1.3752416372, -0.5381416678, 1.3515229225, 0.3483452201, + 1.2234963179}, + { 0.0059552360, -0.2397697568, 0.2965213358, 2.8708021641, + -0.1894584149}, + {-1.1417659521, -0.7091106772, 0.4265739620, 1.2625461817, + 1.3426892757}}, + + {{-1.7054084539, -0.0123512046, 1.6428576708, -0.6893027425, + -0.7225475907}, + { 0.5119231343, -0.6280173063, -1.1115980148, 0.4660657048, + -0.7594850659}, + { 0.3270559013, -0.5950503945, -0.5041811466, 1.5374211073, + -0.1167122573}, + { 0.1087838784, 0.5390511155, 1.2179180384, 0.3980854154, + -0.7828828692}, + { 0.0556977428, 0.7218912244, 1.9021573067, -1.0387030840, + -0.8821150064}}, + + {{ 1.4301366806, 1.6713913679, 2.7843642235, 2.1220099926, + -0.1225114316}, + { 0.9676079750, 0.0927167758, 1.3160738945, -1.1043174267, + 1.1129677296}, + { 1.7959889174, 0.8373055458, 0.9646789432, -1.0112720728, + -0.5838463902}, + { 2.2678804398, 1.4151918888, 1.3015384674, 1.0426887274, + 0.5917658210}, + {-0.4324578047, -0.8626666665, 1.4560189247, 0.4216258824, + 0.4797532558}}}, + + + {{{ 0.7857398987, 1.0575772524, 1.4281150103, 0.5534361601, + -0.7034347653}, + {-0.6367500424, 0.0736645982, -0.1755308807, -0.0363982692, + -0.2698654532}, + {-0.1483689547, -1.4097578526, 3.0468130112, -0.9070094824, + -0.0465935729}, + {-0.4035107195, 0.3865649998, 0.5000861287, 0.0409870148, + -0.2879518867}, + {-0.3219492733, -0.8549274206, 0.6380167007, 1.0422019958, + 0.6655231714}}, + + {{ 1.8096567392, 0.8781581521, 0.8389310837, -0.5663586259, + -0.3415665030}, + { 0.6761546135, -1.8892842531, -0.4562507868, -1.3220169544, + -0.0600548759}, + {-2.3044283390, -0.6273749471, -0.4794641733, 0.3725788891, + -0.0789731145}, + {-0.0977325812, -0.9537382126, 1.6169245243, -1.5318340063, + -0.0146348709}, + { 2.8766520023, -1.4148812294, -1.6623396873, -0.1664140671, + -0.4492034912}}, + + {{-0.1821998209, 0.2622891963, -0.1877782643, -1.6476312876, + -0.6388390660}, + { 0.3169876039, -0.8038316965, -0.0962172970, 0.9118794799, + -0.6303430200}, + {-1.1592572927, -0.7246652842, 2.0476946831, -0.0111423340, + 0.1810427308}, + {-0.9517292976, 0.8139786720, 0.2079211175, 1.1996579170, + -2.9504573345}, + {-0.0734997243, -0.1853470206, -0.7156494260, 0.0105203446, + 0.1248303726}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> {{-0.0953646377, 0.0252329484, 0.0736814514, -0.1786542684}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,3,3> { + {{{{-9.8095126450e-02, -1.2300920486e-01, 9.8884001374e-02}, + { 9.1700568795e-02, -1.4668887854e-01, 1.8008193374e-01}, + { 1.5175457299e-01, -1.7447699606e-01, 1.4544752240e-01}}, + + {{ 1.0927890241e-01, 1.4000165276e-02, -6.3778474927e-02}, + {-6.1189714819e-02, -1.1776087433e-01, -1.3381159306e-01}, + {-1.0581787676e-01, -2.2982399911e-02, -1.7992550135e-01}}, + + {{ 1.6829776764e-01, -1.4381179214e-01, 6.8733291700e-03}, + {-1.2539045885e-02, 5.9776338749e-03, -1.6262140870e-01}, + {-1.8150532246e-01, -5.7762067765e-02, -1.2558361888e-01}}}, + + + {{{-1.9167131186e-01, -2.0615238696e-02, -1.4749580622e-01}, + { 2.8844498098e-02, -1.4776159823e-01, -1.4355856180e-01}, + { 8.4779271856e-04, 6.2282480299e-02, -4.7553293407e-02}}, + + {{ 1.3088253140e-01, 1.5134441853e-01, -2.6965653524e-02}, + { 1.0772592388e-02, 1.2854418159e-01, 1.0366836190e-01}, + {-1.1401490122e-01, 7.7274993062e-02, 3.0564883724e-02}}, + + {{ 5.1155414432e-02, -9.1598711908e-02, -3.4925807267e-02}, + {-4.7612894326e-02, -1.3718418777e-01, -1.6633707285e-01}, + {-1.0674133152e-01, -1.2472828478e-01, -2.7257451788e-02}}}, + + + {{{-1.6760022938e-01, 7.3507070541e-02, -7.9185843468e-02}, + {-1.4954875410e-01, -1.2724696100e-01, -1.7345303297e-01}, + {-6.1098476872e-03, -1.0876300931e-01, 5.5467881262e-02}}, + + {{ 1.2978881598e-01, 1.3939674199e-01, -1.1531168967e-01}, + {-2.2687437013e-02, -1.7840284854e-02, -1.6743370891e-01}, + { 1.0837453604e-01, 1.8985052407e-01, -1.0379110277e-01}}, + + {{ 6.6287182271e-02, 6.2025051564e-02, 1.2377030216e-02}, + { 3.9242565632e-02, -6.0004377738e-03, -6.8803697824e-02}, + { 4.1780071333e-03, 2.6938719675e-02, 7.6389208436e-02}}}, + + + {{{ 1.1024503410e-01, -2.0850643516e-02, -1.8233209848e-01}, + { 6.6961102188e-02, -4.2337212712e-02, 7.4952361174e-03}, + {-9.7807966173e-02, -1.1996924877e-01, -1.0953181982e-01}}, + + {{-5.0487142056e-02, 1.6390360892e-01, 8.4205769002e-02}, + {-5.8809131384e-02, 1.2781466544e-01, -9.8992012441e-02}, + { 8.6972511781e-05, 1.8518652767e-02, 2.3319255561e-02}}, + + {{ 1.3843496144e-01, 1.2726350129e-01, 6.0541676357e-03}, + {-1.7569662631e-01, -3.2354578376e-02, -2.8981804848e-02}, + { 1.2195236981e-01, -1.5436136723e-01, 1.9030691683e-01}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,4,3,3> { + {{{{-1.5940358639, 0.1530998051, 0.7961921692}, + {-0.6456220150, -0.9623295069, -0.3714832962}, + {-0.2949652672, -0.2190868855, -0.7510035634}}, + + {{-1.9549642801, 0.3644225001, 0.2327819020}, + {-0.6244611144, -1.4931452274, 0.0343996882}, + {-0.1242256761, -0.9374671578, -1.3280395269}}, + + {{ 0.0045107063, -0.0462101400, 0.9180620313}, + { 1.1481817961, -0.4496926665, -0.1462283134}, + { 0.2508939803, 0.0321760587, -0.2549723089}}, + + {{ 0.7890433073, 0.4426230788, -0.1316385418}, + {-0.5980579257, -0.4781580269, -0.2697018385}, + {-0.3229375780, -0.6225202084, -0.4930999875}}}, + + + {{{ 1.0800328255, -0.7354405522, 0.4513073266}, + { 0.6686266065, -0.9933773279, 0.1854345798}, + { 0.3025504351, -0.8967062235, 0.9597628117}}, + + {{-0.0881052017, -0.1263627559, -0.3096815050}, + {-0.5136051774, -0.8888333440, -0.4660321772}, + {-1.5337569714, -0.1932583302, 0.0154455183}}, + + {{ 0.5402041674, -0.0840467885, 0.1597940624}, + {-0.6559837461, 0.1566532403, -0.6488818526}, + {-0.7716163993, 0.0849530324, -0.4354790747}}, + + {{-0.3671937585, -0.3662823737, -0.3439203203}, + {-0.5602309704, -0.4117053151, -1.0589636564}, + {-1.2864066362, 0.0953547135, -0.0424274690}}}} + }); + + op->associateInput(0,myInput); + op->associateInput(1,myWeights); + op->associateInput(2,myBiases); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + op->forwardDims(); + myConv->forward(); + REQUIRE(approxEq<float>(*(op->getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("Stride [2,2] | no dilation") { + fmt::print("Stride conv\n"); + std::shared_ptr<Node> myConv = Conv(2,1,{3,3}, "myconv", {2,2}); + auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,1,2,5,5> { { - {{ 81, 82, 83}, - { 84, 85, 86}, - { 87, 88, 89}}, - {{ 90, 91, 92}, - { 93, 94, 95}, - { 96, 97, 98}}, - {{ 99, 100, 101}, - {102, 103, 104}, - {105, 106, 107}} + { + {{-0.9172891974, -0.4144149423, -0.0127728982, -0.6073911190, -0.0152466390}, + { 0.4086987972, 0.5984987617, -0.6368257999, -0.0744020939,0.9958203435}, + { 0.5346475244, -0.0788366571, -1.4571775198, -0.5634615421, 1.9504889250}, + { 1.1559145451, 0.4456179738, 1.5754781961, 0.0340409055, -0.2864624560}, + {-0.2880946100, -1.1225816011, 0.6820554733, -1.6727229357, 0.5375806093}}, + + {{ 1.4201650620, -1.7509239912, -1.0208708048, -2.2132670879, 0.1117813289}, + {-0.2961948812, -0.6673586369, -0.0750549659, -0.6074910164, 2.7782683372}, + {-0.5388702750, -2.9463961124, 2.1617200375, -0.5921722054, 0.5093105435}, + { 0.3627473414, 0.6647079587, 0.9116655588, -1.1410249472, -1.2326570749}, + { 2.0130438805, 0.2274843603, -0.3941729367, 1.5164465904, -1.4629846811}} + } } - } - }); - std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<int,4> {{7,0,9,0}}); - std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<int,2,3,5,5> { //NCHW - { + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,1> {{0.0510118753}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,1,2,3,3> { //NCHW { - {{ 0, 1, 2, 3, 4}, - { 5, 6, 7, 8, 9}, - { 10, 11, 12, 13, 14}, - { 15, 16, 17, 18, 19}, - { 20, 21, 22, 23, 24}}, - - {{ 25, 26, 27, 28, 29}, - { 30, 31, 32, 33, 34}, - { 35, 36, 37, 38, 39}, - { 40, 41, 42, 43, 44}, - { 45, 46, 47, 48, 49}}, - - {{ 50, 51, 52, 53, 54}, - { 55, 56, 57, 58, 59}, - { 60, 61, 62, 63, 64}, - { 65, 66, 67, 68, 69}, - { 70, 71, 72, 73, 74}} - }, - { - {{ 75, 76, 77, 78, 79}, - { 80, 81, 82, 83, 84}, - { 85, 86, 87, 88, 89}, - { 90, 91, 92, 93, 94}, - { 95, 96, 97, 98, 99}}, - - {{100, 101, 102, 103, 104}, - {105, 106, 107, 108, 109}, - {110, 111, 112, 113, 114}, - {115, 116, 117, 118, 119}, - {120, 121, 122, 123, 124}}, - - {{125, 126, 127, 128, 129}, - {130, 131, 132, 133, 134}, - {135, 136, 137, 138, 139}, - {140, 141, 142, 143, 144}, - {145, 146, 147, 148, 149}} + { + {{-0.0313824862, 0.0508503988, 0.0164926797}, + { 0.1006948650, -0.1172138453, -0.1695717275}, + { 0.0428596064, -0.0209989939, 0.0409581661}}, + + {{ 0.1578378379, 0.1206310838, -0.0435518846}, + {-0.1232439354, -0.0746039376, -0.1832553446}, + {-0.2099054456, -0.0485608950, -0.0263982005}} + } } - } - }); - std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<int,2,4,3,3> { - { + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,1,1,2,2> { + {{{{ 0.4589637220, -1.5007210970}, + {-1.3768237829, 0.8838264346}}}} + }); + op->associateInput(0,myInput); + op->associateInput(1,myWeights); + op->associateInput(2,myBiases); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + op->forwardDims(); + myConv->forward(); + REQUIRE(approxEq<float>(*(op->getOutput(0)),*myOutput, 1e-5f, 1e-8f)); + } + SECTION("Stride [2,2] | dilation [2,2]") { + std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv", {2,2},{2,2}); + auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,3,8,8> { + {{{{-0.1188705787, 0.2816951275, -0.2984274328, -1.1976412535, + 0.8114287257, 0.4823331535, -1.0786532164, 0.9851297736}, + { 0.0904922038, -1.1438461542, -0.0669364706, 0.3452534974, + 0.7429151535, 0.2972919941, 1.0456926823, 1.4523943663}, + {-0.3802470863, -0.6462681890, -0.0720853731, -0.2438412607, + -0.4248653948, 1.1631586552, 0.1168384328, -0.6349191666}, + {-0.6761771441, 0.4402402937, 0.1597155333, 1.9242455959, + 0.5197503567, 0.2465762347, -0.4339638352, -1.8066062927}, + {-1.9127304554, 0.6646876931, -0.4783093035, -1.3473324776, + 0.0341839045, -2.2884213924, -0.1275558323, -1.7231218815}, + {-0.2966701984, -1.3017544746, -1.1662167311, 0.1658521742, + -0.6164554954, 0.0081211235, -1.0151772499, -1.5334042311}, + { 0.7240908146, -0.3402064443, 0.6997579932, -2.0220518112, + -0.3068118095, -1.0308376551, -0.2388850898, -0.7060801387}, + { 0.1785067171, 1.7402675152, -2.3868498802, -1.1331337690, + -0.4219121039, 1.8711329699, -0.6606232524, -1.6592081785}}, + + {{ 0.6190823913, 0.8092262745, -0.7060685158, 0.0304337163, + 0.4091666043, 0.0813116133, 0.4840675890, 0.2641656399}, + { 0.2542354465, 0.2307302207, 0.6869923472, 0.2362013012, + -1.1283814907, 0.8362480402, 0.8266768456, 0.0802312344}, + {-0.3555500209, -2.2554228306, -1.5869985819, 0.3333877325, + -0.4702588022, -0.6803518534, 0.6176669002, 0.4762512147}, + { 1.1077108383, 1.1330136061, 1.8400819302, -1.0991362333, + 2.1611855030, 0.1117769480, -0.6205040216, -0.2117952108}, + {-1.5193798542, -0.1550827324, -0.5265267491, -1.2790243626, + 0.3303467929, 1.0218718052, 1.1192175150, -0.7710328102}, + {-0.0476187728, 0.4485085905, -0.1582977027, 0.0724322647, + -1.1117008924, -0.8183123469, 0.5641981363, -1.6811249256}, + { 0.9987262487, -0.1189468578, 2.1283075809, 0.4363055527, + -0.3778467178, -0.0662526861, 1.8629752398, 0.0138667496}, + { 0.7676854134, 0.1937866956, 1.4422281981, -1.1290216446, + 0.7080894709, 2.5929143429, -0.7571783662, 0.2852989435}}, + + {{ 1.4924588203, -0.4435261190, 2.0567798615, -0.0626328290, + -0.1295254529, -0.7903441191, -0.8875020146, -0.0774324834}, + {-0.1324225664, 1.0868368149, -1.1364834309, -1.6390762329, + 1.9040355682, -0.4707730114, 0.0161281377, -0.6975379586}, + {-0.3238699436, 0.8842161894, 1.1445614100, -1.6232908964, + -0.8882713914, -1.8093744516, -0.1777562201, 0.5819293261}, + { 0.7295795679, 0.1291531473, -0.4221164286, -0.7543993592, + 1.7530804873, 1.2137289047, 0.1712447554, 0.7797858715}, + {-0.9555802941, 0.5811409950, -0.8670540452, 0.2635410130, + -0.7124243975, -1.1684633493, -0.9052340388, 1.9412840605}, + { 1.9100213051, -1.0305522680, 0.5274596214, -1.2627660036, + -0.2148997933, 0.2035689354, -1.7192043066, -0.2544563115}, + { 0.2017714083, 0.4016665220, 0.3390722871, -0.3035760522, + -0.3663327694, 0.0600613877, -0.0339251719, -0.8913670778}, + { 1.7665598392, 0.9260235429, 0.3003851175, -0.9005023837, + 1.1928522587, 0.4629152417, -0.5554659963, 0.7210552692}}}, + + + {{{-1.4669480324, -1.9388267994, 1.5483425856, 1.1181882620, + -0.3805114627, 1.1109890938, -0.2052619010, 1.1042164564}, + { 0.4673211575, 0.4196352065, 0.4227792621, -0.5219387412, + 0.0205618106, 0.1554571241, 0.0663015544, -0.9489745498}, + {-1.0466605425, 0.5461820960, 0.6410058141, -1.3620399237, + -1.0813049078, 0.1793443412, -0.1706669927, -0.4602096379}, + {-1.7209292650, -0.2953333557, 1.7118937969, -0.4912810624, + -0.1289002448, 0.1328577548, 0.1901275814, -0.3454665840}, + { 0.8632204533, -0.5402244329, -0.7203282118, 0.1909070015, + -0.9598633051, -0.3567193747, -1.2241401672, -1.4285988808}, + {-0.0053109927, 0.5492724180, -0.6969500184, -0.0806153566, + 0.0108112106, 1.0541758537, 0.7616159320, 0.4281075597}, + {-0.4106684327, -0.1015827805, 0.3509808779, 0.5693497658, + -0.1018603221, 1.4039390087, -0.0777162984, 0.6270876527}, + {-1.0135161877, -1.8348534107, 0.4114980996, 0.3335310519, + 0.1524153352, 0.3699457347, 0.7671044469, -0.4262715578}}, + + {{-1.6966865063, -1.6478683949, -1.9790385962, 0.2527430952, + -0.0905110165, -0.0635698363, -0.0573330820, 1.0314729214}, + {-0.6635334492, 0.6282949448, -1.7904577255, -0.8557172418, + -0.0034775073, -1.8597713709, 0.0629400164, 1.8645975590}, + {-0.2417802215, -2.0907909870, -0.6610770226, 0.0370707251, + -0.3790464699, 0.4056134224, -0.6770003438, -0.1630496234}, + { 0.5351417661, 0.5648558736, 0.0339234099, -1.0622112751, + 0.2600688934, -0.9012448788, -0.1649469733, 1.0918279886}, + { 0.1757414639, -0.9574738145, -0.1036170200, 0.4119592607, + -0.1864566058, -0.7414899468, -1.0112596750, 0.5932180285}, + { 2.5017285347, -1.1581023932, 1.3093953133, -0.7198146582, + -1.5908024311, 0.8121448755, 0.7034096718, -0.9230241179}, + { 0.2434841692, -0.3430828154, -0.5492738485, 1.2563036680, + -0.7979055047, -0.5963852406, -1.2386600971, 1.0888540745}, + {-1.0111418962, -0.8539634347, -0.5791223645, -0.1067883298, + -1.5724925995, 2.1116111279, -1.1383074522, -1.0958977938}}, + + {{-1.1234122515, 0.1245573759, 1.0143580437, 1.5302449465, + 0.2031295598, -1.9043928385, 1.3508439064, -0.2800186276}, + {-0.9589865804, 0.1616169512, 0.2391400337, -1.4750261307, + 0.0099803358, -0.1171493977, 0.6740672588, -0.1370818019}, + {-0.1391109377, 0.0727060661, -0.4593637884, -0.7321376801, + 1.5460444689, -0.8100689054, -0.9888796806, -0.4120852351}, + {-0.0212404300, -0.4793040454, 0.6356511712, -0.4630482495, + 0.2108679265, -0.3083102107, 0.5054508448, 1.1919053793}, + {-0.0884858668, 3.4338581562, 0.7329123616, -2.1103246212, + 0.7225313187, 0.9016330242, 0.9377334714, -0.4367531836}, + {-0.3704574108, 1.6491469145, 0.3141648471, 0.4510190189, + 0.4318002760, -0.9045167565, 0.8301303983, 0.7684385180}, + {-0.8037292361, -0.5441461802, -0.7591288686, 0.0744654313, + 1.3689713478, -0.5011886954, 1.7228851318, 0.5603687763}, + {-0.6605531573, -0.6788222194, -0.7361486554, -0.5757233500, + 1.0730458498, -1.3874051571, -0.3474266231, -0.8265513778}}}} + }); + std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float,4> {{ 0.0990660042, -0.1225081310, 0.1313948184, 0.1762123108}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,3,3> { //NCHW { - {{ 15226, 15577, 15928}, - { 16981, 17332, 17683}, - { 18736, 19087, 19438}}, - {{ 37818, 38898, 39978}, - { 43218, 44298, 45378}, - { 48618, 49698, 50778}}, - {{ 60426, 62235, 64044}, - { 69471, 71280, 73089}, - { 78516, 80325, 82134}}, - {{ 83016, 85554, 88092}, - { 95706, 98244, 100782}, - {108396, 110934, 113472}} - }, + { + {{-0.0742266178, -0.1015041918, -0.1837180406}, + { 0.1212819815, 0.0690591782, -0.1593577117}, + {-0.0843696669, -0.1086091846, -0.0535376482}}, + {{-0.1789309531, 0.0376559310, 0.0013079831}, + {-0.1449880898, -0.0421767347, -0.1624120027}, + {-0.0249556731, 0.0128599331, -0.0885131210}}, + {{ 0.0676055402, -0.0983952284, 0.1473675370}, + { 0.1839909703, 0.1130383387, -0.0591052435}, + { 0.0585059784, 0.1588494480, 0.0837457627}} + }, + { + {{-0.1567524076, 0.1625535488, 0.0182458963}, + { 0.1850223690, -0.0358172096, -0.1628112793}, + {-0.1376292855, -0.0344094560, -0.1884102225}}, + {{-0.0423617847, 0.0293476824, 0.1536137760}, + {-0.0270300750, -0.0305194370, -0.1655584276}, + {-0.1863301992, -0.1162049547, 0.1737580448}}, + {{ 0.0242811367, -0.1578935832, 0.0989272967}, + { 0.1372362673, -0.0220122132, 0.1455551833}, + {-0.1089558378, -0.0905081928, 0.0518258214}} + }, + { + {{ 0.0139021352, 0.0587250255, -0.1449213177}, + {-0.0330693051, -0.1812424064, -0.1139466539}, + {-0.1164239123, 0.1715961397, -0.0794166178}}, + {{ 0.1213000864, -0.0386510566, -0.1356311589}, + {-0.0149245840, 0.0932215229, -0.0739067197}, + {-0.0967126489, 0.1908948421, -0.1560722739}}, + {{ 0.0209723040, 0.1112291217, -0.0647701770}, + {-0.1392559260, -0.1320645362, -0.0444710776}, + { 0.0430435687, 0.0755403191, -0.0292177629}} + }, + { + {{-0.0096048070, -0.1168650612, 0.0795118213}, + { 0.1292631775, 0.1187104806, 0.0055516954}, + {-0.0513968319, 0.0119201895, -0.1323033124}}, + {{-0.0103231622, -0.0167140234, -0.1640086174}, + { 0.0006717141, -0.1033149883, 0.1532802731}, + {-0.1091904417, -0.0362586826, -0.0466702394}}, + {{ 0.1598802209, 0.0274402276, -0.0314338058}, + {-0.0309233274, -0.0172249842, 0.0602610074}, + {-0.0760313869, 0.1626167148, -0.1280857325}} + } + } + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,2,4,2,2> { { - {{ 41551, 41902, 42253}, - { 43306, 43657, 44008}, - { 45061, 45412, 45763}}, - {{118818, 119898, 120978}, - {124218, 125298, 126378}, - {129618, 130698, 131778}}, - {{196101, 197910, 199719}, - {205146, 206955, 208764}, - {214191, 216000, 217809}}, - {{273366, 275904, 278442}, - {286056, 288594, 291132}, - {298746, 301284, 303822}} + { + {{ 2.5577675551e-02, 3.3609107137e-01}, + {-4.1263043880e-01, -8.6513996124e-02}}, + {{ 3.3599328995e-01, 6.7887884378e-01}, + {-1.6548714638e+00, -4.3978449702e-01}}, + {{ 2.5655159354e-01, 1.6271021962e-01}, + { 1.2917815447e+00, -6.9365233183e-01}}, + {{ 7.6308822632e-01, 4.2617604136e-01}, + { 9.5820712158e-04, -8.5415104404e-03}} + }, + { + {{ 4.8839592934e-01, 1.6159484386e+00}, + { 8.3929586411e-01, 7.7249741554e-01}}, + {{ 4.6416598558e-01, -1.5206467360e-02}, + { 8.9938336611e-01, -3.4680870175e-01}}, + {{ 1.1395914108e-01, 2.1783047915e-01}, + { 2.8127232194e-01, 5.7355374098e-01}}, + {{ 3.7263277918e-02, 3.5707062483e-01}, + {-1.1406441033e-01, 2.7624604106e-01}} + } } - } - }); - op->associateInput(0,myInput); - op->associateInput(1,myWeights); - op->associateInput(2,myBias); - op->setDataType(DataType::Int32); - op->setBackend("cpu"); - myConv->forward(); - op->getOutput(0)->print(); - REQUIRE(*(op->getOutput(0)) == *myOutput); + }); + op->associateInput(0,myInput); + op->associateInput(1,myWeights); + op->associateInput(2,myBias); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + op->forwardDims(); + myConv->forward(); + REQUIRE(approxEq<float>(*(op->getOutput(0)),*myOutput, 1e-5f, 1e-8f)); + } } + SECTION("Point-wise") { - std::shared_ptr<Node> myConv = Conv(3,4,{1,1}, "myconv", {1,1}); - auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); - op->setInput(0, std::make_shared<Tensor>(Array4D<float,2,3,3,3> { - { - { - {{-1.38467371F, -0.87123615F, -0.22336592F}, - { 1.71736145F, 0.31888032F, -0.42451897F}, - { 0.30572093F, -0.77459252F, -1.55757248F}}, - {{ 0.99563611F, -0.87978584F, -0.60114205F}, - {-1.27415121F, 2.12278509F, -1.23465312F}, - {-0.48791388F, -0.91382301F, -0.65813726F}}, - {{ 0.07802387F, 0.52580875F, -0.48799172F}, - { 1.19136906F, -0.81400764F, -0.73599279F}, - {-1.40324783F, 0.03600367F, -0.06347727F}} - }, + SECTION("no stride, no dilation") { + std::shared_ptr<Node> myConv = Conv(3,4,{1,1}, "myconv", {1,1}); + auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); + op->setInput(0, std::make_shared<Tensor>(Array4D<float,2,3,3,3> { { - {{ 0.67561489F, -0.09780689F, 1.84459400F}, - {-1.18453741F, 1.38354933F, 1.44513381F}, - { 0.85641253F, 2.21807575F, 0.52316552F}}, - {{ 0.34664667F, -0.19733144F, 1.14120162F}, - { 0.05164360F, 0.72810954F, -0.71064192F}, - {-0.60206831F, 0.96044880F, 0.40481427F}}, - {{-1.35434294F, 1.33470297F, 0.48353928F}, - {-0.19756168F, 1.26831138F, 1.22426283F}, - { 0.09811721F, 1.74225271F, -1.35267365F}} + { + {{-1.38467371F, -0.87123615F, -0.22336592F}, + { 1.71736145F, 0.31888032F, -0.42451897F}, + { 0.30572093F, -0.77459252F, -1.55757248F}}, + {{ 0.99563611F, -0.87978584F, -0.60114205F}, + {-1.27415121F, 2.12278509F, -1.23465312F}, + {-0.48791388F, -0.91382301F, -0.65813726F}}, + {{ 0.07802387F, 0.52580875F, -0.48799172F}, + { 1.19136906F, -0.81400764F, -0.73599279F}, + {-1.40324783F, 0.03600367F, -0.06347727F}} + }, + { + {{ 0.67561489F, -0.09780689F, 1.84459400F}, + {-1.18453741F, 1.38354933F, 1.44513381F}, + { 0.85641253F, 2.21807575F, 0.52316552F}}, + {{ 0.34664667F, -0.19733144F, 1.14120162F}, + { 0.05164360F, 0.72810954F, -0.71064192F}, + {-0.60206831F, 0.96044880F, 0.40481427F}}, + {{-1.35434294F, 1.33470297F, 0.48353928F}, + {-0.19756168F, 1.26831138F, 1.22426283F}, + { 0.09811721F, 1.74225271F, -1.35267365F}} + } } - } - })); - op->setInput(1, std::make_shared<Tensor>(Array4D<float,4,3,1,1> { - { - { - {{ 0.33669037F}}, - {{ 0.12880941F}}, - {{ 0.23446237F}} - }, - { - {{ 0.23033303F}}, - {{-1.12285638F}}, - {{-0.18632829F}} - }, + })); + op->setInput(1, std::make_shared<Tensor>(Array4D<float,4,3,1,1> { { - {{ 2.20820141F}}, - {{-0.63799703F}}, - {{ 0.46165723F}}}, - { - {{ 0.26735088F}}, - {{ 0.53490466F}}, - {{ 0.80935723F}} + { + {{ 0.33669037F}}, + {{ 0.12880941F}}, + {{ 0.23446237F}} + }, + { + {{ 0.23033303F}}, + {{-1.12285638F}}, + {{-0.18632829F}} + }, + { + {{ 2.20820141F}}, + {{-0.63799703F}}, + {{ 0.46165723F}}}, + { + {{ 0.26735088F}}, + {{ 0.53490466F}}, + {{ 0.80935723F}} + } } - } - })); - op->setInput(2, std::make_shared<Tensor>(Array1D<float,4> {{ 1.11029029F, -1.68979895F, -0.98895991F, 0.95797181F}})); - Tensor expectedOutput = Array4D<float,2,4,3,3> { - { - { - {{ 0.79062498F, 0.82691115F, 0.84323663F}, - { 1.80371785F, 1.30023468F, 0.63576132F}, - { 0.82136691F, 0.74022496F, 0.48621333F}}, - {{-3.14122939F, -1.00057328F, -0.97532475F}, - {-0.08553087F, -3.84826040F, -0.26410526F}, - {-0.81005937F, -0.84882969F, -1.29773819F}}, - {{-4.64579105F, -2.10878062F, -1.32395494F}, - { 4.16622877F, -2.01493120F, -1.47845459F}, - {-0.65039843F, -2.09977841F, -4.03780890F}}, - {{ 1.18349767F, 0.68001163F, 0.18174142F}, - { 1.69980371F, 1.51988935F, -0.41162649F}, - {-0.35700959F, 0.29121545F, 0.13813695F}} - }, + })); + op->setInput(2, std::make_shared<Tensor>(Array1D<float,4> {{ 1.11029029F, -1.68979895F, -0.98895991F, 0.95797181F}})); + Tensor expectedOutput = Array4D<float,2,4,3,3> { { - {{ 1.06487226F, 1.36487913F, 1.99171650F}, - { 0.67179936F, 1.96727657F, 1.79235911F}, - { 1.34408879F, 2.38930249F, 1.02142799F}}, - {{-1.67106462F, -1.73944509F, -2.63643050F}, - {-1.98381400F, -2.42500663F, -0.78710288F}, - {-0.83478457F, -2.58197999F, -1.77180362F}}, - {{-0.34346789F, -0.46286502F, 2.57942152F}, - {-3.72881150F, 2.18718910F, 3.22076392F}, - { 1.33158576F, 4.10055828F, -0.71644694F}}, - {{ 0.22787374F, 1.90652108F, 2.45291567F}, - { 0.50901115F, 2.74385118F, 1.95506990F}, - { 0.94429719F, 3.47482967F, 0.21958135F}} + { + {{ 0.79062498F, 0.82691115F, 0.84323663F}, + { 1.80371785F, 1.30023468F, 0.63576132F}, + { 0.82136691F, 0.74022496F, 0.48621333F}}, + {{-3.14122939F, -1.00057328F, -0.97532475F}, + {-0.08553087F, -3.84826040F, -0.26410526F}, + {-0.81005937F, -0.84882969F, -1.29773819F}}, + {{-4.64579105F, -2.10878062F, -1.32395494F}, + { 4.16622877F, -2.01493120F, -1.47845459F}, + {-0.65039843F, -2.09977841F, -4.03780890F}}, + {{ 1.18349767F, 0.68001163F, 0.18174142F}, + { 1.69980371F, 1.51988935F, -0.41162649F}, + {-0.35700959F, 0.29121545F, 0.13813695F}} + }, + { + {{ 1.06487226F, 1.36487913F, 1.99171650F}, + { 0.67179936F, 1.96727657F, 1.79235911F}, + { 1.34408879F, 2.38930249F, 1.02142799F}}, + {{-1.67106462F, -1.73944509F, -2.63643050F}, + {-1.98381400F, -2.42500663F, -0.78710288F}, + {-0.83478457F, -2.58197999F, -1.77180362F}}, + {{-0.34346789F, -0.46286502F, 2.57942152F}, + {-3.72881150F, 2.18718910F, 3.22076392F}, + { 1.33158576F, 4.10055828F, -0.71644694F}}, + {{ 0.22787374F, 1.90652108F, 2.45291567F}, + { 0.50901115F, 2.74385118F, 1.95506990F}, + { 0.94429719F, 3.47482967F, 0.21958135F}} + } } - } - }; - op->setDataType(DataType::Float32); - op->setBackend("cpu"); - myConv->forward(); - - float* resPtr = static_cast<float*>(op->getOutput(0)->getImpl()->rawPtr()); - float* expectedPtr = static_cast<float*>(expectedOutput.getImpl()->rawPtr()); - for (std::size_t i = 0; i< expectedOutput.size(); ++i) { - REQUIRE(std::abs(resPtr[i]-expectedPtr[i]) < 0.00001); + }; + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + myConv->forward(); + REQUIRE(approxEq<float>(*(op->getOutput(0)),expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [3,3], no dilation") { + std::shared_ptr<Node> myConv = Conv(3,4,{1,1}, "myconv", {3,3}); + auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,1,3,3,3> { + {{{{ 0.5328165889, -0.4486202002, -0.4963828325}, + { 0.2804954648, 0.4619753063, -0.9826803803}, + { 0.4261451066, 0.5110000372, 1.9890428782}}, + + {{ 1.1952217817, -1.5133171082, 0.1646732241}, + {-0.7254997492, -0.1677423269, -1.0745935440}, + { 0.2478682548, -2.3416306973, 0.5321671963}}, + + {{ 0.7880555391, 2.6202778816, -1.0282965899}, + {-0.5344914198, -0.5824835896, 0.4730331898}, + {-1.3073508739, 0.2466813326, -0.3054754436}}}} + }); + std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float,4> {{0.3835324347, 0.5210654140, 0.0670478195, 0.2723239064}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,1,1> { //NCHW + {{{{ 0.3910248578}}, + {{ 0.4545786083}}, + {{-0.1685599536}}}, + + {{{-0.2521645725}}, + {{ 0.1072919592}}, + {{ 0.1053370386}}}, + + {{{ 0.4627405703}}, + {{-0.4509093165}}, + {{ 0.3544224203}}}, + + {{{ 0.4944877923}}, + {{ 0.1518175900}}, + {{ 0.1704865843}}}} + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,1,4,1,1> { + {{{{1.0023646355}}, + {{0.5979570746}}, + {{0.0539716333}}, + {{0.8516038060}}}} + }); + op->associateInput(0,myInput); + op->associateInput(1,myWeights); + op->associateInput(2,myBias); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + op->forwardDims(); + myConv->forward(); + REQUIRE(approxEq<float>(*(op->getOutput(0)),*myOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [3,3], dilation [2,2]") { // same as no dilation test + std::shared_ptr<Node> myConv = Conv(3,4,{1,1}, "myconv", {3,3}, {2,2}); + auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,1,3,3,3> { + {{{{ 0.5328165889, -0.4486202002, -0.4963828325}, + { 0.2804954648, 0.4619753063, -0.9826803803}, + { 0.4261451066, 0.5110000372, 1.9890428782}}, + + {{ 1.1952217817, -1.5133171082, 0.1646732241}, + {-0.7254997492, -0.1677423269, -1.0745935440}, + { 0.2478682548, -2.3416306973, 0.5321671963}}, + + {{ 0.7880555391, 2.6202778816, -1.0282965899}, + {-0.5344914198, -0.5824835896, 0.4730331898}, + {-1.3073508739, 0.2466813326, -0.3054754436}}}} + }); + std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float,4> {{0.3835324347, 0.5210654140, 0.0670478195, 0.2723239064}}); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,1,1> { //NCHW + {{{{ 0.3910248578}}, + {{ 0.4545786083}}, + {{-0.1685599536}}}, + + {{{-0.2521645725}}, + {{ 0.1072919592}}, + {{ 0.1053370386}}}, + + {{{ 0.4627405703}}, + {{-0.4509093165}}, + {{ 0.3544224203}}}, + + {{{ 0.4944877923}}, + {{ 0.1518175900}}, + {{ 0.1704865843}}}} + }); + std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,1,4,1,1> { + {{{{1.0023646355}}, + {{0.5979570746}}, + {{0.0539716333}}, + {{0.8516038060}}}} + }); + op->associateInput(0,myInput); + op->associateInput(1,myWeights); + op->associateInput(2,myBias); + op->setDataType(DataType::Float32); + op->setBackend("cpu"); + op->forwardDims(); + myConv->forward(); + REQUIRE(approxEq<float>(*(op->getOutput(0)),*myOutput, 1e-5f, 1e-8f)); } } - SECTION("Strided and dilated Conv") { - std::shared_ptr<Node> myConv = Conv(3,4,{3,3}, "myconv", {3,3},{2,2}); - auto op = std::static_pointer_cast<OperatorTensor>(myConv -> getOperator()); - std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,3,8,8> { - {{{ - {0.0107F, 0.5076F, 0.2293F, 0.0486F, 0.7375F, 0.2637F, 0.9615F, 0.9138F}, - {0.0678F, 0.5604F, 0.1940F, 0.0287F, 0.1029F, 0.2059F, 0.5058F, 0.9885F}, - {0.9904F, 0.2890F, 0.4606F, 0.1055F, 0.9028F, 0.1654F, 0.6499F, 0.4775F}, - {0.9499F, 0.4695F, 0.1713F, 0.0731F, 0.4913F, 0.8921F, 0.1782F, 0.1111F}, - {0.2479F, 0.4669F, 0.1078F, 0.6153F, 0.0299F, 0.6484F, 0.2397F, 0.1814F}, - {0.3779F, 0.9032F, 0.5651F, 0.3896F, 0.8439F, 0.6404F, 0.3813F, 0.0841F}, - {0.5566F, 0.8950F, 0.1226F, 0.8881F, 0.9870F, 0.6256F, 0.6387F, 0.0628F}, - {0.2857F, 0.0579F, 0.6247F, 0.1286F, 0.0951F, 0.1268F, 0.9510F, 0.3789F}}, - - {{0.7648F, 0.5340F, 0.1024F, 0.4098F, 0.9958F, 0.7941F, 0.1190F, 0.7328F}, - {0.4532F, 0.6598F, 0.9146F, 0.1690F, 0.6041F, 0.7230F, 0.5719F, 0.9282F}, - {0.2862F, 0.2329F, 0.7302F, 0.6717F, 0.1983F, 0.1876F, 0.4561F, 0.2126F}, - {0.7849F, 0.0239F, 0.7977F, 0.5935F, 0.9958F, 0.4703F, 0.4612F, 0.1627F}, - {0.6393F, 0.3544F, 0.8643F, 0.5039F, 0.8087F, 0.6521F, 0.5086F, 0.9331F}, - {0.7749F, 0.9798F, 0.6820F, 0.7869F, 0.5144F, 0.2941F, 0.8137F, 0.4561F}, - {0.6505F, 0.3974F, 0.6909F, 0.7019F, 0.2729F, 0.4240F, 0.0162F, 0.1536F}, - {0.3529F, 0.8821F, 0.1812F, 0.3426F, 0.3472F, 0.0300F, 0.8841F, 0.8088F}}, - - {{0.5099F, 0.3323F, 0.1488F, 0.3424F, 0.1494F, 0.6225F, 0.8103F, 0.5995F}, - {0.9198F, 0.5635F, 0.8908F, 0.9378F, 0.6689F, 0.3176F, 0.3755F, 0.3883F}, - {0.0626F, 0.5309F, 0.0307F, 0.3955F, 0.2794F, 0.1420F, 0.4758F, 0.7558F}, - {0.6154F, 0.5280F, 0.2318F, 0.3832F, 0.4435F, 0.3490F, 0.4043F, 0.5872F}, - {0.3705F, 0.3848F, 0.2182F, 0.8332F, 0.4559F, 0.5310F, 0.4611F, 0.4236F}, - {0.6141F, 0.8103F, 0.2260F, 0.9907F, 0.5615F, 0.4520F, 0.6949F, 0.0175F}, - {0.3969F, 0.5021F, 0.0970F, 0.9937F, 0.9270F, 0.4302F, 0.2868F, 0.3891F}, - {0.8693F, 0.5170F, 0.5348F, 0.2676F, 0.9769F, 0.3356F, 0.9427F, 0.3908F}} - }, - { - {{0.4803F, 0.5223F, 0.6395F, 0.8402F, 0.4442F, 0.6377F, 0.7852F, 0.9063F}, - {0.0361F, 0.0470F, 0.3104F, 0.6921F, 0.0543F, 0.4490F, 0.9541F, 0.7395F}, - {0.3832F, 0.3828F, 0.2236F, 0.2068F, 0.4369F, 0.7443F, 0.6952F, 0.6394F}, - {0.5309F, 0.8483F, 0.1991F, 0.9756F, 0.8969F, 0.7284F, 0.4657F, 0.5486F}, - {0.8839F, 0.3260F, 0.6892F, 0.4074F, 0.9473F, 0.5526F, 0.4147F, 0.4786F}, - {0.9674F, 0.0952F, 0.8379F, 0.2163F, 0.9420F, 0.4046F, 0.1339F, 0.5234F}, - {0.4213F, 0.8392F, 0.3184F, 0.4576F, 0.9349F, 0.8267F, 0.0931F, 0.8009F}, - {0.5570F, 0.5871F, 0.4175F, 0.5465F, 0.6679F, 0.9224F, 0.0049F, 0.9421F}}, - - {{0.3739F, 0.6230F, 0.7613F, 0.1337F, 0.8527F, 0.0557F, 0.6424F, 0.8463F}, - {0.7179F, 0.5638F, 0.2457F, 0.4579F, 0.0487F, 0.8693F, 0.8216F, 0.0415F}, - {0.1724F, 0.5108F, 0.9103F, 0.0850F, 0.0080F, 0.8927F, 0.7706F, 0.3600F}, - {0.7751F, 0.8828F, 0.7872F, 0.4541F, 0.3181F, 0.1855F, 0.2486F, 0.0033F}, - {0.5558F, 0.3500F, 0.6034F, 0.1763F, 0.7418F, 0.5190F, 0.5147F, 0.4090F}, - {0.4476F, 0.1249F, 0.8116F, 0.9091F, 0.1738F, 0.6150F, 0.3285F, 0.3133F}, - {0.5657F, 0.4447F, 0.5049F, 0.3425F, 0.7443F, 0.2718F, 0.2466F, 0.5586F}, - {0.3684F, 0.7616F, 0.5165F, 0.9621F, 0.2864F, 0.7747F, 0.8110F, 0.7045F}}, - - {{0.4570F, 0.4577F, 0.0373F, 0.6084F, 0.4632F, 0.3472F, 0.9917F, 0.2011F}, - {0.7921F, 0.2202F, 0.9525F, 0.7274F, 0.3357F, 0.0076F, 0.5786F, 0.3034F}, - {0.6510F, 0.0798F, 0.2757F, 0.1738F, 0.3046F, 0.2197F, 0.3872F, 0.5650F}, - {0.1532F, 0.3204F, 0.6094F, 0.3287F, 0.8903F, 0.9773F, 0.7950F, 0.2845F}, - {0.2482F, 0.3395F, 0.8795F, 0.4325F, 0.1395F, 0.2457F, 0.2968F, 0.5424F}, - {0.8636F, 0.7426F, 0.2151F, 0.6900F, 0.3938F, 0.0062F, 0.4980F, 0.4098F}, - {0.8026F, 0.0464F, 0.2662F, 0.7835F, 0.8444F, 0.0688F, 0.8796F, 0.7625F}, - {0.2764F, 0.5341F, 0.1773F, 0.6671F, 0.7555F, 0.5235F, 0.7142F, 0.9423F}}}} - }); - std::shared_ptr<Tensor> myBias = std::make_shared<Tensor>(Array1D<float,4> {{ 0.1902F, -0.1789F, -0.0314F, -0.0589F}}); - std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,3,3> { //NCHW - { - { - {{ 0.0039F, 0.1098F, -0.0834F}, - {-0.0890F, 0.0725F, -0.1178F}, - { 0.1056F, -0.0924F, -0.0574F}}, - {{ 0.0070F, -0.0730F, -0.0674F}, - {-0.0380F, -0.1025F, -0.0085F}, - {-0.1451F, -0.0656F, 0.1137F}}, - {{ 0.1020F, 0.1025F, -0.0678F}, - { 0.0028F, 0.1512F, -0.0871F}, - { 0.1563F, -0.1446F, -0.1636F}} - }, - { - {{ 0.1472F, 0.0025F, -0.0281F}, - { 0.0350F, 0.0296F, -0.1711F}, - {-0.1197F, -0.1198F, -0.1130F}}, - {{-0.1492F, 0.1554F, -0.1044F}, - { 0.1203F, -0.1596F, 0.0589F}, - {-0.0436F, -0.1876F, -0.0816F}}, - {{ 0.1572F, -0.0982F, 0.1293F}, - { 0.1358F, 0.1559F, 0.1322F}, - { 0.0296F, -0.0354F, -0.0632F}} - }, - { - {{-0.0941F, -0.0479F, 0.0908F}, - {-0.1319F, -0.1333F, 0.1223F}, - {-0.1098F, 0.1924F, 0.1075F}}, - {{ 0.1796F, 0.0213F, 0.0626F}, - { 0.0275F, 0.1883F, -0.0818F}, - { 0.0363F, 0.0684F, 0.1094F}}, - {{ 0.1131F, 0.1258F, -0.0558F}, - { 0.1498F, 0.0322F, -0.0186F}, - {-0.1801F, -0.0358F, 0.1727F}} - }, - { - {{-0.1500F, -0.0554F, -0.0994F}, - {-0.0818F, -0.1223F, 0.1365F}, - { 0.1281F, 0.1507F, -0.0890F}}, - {{-0.0444F, -0.1071F, -0.1632F}, - { 0.0757F, -0.1235F, 0.0408F}, - { 0.0401F, -0.1914F, 0.1772F}}, - {{-0.0714F, 0.1582F, -0.0065F}, - {-0.0119F, 0.1375F, -0.0727F}, - {-0.1532F, -0.1826F, -0.0417F}} - } - } - }); - std::shared_ptr<Tensor> myOutput = std::make_shared<Tensor>(Array4D<float,2,4,2,2> { - { - { - {{-0.2174F, -0.0778F}, - {-0.2584F, 0.2303F}}, - {{-0.7686F, -0.3879F}, - {-0.1775F, 0.0119F}}, - {{ 0.5180F, 0.5087F}, - { 0.5398F, 0.3476F}}, - {{-0.5258F, -0.3128F}, - {-0.6673F, -0.1827F}} - }, - { - {{-0.1902F, -0.0467F}, - {-0.3327F, -0.1701F}}, - {{-0.5505F, -0.4875F}, - {-0.4119F, -0.5726F}}, - {{ 0.5777F, 0.4428F}, - { 0.6121F, 0.7221F}}, - {{-0.6009F, -0.6335F}, - {-0.5159F, -0.3353F}} - } - } - }); - op->associateInput(0,myInput); - op->associateInput(1,myWeights); - op->associateInput(2,myBias); - op->setDataType(DataType::Float32); - op->setBackend("cpu"); - op->forwardDims(); - myConv->forward(); - op->getOutput(0)->print(); - REQUIRE(approxEq<float>(*(op->getOutput(0)),*myOutput, 1e-3f, 1e-4f)); + SECTION("kernel size [5,5]") { + SECTION("no stride, no dilation") { + Conv_Op<2> conv_op = Conv_Op<2>({5,5}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,3,7,7> { + {{{{-1.5819821358e+00, -1.7854875326e+00, 6.9175791740e-01, + 2.0996522903e+00, 9.6078282595e-01, -8.2813060284e-01, + 1.6968919337e-01}, + { 7.2002971172e-01, -1.8102344871e-01, -9.1400068998e-01, + 1.2503153086e+00, -7.6566308737e-01, 3.8362321258e-01, + -1.5427072346e-01}, + {-1.2361711264e+00, 7.7305638790e-01, 5.6913595647e-02, + -2.9433086514e-01, -2.4903245270e-01, 3.2472074032e-01, + 1.4886678457e+00}, + {-5.9757936001e-01, 8.0331146717e-01, -3.5785683990e-01, + 2.5472685695e-01, 5.5055916309e-01, 1.0456261635e+00, + 7.3776167631e-01}, + { 4.7463580966e-01, 4.0943801403e-01, 2.1823734045e-01, + 1.1387450248e-01, -1.2749063969e+00, -4.4165733457e-01, + -4.0056625009e-01}, + { 1.5264246464e+00, -6.9763243198e-01, 2.9963770509e-01, + -5.3752285242e-01, -1.3321673870e+00, -3.8046565652e-01, + 1.2825177908e+00}, + {-1.0958684683e+00, 8.3749586344e-01, -2.0998646319e-01, + 1.4721590281e+00, -5.6766863912e-02, -3.6367511749e-01, + -4.7073301673e-01}}, + + {{-5.6565409899e-01, 1.2466928959e+00, 1.7969487607e-01, + 7.1508967876e-01, 8.4663391113e-01, 4.8241707683e-01, + -1.1983201504e+00}, + { 8.1637918949e-02, -1.1711903811e+00, 1.1311115026e+00, + 7.1283179522e-01, 9.2079514265e-01, 2.7613803744e-01, + -1.1035090685e+00}, + {-2.1333950758e-01, -2.3015061021e-01, -1.9564050436e-01, + -5.0889742374e-01, -7.8319394588e-01, 1.0231192112e+00, + -9.3693393469e-01}, + { 1.8086779118e+00, -1.1668262482e+00, -1.6222233772e+00, + 1.9393971562e-01, 1.0758545399e+00, 7.6798349619e-01, + 9.5232710242e-02}, + { 4.4345301390e-01, 6.6509228945e-01, 3.1033048034e-01, + 9.1944271326e-01, -2.6410639286e-01, -5.5876952410e-01, + 4.4179755449e-01}, + {-5.4229873419e-01, -2.9905325174e-01, 1.0674524307e+00, + -5.4987430573e-01, 1.2865034342e+00, -2.9273465276e-01, + -1.2198672295e+00}, + { 5.6028085947e-01, 3.1508100033e-01, 1.1667277813e+00, + -1.3935315609e+00, 1.0823357105e+00, -5.4969668388e-01, + -1.0486271381e+00}}, + + {{ 4.7686126828e-01, -5.9134221077e-01, -1.9606289864e+00, + -1.6939393282e+00, -1.1419154406e+00, -1.9365699291e+00, + 5.9356534481e-01}, + { 7.7772504091e-01, -1.5844665766e+00, 4.1060101241e-02, + -1.2316554785e+00, 5.8156740665e-01, -4.3886345625e-01, + -1.6858860254e+00}, + {-2.2140502930e-01, 8.1183856726e-01, 8.9134818316e-01, + 2.3568744659e+00, -1.0616497993e+00, 1.6614040732e-01, + 2.5361647829e-02}, + { 2.4231983721e-01, -6.6797232628e-01, 6.0891377926e-01, + -1.1280845851e-01, 1.3479894400e+00, -1.0160627365e+00, + -3.1460383534e-01}, + { 4.8870971799e-01, 6.7049396038e-01, 9.0237140656e-01, + -7.7934461832e-01, -4.3115192652e-01, -4.7609877586e-01, + -1.3054919243e+00}, + {-3.8021788001e-01, -1.2455014884e-01, 4.5932388306e-01, + -8.7587934732e-01, 8.4449040890e-01, -3.2640647888e-01, + 1.1044296026e+00}, + { 5.3791737556e-01, -5.3963464499e-01, 5.8685314655e-01, + 1.0961996317e+00, 9.6712523699e-01, -1.5506522655e+00, + 5.9469139576e-01}}}, + + + {{{ 1.4693295956e+00, -7.3901087046e-01, 3.5381242633e-01, + 2.9703059793e-01, -2.4348771572e+00, 1.8592662811e+00, + -4.0466204286e-01}, + {-9.3013238907e-01, 2.0250107627e-03, 1.0819822550e+00, + 3.6489057541e-01, -2.3111122847e-01, -5.2143561840e-01, + 5.6165838242e-01}, + { 1.3814152777e-01, 2.3664423823e-01, 4.7802433372e-01, + -3.9427459240e-02, -2.2023189068e+00, -9.7024470568e-02, + 9.7239714861e-01}, + {-3.7371930480e-01, -9.4720816612e-01, -1.1839007139e+00, + -6.4284646511e-01, -1.9176955223e+00, -3.9734467864e-01, + -4.3485221267e-01}, + {-4.3412786722e-01, 5.6266564131e-01, -5.7092076540e-01, + 1.9345518351e+00, 1.1893578768e+00, 7.3246234655e-01, + -2.1017053127e+00}, + { 5.0433129072e-01, 1.0112253428e+00, 9.6384412050e-01, + 6.1708116531e-01, 2.0769470930e-01, 1.6922599077e-01, + -1.7680550814e+00}, + { 9.6860373020e-01, 6.1789232492e-01, -1.2975339890e+00, + 1.8161086738e-01, -6.1131346226e-01, 7.5746607780e-01, + 8.1318039447e-03}}, + + {{-2.9031080008e-01, -7.4931091070e-01, -8.4527724981e-01, + 1.1330122948e+00, -1.3819234371e+00, -4.6072882414e-01, + -9.6089905500e-01}, + {-6.5712004900e-02, 8.1384368241e-02, 2.2198197842e+00, + -2.9237899184e-01, -7.6382297277e-01, -2.2724790573e+00, + -2.5980213284e-01}, + { 5.4156178236e-01, 1.0466700792e+00, 1.5188544989e+00, + -1.7176572978e-01, 1.2582596540e+00, 1.9942443073e-01, + 2.8897064924e-01}, + {-9.4713824987e-01, -1.4355570078e-01, 1.4625415206e-01, + -6.9037866592e-01, -8.4468722343e-01, -6.6903305054e-01, + -4.6764537692e-01}, + {-7.0394933224e-01, 1.4981827736e+00, 1.7820091546e-01, + -6.0410326719e-01, 3.1616929173e-01, 3.3617347479e-01, + 9.6749967337e-01}, + {-1.3977780342e+00, -1.0755797625e+00, -1.2804145813e+00, + 3.8759008050e-01, -1.5185016394e+00, -1.7654757202e-01, + -6.6924899817e-01}, + { 9.1542047262e-01, -1.2073645592e+00, -1.0723612309e+00, + 5.2093189955e-01, 7.4030023813e-01, -4.8004593700e-03, + -3.8286569715e-01}}, + + {{-5.6654226035e-02, 5.2225315571e-01, 6.5561562777e-01, + 3.0832710862e-01, -1.0121858120e+00, -6.5822112560e-01, + 1.1624189615e+00}, + {-4.6575185657e-01, -5.9653341770e-02, 1.0733175278e+00, + -9.0637534857e-01, 4.2416542768e-01, 2.2279551029e+00, + 8.1080448627e-01}, + { 1.3819074631e+00, -4.1368013620e-01, -3.1706240773e-01, + -1.2126101255e+00, 3.6613631248e-01, 6.6122449934e-02, + 7.8346210718e-01}, + { 4.8448505998e-01, -1.4276403189e-01, -6.0243755579e-01, + 1.3074930906e+00, 1.4549188614e+00, 2.0044024289e-01, + 3.3380195498e-01}, + {-3.6014577746e-01, -1.3747563362e+00, -2.4885981083e+00, + 7.2047698498e-01, 5.3208362311e-02, 2.1174606681e-01, + -1.9557200372e-01}, + {-4.1237883270e-02, 8.6860567331e-01, 2.3110714555e-01, + 6.6041219234e-01, 2.2416541576e+00, -7.1505290270e-01, + 4.2335259914e-01}, + { 2.0214543343e+00, -5.1382958889e-01, 1.2030944824e+00, + 1.2382258177e+00, -5.3932261467e-01, 1.5783529282e+00, + 2.3444575071e-01}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> { + {-0.0668615103, 0.0096536176, 0.1032274216, 0.0085389474} + }); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,5,5> { + {{{{-0.0256651584, -0.0010102347, 0.0679412410, -0.1094588861, + -0.0231344569}, + { 0.0001781617, -0.0604480579, 0.0879044831, 0.0947295949, + 0.0732583329}, + { 0.0955041796, -0.0005995115, 0.0651906058, 0.0311463121, + -0.0287336204}, + { 0.0884799510, -0.0909402817, 0.0743730441, 0.0313041285, + -0.1135403663}, + {-0.0055052429, -0.0088848919, -0.0945433900, -0.0986202136, + 0.0405280553}}, + + {{-0.0124288145, 0.0954111964, 0.0848305821, -0.0119373864, + -0.0877429247}, + {-0.0569152571, -0.0907590389, 0.0886867270, -0.0892160609, + -0.0698366836}, + {-0.0321554989, -0.0878270566, 0.0040294859, 0.0923006833, + -0.0500101894}, + {-0.0134353461, 0.0157409590, 0.0994384289, -0.1078630984, + 0.1124769971}, + {-0.0414051823, -0.0653665662, 0.0932169184, -0.0028723227, + 0.1107295603}}, + + {{-0.0775901973, -0.0497837812, 0.0233222693, 0.0232233945, + 0.0641542301}, + {-0.0499767400, 0.0732239708, 0.0240622535, 0.0533199385, + 0.0095617082}, + { 0.0203404855, 0.0888396129, 0.0836194754, 0.0820384473, + -0.0988118276}, + { 0.0671891198, -0.0690774396, 0.0534351096, 0.1113787442, + 0.0424823835}, + { 0.0898227617, 0.0541929603, -0.0550950989, -0.0454236157, + -0.0338280164}}}, + + + {{{-0.0010399260, -0.0008750338, 0.0267464351, -0.0620367154, + -0.0695600584}, + { 0.0839569122, -0.0903549194, -0.0383492336, -0.1136710495, + -0.0013333842}, + { 0.0458775200, 0.0615439937, 0.0627180859, 0.0060472852, + 0.0809845850}, + { 0.0803537294, 0.0442582481, -0.0573564842, -0.0694455057, + 0.0130708050}, + { 0.0842651576, 0.0554903932, 0.0676512942, -0.0407176018, + -0.0405663215}}, + + {{-0.0050318469, -0.0642497912, 0.0112077948, -0.0386103839, + -0.0723103955}, + { 0.0766968951, 0.1154476032, 0.0771673843, -0.0140198674, + 0.0007884651}, + { 0.0696242377, -0.0228096284, -0.0823685005, 0.0049601034, + -0.0549438223}, + {-0.0744984299, -0.0340416208, -0.0165108684, 0.0726555437, + 0.0606348105}, + { 0.0737286732, 0.0553442463, -0.0405892283, 0.0447632186, + 0.0778847709}}, + + {{ 0.0277899262, -0.0676826686, 0.0787717402, -0.0618756786, + -0.0597167611}, + {-0.0498202555, 0.0110518495, 0.0751497000, -0.1083162725, + 0.0197963919}, + { 0.0177669358, 0.0064084125, -0.0998755395, 0.1101382077, + -0.0407040417}, + { 0.0109004201, -0.0155633893, 0.0020808843, 0.1043395475, + 0.0241470188}, + {-0.0832296610, 0.0045275348, -0.0467849411, 0.0846917927, + -0.0272654612}}}, + + + {{{-0.0315979160, -0.1152938455, -0.0286736861, 0.0225985274, + 0.0697309002}, + {-0.0238809530, -0.0416267440, 0.1103974208, 0.0990703627, + 0.0957340077}, + {-0.0198308472, 0.0932704657, -0.0856214613, -0.0321927778, + 0.0995117798}, + { 0.0178589821, -0.0603880435, 0.0693830997, 0.0229547676, + -0.0488339476}, + { 0.0703017265, -0.0215192195, -0.0029850313, 0.0313456170, + 0.0309588723}}, + + {{-0.0417468436, 0.0079163658, -0.0190789141, -0.0332297161, + -0.0559522100}, + { 0.0237717275, -0.0125475517, -0.0180152841, -0.0576930158, + 0.0851965323}, + { 0.0890824571, -0.0164867807, 0.0140483472, 0.0352087654, + 0.0737347007}, + {-0.0402757414, 0.0018327822, 0.0074899504, -0.0751349106, + -0.0055309837}, + {-0.0124193439, -0.0278801825, 0.1109836772, 0.0963729918, + -0.0748507902}}, + + {{ 0.0692440420, 0.0490804240, 0.0596088283, 0.0401321165, + -0.0519519635}, + { 0.0487414300, 0.0161796808, 0.1103035659, -0.0254597142, + -0.0346399099}, + { 0.0937067941, -0.0507467575, 0.0202949513, -0.1111455485, + -0.0643658787}, + {-0.0158051737, 0.0514767207, -0.1004245058, 0.0356756486, + 0.1005240306}, + {-0.0250654537, 0.0197961032, -0.0232695211, -0.0349051207, + 0.0455882438}}}, + + + {{{-0.0370168537, 0.0584056191, -0.0164116360, -0.0905687362, + -0.0273107067}, + { 0.1133346409, 0.0383886844, -0.1096662879, 0.0884051472, + -0.0351093002}, + {-0.0329314657, 0.0843854621, -0.0302542653, -0.0380227529, + -0.0992667377}, + {-0.0894243866, 0.0335155874, 0.1042508706, -0.0078598047, + -0.0247038864}, + {-0.0504006073, -0.1062115207, 0.0748218298, -0.0250077099, + 0.0837111399}}, + + {{ 0.0187441055, -0.0280515030, 0.0867191330, 0.0816542357, + 0.0727743357}, + { 0.0749907270, 0.0491249822, -0.0638230145, -0.0226818472, + -0.0130754299}, + { 0.1146481931, 0.0110172033, -0.0704981834, 0.0573508702, + 0.0468072817}, + {-0.0787231997, 0.0649248436, 0.0268039443, -0.0404104590, + -0.0502680764}, + { 0.1082614362, 0.0888254121, 0.0610329509, -0.1026768908, + 0.0593265593}}, + + {{-0.0167568102, -0.0386474803, 0.0916503966, -0.0223906469, + 0.0836452991}, + {-0.0814188346, 0.1102456823, -0.0443761721, 0.0915107206, + 0.0361284241}, + { 0.0015683545, 0.0883040577, -0.1119926274, 0.0363290384, + -0.0069464077}, + {-0.0272085574, 0.1144676507, -0.0756823123, -0.0254881941, + -0.0074769561}, + {-0.0568101220, -0.0335657224, 0.1033082232, 0.0875433162, + -0.0518697165}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,4,3,3> { + {{{{-0.3871777952, 0.5219718218, 0.7862872481}, + {-0.1375074536, 0.8885341883, 0.5779658556}, + { 0.5681430101, -0.7254346609, -0.2990708053}}, + + {{ 0.0998495445, 0.3688535392, 0.5431413054}, + {-0.1215116680, 0.4525770843, -0.7176234722}, + {-0.3100067377, 0.4953409731, -0.7239166498}}, + + {{-0.0482354760, -0.1504757106, -1.0672740936}, + { 0.2462310642, 0.3214653134, 0.2619933784}, + { 0.6370882392, 0.3474441171, 1.1181805134}}, + + {{-1.0228110552, -0.5645121336, 0.4233708680}, + { 0.6847354174, -0.2050004154, -0.1628678143}, + { 0.4284777045, 0.8981749415, -1.3061262369}}}, + + + {{{ 0.2609502375, -0.5564229488, -1.2170201540}, + {-0.6059533358, 0.0885168016, -0.5935401320}, + {-0.5572131276, 0.8337767124, 0.4323115349}}, + + {{ 0.3608099222, 1.0569038391, 0.0491272137}, + { 0.2181473076, -0.1242226511, 0.0207630899}, + { 0.5975542665, 0.2953254580, -0.2222344875}}, + + {{ 0.8388710022, -0.0150405290, 0.5986887813}, + {-0.0973970443, -0.9291419387, -0.1056552008}, + {-0.6437571645, 0.0552899837, -0.4056833386}}, + + {{-0.3430148065, 1.5227048397, -0.2973098159}, + {-0.2977932394, -0.4785656929, 0.0867543519}, + { 0.4107315838, -0.3220899701, 0.9598766565}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [2,2], no dilation") { + Conv_Op<2> conv_op = Conv_Op<2>({5,5}, {2,2}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,3,9,9> { + {{{{-1.0262345076e+00, 4.3911907077e-01, 1.8221689463e+00, + -9.0496063232e-02, 2.3124601841e+00, 1.2676076889e+00, + 2.9154178500e-01, -1.0054944754e+00, 1.4668885469e+00}, + {-2.6926884055e-01, 2.7054285631e-02, -6.5346807241e-02, + 8.3025145531e-01, -9.5455622673e-01, 4.1868144274e-01, + 1.3323990107e+00, 7.5650310516e-01, 8.5871744156e-01}, + {-6.7330420017e-01, -3.4846577048e-01, 1.4346588850e+00, + 2.7194222808e-01, 2.1521434784e+00, 1.6696736217e-01, + 1.8376970291e-01, 7.1468442678e-01, -1.2447526306e-01}, + { 8.0760914087e-01, -1.1017438173e+00, 5.1274144650e-01, + 2.2169539332e-01, -1.0916151851e-01, 1.3694372177e+00, + 1.1246484704e-02, -6.0367786884e-01, 1.1885926723e+00}, + { 1.6050251722e+00, 1.5227701664e+00, 2.2580787539e-01, + 3.3003208041e-01, -1.0569810867e+00, -6.6531229019e-01, + -1.5870132446e+00, -5.5765572935e-02, 1.2733308077e+00}, + {-2.3591136932e+00, -1.6602599621e+00, 1.3936065137e-02, + -8.6029511690e-01, 4.3284755945e-01, 9.5207196474e-01, + -5.1348048449e-01, -8.7345111370e-01, 6.0008174181e-01}, + { 1.5557394028e+00, -1.3211530447e+00, 1.1696275473e+00, + 6.4125038683e-02, -1.2432120740e-01, -5.2999585867e-01, + 6.8962436914e-01, 1.0586444139e+00, -1.5393349528e-01}, + {-4.3554434180e-01, -1.5950162709e-01, 1.0490014553e+00, + -5.0765627623e-01, -3.0467399955e-01, -5.3264838457e-01, + 9.8690450191e-01, -1.5088070631e+00, 8.9175784588e-01}, + { 1.4247367382e+00, -5.8848276734e-02, -1.0292690992e+00, + 1.0037636757e+00, -5.3313058615e-01, -8.4637176991e-01, + -2.3506495357e-01, 2.5999876857e-01, 1.0645623207e+00}}, + + {{ 1.1594262123e+00, 1.4667278528e-01, -1.3725358248e+00, + -6.8673837185e-01, -2.6230901480e-01, -1.0233192444e+00, + 1.0699504614e+00, -1.4223144948e-01, 8.3192962408e-01}, + { 1.0303292274e+00, -7.8168278933e-01, -8.6705344915e-01, + -1.4635435343e+00, 1.1048543453e+00, -5.7091850042e-01, + 1.5543711185e-01, -5.3300577402e-01, 7.5885367393e-01}, + { 3.8549888134e-01, 2.1701371670e+00, -1.8536053896e+00, + 5.6785058975e-01, -1.8138351440e+00, -1.4978441596e-01, + 6.9255560637e-01, -1.9088203087e-02, -9.3620312214e-01}, + { 7.6818950474e-02, -7.7815485001e-01, -3.6633202434e-01, + 3.7396013737e-01, 3.9735972881e-02, 5.3178119659e-01, + -6.2058432959e-03, 1.1743594408e+00, -1.3553234339e+00}, + {-1.2857575417e+00, 1.1403552294e+00, 7.6715612411e-01, + -3.8674977422e-01, 3.0915462971e-01, 1.6230952740e+00, + -1.0317113400e+00, -3.4624069929e-01, 2.3419447243e-01}, + { 2.0289704800e+00, 9.5344042778e-01, 1.3398091495e-01, + 2.3676842451e-01, -9.1961878538e-01, -7.9850316048e-02, + 1.7670296133e-01, 9.2019730806e-01, 3.2906150818e-01}, + { 1.5497472286e+00, -7.5789727271e-02, -1.5621168613e+00, + -1.3265879452e-01, 2.3210446835e+00, 6.2546260655e-02, + 1.3731292486e+00, -1.0590184927e+00, -1.1315354109e+00}, + { 9.0949285030e-01, 1.5912884474e+00, 9.7548440099e-02, + -9.5485979319e-01, 8.1781722605e-02, -8.9541131258e-01, + -9.7539001703e-01, 1.4973148108e+00, -7.3013925552e-01}, + {-2.6063349247e+00, -9.8948460817e-01, -4.6334408224e-02, + -5.6305401027e-02, 1.2275942564e+00, -1.0767419338e+00, + 3.5429549217e-01, -8.4239321947e-01, -2.3129728436e-01}}, + + {{-3.6824011803e-01, -1.3600775003e+00, -5.3082842380e-02, + 2.1362492442e-01, -3.3504801989e-01, -1.2447865009e+00, + 1.2259862423e+00, 1.2943927050e+00, 4.1017323732e-01}, + {-7.7389232814e-02, 5.6407207251e-01, -1.5415928364e+00, + 6.2723839283e-01, 6.6969829798e-01, 1.7052684724e-01, + 6.7901211977e-01, -9.1711390018e-01, 4.1649293900e-01}, + {-1.9805129766e+00, 2.6968449354e-02, 2.7286293507e+00, + 4.8898363113e-01, 1.0338652134e+00, -3.4376963973e-01, + 1.5369942784e-01, 2.1052715778e+00, 9.6360033751e-01}, + {-7.8345723450e-02, 1.7320346832e+00, 5.1241457462e-01, + -1.6069989204e+00, -8.2155573368e-01, 1.5159207582e+00, + -1.7706178427e+00, -3.5353070498e-01, -1.2306252122e-01}, + { 9.9549388885e-01, 6.6899424791e-01, 1.8666473031e-01, + 4.1127932072e-01, 1.6854909658e+00, 1.2500119209e+00, + 6.2952446938e-01, -7.8491973877e-01, -1.7457501590e-01}, + { 1.9429718256e+00, -1.9178773165e+00, -2.1454337239e-01, + 2.3576610088e+00, 1.7864210904e-01, -1.3109503984e+00, + 2.3597766459e-01, 2.8684207797e-01, 2.1074929833e-01}, + { 1.2090591192e+00, -2.1073739976e-02, -1.5082824230e+00, + -8.3251363039e-01, 1.0880084038e+00, -5.8158898354e-01, + -3.0504870415e-01, -2.9301109910e-01, 8.9690053463e-01}, + { 8.7137883902e-01, 4.2053112388e-01, -5.0221372396e-02, + -2.8163683414e-01, -1.3151681423e+00, -6.8825948238e-01, + 3.9207798243e-01, -1.0277284384e+00, 5.0744730234e-01}, + {-6.9271439314e-01, 5.3248941898e-01, 1.5569035895e-02, + 3.7575492263e-01, -8.3689816296e-02, 1.1159549952e+00, + -1.4623420238e+00, -4.3859991431e-01, 1.1101962328e+00}}}, + + + {{{-2.0423326641e-02, -4.2261308432e-01, -1.2248262167e+00, + -7.6990664005e-01, 6.8539634347e-02, -5.8175742626e-01, + 1.3995911926e-02, -2.4920943379e-01, 1.6195765138e-01}, + { 3.1753441691e-01, -4.5215657353e-01, 3.4099850059e-01, + 8.2994532585e-01, -1.4502160251e-01, 4.6974977851e-01, + -1.0577541590e+00, -3.8428103924e-01, -1.5933537483e-01}, + { 8.4931612015e-01, -1.4407234192e+00, -8.8770568371e-01, + 5.6812566519e-01, -5.3451889753e-01, -7.9881912470e-01, + 3.5436341166e-01, 6.9050423801e-02, 1.0797642469e+00}, + { 1.5346392393e+00, -8.0458652973e-01, 1.1945800781e+00, + 2.8993117809e-01, -3.4709338099e-02, -1.9005538225e+00, + 5.7719033957e-01, 9.8633068800e-01, -8.1702458858e-01}, + { 2.1732668877e+00, 9.3365108967e-01, -1.4125390053e+00, + -2.6723548770e-01, 7.6397609711e-01, -2.5253626704e-01, + 5.1223450899e-01, -1.3197456598e+00, 3.5206422210e-01}, + { 4.8497560620e-01, -5.9305703640e-01, -1.3207372427e+00, + -1.2734633684e+00, -1.8892226219e+00, -1.2254822254e+00, + -1.0012117624e+00, 4.4947278500e-01, 1.3996914029e-01}, + {-7.1806615591e-01, -2.1445353031e+00, 8.4149742126e-01, + -1.2808227539e+00, -1.4514193535e+00, 8.5352408886e-01, + 5.3722190857e-01, 1.0689587593e+00, 4.6941962838e-01}, + {-2.9689872265e-01, 5.7039666176e-01, -9.7570866346e-01, + -7.5906850398e-02, 1.9404630363e-01, -1.2686843872e+00, + 7.7697885036e-01, -9.5903050900e-01, -1.1655918360e+00}, + { 6.3129204512e-01, -3.5601881146e-01, 6.5524661541e-01, + -5.0732725859e-01, 1.3058322482e-02, -5.4524648190e-01, + 4.0899762511e-01, -9.8380684853e-01, 4.5132014155e-01}}, + + {{ 1.5725825727e-01, 2.1817979813e+00, -1.0229426622e+00, + 5.2571322769e-02, 1.7906796932e+00, 1.5359158516e+00, + -1.6435106993e+00, -3.8198566437e-01, -1.5808371305e+00}, + {-1.3824816942e+00, 7.8574791551e-02, -5.1695871353e-01, + -5.0357979536e-01, -1.1000699997e+00, 1.7837898433e-01, + 7.6670318842e-01, -1.2971758842e+00, 1.3056064844e+00}, + { 1.0061295033e+00, 2.9437178373e-01, -3.0505040288e-01, + 1.4037330151e+00, 1.5578675270e+00, -2.3277984560e-01, + 2.6896992326e-01, 2.0645604134e+00, -1.6063396931e+00}, + { 4.0633159876e-01, 4.3755510449e-01, -4.1917449236e-01, + -1.2625947595e-02, -1.3815000653e-02, 7.8905373812e-01, + 1.3740755618e-01, 1.4110846519e+00, -1.3393870592e+00}, + {-9.1028696299e-01, 9.2755252123e-01, -1.0052075386e+00, + -9.9492824078e-01, 6.7882398143e-03, 4.5161217451e-01, + 6.1087304354e-01, 1.1929332018e+00, -1.1769343615e+00}, + { 1.5572510660e-01, 4.4804865122e-01, 5.9232199192e-01, + 1.4278647900e+00, 3.1380197406e-01, -6.2812852859e-01, + -1.0075987577e+00, -5.8191227913e-01, 1.6295973063e+00}, + {-8.1633257866e-01, -3.1518262625e-01, 1.1550375223e+00, + -1.8897730112e+00, -4.2993515730e-01, -5.9119540453e-01, + 9.1979181767e-01, -1.7244141102e+00, 9.7749936581e-01}, + {-1.2826606035e+00, -8.3349563181e-02, -8.0689124763e-02, + -2.3780565262e+00, 6.4297276735e-01, -8.6600404978e-01, + -1.0059145689e+00, -4.3131682277e-01, 7.4153000116e-01}, + { 1.1657108068e+00, 1.3443695307e+00, -1.7663496733e-01, + 1.2084038258e+00, -1.0071879625e+00, -9.3671619892e-01, + 5.8391742408e-02, -9.3132650852e-01, 9.3861585855e-01}}, + + {{ 4.7928895219e-04, 1.5494421721e+00, 1.6211936474e+00, + 2.2041907310e+00, -3.2932338119e-01, -1.6941326857e+00, + 4.2259506881e-02, -3.0548694730e-01, -5.4214018583e-01}, + {-9.2042051256e-02, 3.2141461968e-01, 1.4343656301e+00, + 3.2310426235e-01, -4.1865095496e-01, 1.0167524815e+00, + 4.7122836113e-01, -1.1745406389e+00, 3.3083841205e-01}, + {-1.2731312513e+00, 1.6528505087e+00, -3.0167734623e-01, + 1.3895208836e+00, 1.4259627461e-01, 8.3795058727e-01, + -7.1655702591e-01, 1.0907325745e+00, -2.2553758621e+00}, + { 1.3037883043e+00, -4.7551321983e-01, -9.0221446753e-01, + -6.9939422607e-01, 8.2557731867e-01, 2.2087992728e-01, + -1.5934921503e+00, 3.7456196547e-01, 1.2232249975e-01}, + {-2.4604837894e+00, -7.6448154449e-01, 7.1209603548e-01, + 1.0618342161e+00, 3.1532841921e-01, -1.3785527945e+00, + -2.0559796691e-01, 5.9892934561e-01, -1.0338040590e+00}, + {-4.3527457118e-01, -2.2242832184e+00, 7.4645686150e-01, + -1.2474944592e+00, 9.7446382046e-01, 8.6570835114e-01, + -7.6936386526e-02, 1.1415704489e+00, 1.3671324253e+00}, + { 1.1514008045e-01, -4.2773807049e-01, -5.3324125707e-02, + -2.5783428550e-01, -1.0264251232e+00, -8.0453145504e-01, + -2.0028649271e-01, -8.4552615881e-03, -5.6716823578e-01}, + { 4.6886390448e-01, 5.2331888676e-01, 1.3241011649e-02, + 9.7727668285e-01, 3.6741080880e-01, 7.4033015966e-01, + -1.4176219702e-01, -1.7566013336e+00, 1.8248949945e-01}, + { 3.4004274011e-01, -3.9557147026e-01, -1.8804949522e+00, + 3.9474415779e-01, -7.9836857319e-01, -3.9796181023e-02, + -1.3347951174e+00, 1.0292435884e+00, 8.2486397028e-01}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> { + {-0.0256307721, -0.0979429483, -0.0857793465, 0.1046720892} + }); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,5,5> { + {{{{ 7.9161055386e-02, 1.0865640640e-01, -8.4078900516e-02, + -8.7769351900e-02, 5.5054081604e-03}, + {-5.2150800824e-02, -8.6442008615e-02, 6.6387809813e-02, + 7.4080288410e-02, -8.6096972227e-02}, + { 6.5568581223e-02, 1.9050860777e-02, -6.8354956806e-02, + -1.1355084926e-01, 1.0479265451e-01}, + { 8.2975491881e-02, 9.2087492347e-02, 2.0702988841e-03, + 1.1016045511e-01, -1.8409077078e-02}, + { 8.0884704366e-03, -6.1311736703e-02, -1.0344123840e-01, + -5.7280212641e-02, 9.0915616602e-03}}, + + {{-9.9937595427e-02, -7.4386410415e-02, -6.6606923938e-02, + -6.9912821054e-03, -5.1879696548e-02}, + { 3.3732347190e-02, 7.3814824224e-02, -6.5518431365e-02, + 1.1177737266e-01, 6.4293742180e-02}, + { 1.7644548789e-02, -8.2394741476e-02, -9.5198350027e-03, + -2.8538773768e-03, -5.5996231735e-02}, + {-9.7051627934e-02, 8.5470348597e-02, -5.2727516741e-02, + 5.3426343948e-02, -1.0500780493e-01}, + { 3.5254769027e-03, 2.5447849184e-02, 9.5501638949e-02, + 7.0650175214e-02, -1.1365779489e-01}}, + + {{ 2.2010399334e-05, 1.1450178921e-01, 2.8901636600e-02, + 5.7596616447e-02, 3.7809558213e-02}, + {-1.1890708469e-02, -1.1361743510e-01, -5.3352948278e-02, + -7.3011368513e-02, 9.1052189469e-02}, + { 1.1218705028e-01, 7.7470839024e-02, -3.3929269761e-02, + 3.4500412643e-02, 9.8039925098e-02}, + { 6.3609093428e-02, -2.8273850679e-02, -1.0159312189e-01, + -4.9110885710e-02, -8.1224292517e-02}, + {-1.6401037574e-02, 2.3994818330e-02, -2.5938203558e-02, + 7.3313489556e-03, 9.8718859255e-02}}}, + + + {{{ 3.4335671808e-04, 3.6441650242e-02, 3.0187530443e-02, + 6.0592081398e-02, -2.0527897403e-02}, + { 1.1346739531e-01, -3.2229032367e-02, -2.2989841178e-02, + 6.0040432960e-02, 4.8826828599e-02}, + {-6.0838893056e-02, -6.0655072331e-02, -8.7084382772e-02, + 1.0714148730e-01, 1.0812971741e-01}, + {-4.2028987082e-04, 8.5651963949e-02, 8.0970667303e-02, + -2.3986544460e-02, -1.8859704724e-03}, + {-6.6010192037e-02, 9.4307392836e-02, 1.1242634058e-01, + -8.5995316505e-02, -2.7027517557e-02}}, + + {{ 1.1397475004e-01, 6.3374929130e-02, -7.7009305358e-02, + -8.3151273429e-03, 8.2951277494e-02}, + { 9.8664939404e-02, -8.1143163145e-02, 3.9020900149e-03, + -1.4049283229e-02, -8.3722546697e-02}, + { 4.7419264913e-02, -1.6121247783e-02, -3.9537753910e-02, + 3.4721549600e-02, -6.9126158953e-02}, + {-2.2569455206e-02, 3.0018799007e-03, -1.1006606370e-01, + 7.5041048229e-02, 1.0107534379e-01}, + {-6.8557247519e-02, -3.2554015517e-02, -1.0497659445e-01, + 6.0661323369e-02, -4.7699000686e-02}}, + + {{ 3.5086035728e-02, -4.3413480744e-03, 2.1944919601e-02, + -3.3144496381e-02, -8.7255813181e-02}, + {-1.7425794154e-02, 6.0090426356e-02, 5.0702422857e-02, + 1.0287687927e-01, -1.0552648455e-01}, + { 7.2182372212e-02, -7.6782293618e-03, 1.6422374174e-02, + -5.1674857736e-02, -7.1635149419e-02}, + { 7.0604115725e-02, -6.8249620497e-02, -8.9388333261e-02, + -5.1136296242e-02, 6.3407994807e-02}, + { 6.1435334384e-02, 9.4730198383e-02, 8.3181746304e-02, + 5.7519987226e-02, 7.0465311408e-02}}}, + + + {{{ 1.5252918005e-02, 7.4633888900e-02, 9.0553842485e-02, + 3.8993149996e-02, 6.3962623477e-02}, + { 9.6443951130e-02, -5.6319754571e-02, -2.0676823333e-02, + 1.0050912201e-01, -1.0454320349e-02}, + { 7.1146171540e-03, -9.4763293862e-02, -1.1824182235e-02, + -8.2581162453e-02, 8.5433647037e-02}, + { 1.5876146033e-02, 1.0734396428e-01, -1.1871671304e-02, + 1.0982066393e-01, 1.5651857480e-02}, + { 3.9369460195e-02, -6.0143850744e-02, 1.0505072027e-01, + -8.4761910141e-02, 5.0331920385e-02}}, + + {{ 1.0066681542e-02, -1.0608792305e-01, 3.1187359244e-02, + -1.1222343892e-02, 1.2503461912e-02}, + {-9.5613434911e-02, -2.0962793380e-02, 6.2319990247e-02, + 3.6858320236e-02, 1.0154407471e-01}, + {-5.7882230729e-02, 4.0998354554e-02, 2.2802127525e-02, + -8.7465390563e-02, -7.8975915909e-02}, + {-1.7631363124e-02, -7.0353029296e-03, -1.0995418578e-01, + 2.3209381849e-02, 5.4516538978e-02}, + {-6.7430905998e-02, -3.2775081694e-02, -7.1572959423e-02, + 3.0015124008e-02, -8.2214266062e-02}}, + + {{ 7.8755617142e-02, 3.5310130566e-02, -2.2435920313e-02, + 9.6409514546e-02, 1.0338477045e-01}, + { 5.9903923422e-02, 1.8227061257e-02, 1.9898558035e-02, + 1.7521159723e-02, 2.3488936946e-02}, + {-9.0055637062e-02, 4.9969940446e-03, 5.6411098689e-02, + -4.5843422413e-02, 3.5857871175e-02}, + {-1.9720340148e-02, 1.4090083539e-02, -7.7931761742e-02, + 1.5021140687e-02, 6.3057549298e-02}, + {-3.7665259093e-02, 7.1552298963e-02, -5.8394841850e-02, + 8.0049857497e-02, -1.0659194738e-01}}}, + + + {{{-7.7228933573e-02, 5.1064457744e-02, -5.6017566472e-02, + 1.1644850485e-02, 1.6073303297e-02}, + {-3.2181475312e-02, 8.2790434361e-02, -1.1385639757e-01, + -1.1087367684e-01, 2.1422423422e-02}, + {-2.7187412605e-02, 7.9632617533e-02, 9.0410903096e-02, + 2.9836246744e-02, 7.0767119527e-02}, + { 4.1806723922e-02, 8.9509122074e-02, -7.6765663922e-02, + -4.1281420738e-02, 4.2842119932e-02}, + { 8.1620868295e-03, -6.4203604124e-03, 1.9168345258e-02, + 9.7423560917e-02, 8.7407097220e-02}}, + + {{ 5.9115942568e-02, -5.4798915982e-02, 7.6376385987e-02, + 2.8852632269e-02, -3.4071631730e-02}, + { 4.3502859771e-02, 6.9351151586e-02, -7.3672071099e-02, + -1.3623046689e-02, 9.9386192858e-02}, + { 8.2298137248e-02, 3.1567070633e-02, -5.2173703909e-02, + 1.0301522166e-01, -2.8413718566e-02}, + {-4.0537059307e-02, 2.9644003138e-02, 3.0859379098e-02, + 1.1227397621e-01, -5.6883804500e-02}, + {-1.0370600969e-01, -2.3154499009e-02, -7.5694397092e-02, + 7.3324322701e-02, -7.2433322668e-02}}, + + {{-9.0371906757e-02, 1.1529860646e-01, -6.1082314700e-02, + 1.0215951502e-01, 1.1091886461e-01}, + { 1.0044538975e-01, -8.9651204646e-02, -1.0203760862e-01, + -1.0518372059e-01, 1.0253529996e-01}, + { 1.1435522884e-01, -7.3617205024e-02, 6.6930335015e-03, + -4.1806902736e-02, -4.3869618326e-02}, + { 6.4512699842e-02, -6.9700092077e-02, 7.6211534441e-02, + 4.5039333403e-02, 6.5938614309e-02}, + {-8.8014036417e-02, 2.7664732188e-02, -3.6620914936e-02, + 6.0766555369e-02, -1.6977010295e-02}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,4,3,3> { + {{{{-0.3815447390, 0.5393404961, 1.2082489729}, + {-0.7000376582, 0.5649882555, 0.6821430922}, + { 0.7830977440, 0.2265645415, -0.0234967973}}, + + {{ 0.5304207802, 0.3521062732, -0.4996214509}, + {-0.5668326020, -0.1662681103, -0.0246974919}, + {-0.4764020443, -0.6305201650, 0.1680680513}}, + + {{ 0.3380682766, 0.2957296073, 0.2884519398}, + {-0.6942415833, 0.2494415045, -0.1550499499}, + {-0.6718355417, 0.3543845117, -0.4612499774}}, + + {{-0.0101785604, -0.0853718519, -0.0978565961}, + {-0.4967738688, -1.4151864052, 0.4521864057}, + { 0.6482952833, -0.4406591058, 0.5528392196}}}, + + + {{{ 0.2647683024, -0.0899230987, 0.0251224432}, + { 0.2060495466, -0.5142026544, -0.6169960499}, + { 0.8739010096, -0.4528320730, -0.6620424986}}, + + {{ 0.4344040155, -0.3720431328, 0.3896343112}, + {-0.3978653252, -0.9195920229, -0.1681326628}, + {-1.1699988842, -0.5111203790, -0.2178955674}}, + + {{-0.2299620360, 0.0700573027, -0.2049577832}, + { 0.3746963441, -0.9416288137, 0.0695008487}, + { 0.5960077643, -0.2242757827, 0.0868191570}}, + + {{-0.4896622002, 0.5204730630, 0.3560196757}, + { 0.3295116723, 0.6170410514, 0.1678314358}, + { 0.1554045975, 0.2489441931, -0.3163521886}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-8f)); + } + SECTION("stride [2,2], dilation [2,2]") { + Conv_Op<2> conv_op = Conv_Op<2>({5,5}, {2,2}, {2,2}); + std::shared_ptr<Tensor> myInput = std::make_shared<Tensor>(Array4D<float,2,3,9,9> { + {{{{ 6.5991479158e-01, -5.3738802671e-01, 1.6696882248e+00, + -1.1178103685e+00, -2.0366449356e+00, 7.6323038340e-01, + 8.8008441031e-02, -5.7445436716e-01, -1.1746516228e+00}, + { 1.1325144768e+00, 1.8485877514e+00, 6.6733115911e-01, + -1.1192236841e-01, -7.6075768471e-01, 9.3889570236e-01, + 1.1970713139e+00, -1.1396632344e-01, 3.7702596188e-01}, + { 1.0176507235e+00, -3.2929670811e-01, 2.2804117203e+00, + 2.6695098877e+00, 9.1759788990e-01, -7.7024054527e-01, + -1.1899918318e+00, -3.5644110292e-02, 1.4563606977e+00}, + { 1.3222569227e-01, 1.7781604528e+00, -5.3485935926e-01, + -1.1690793559e-02, -3.1795254350e-01, 1.3168191910e+00, + -9.8241758347e-01, -1.9394657016e-01, -2.6476421952e-01}, + { 1.7474385500e+00, -5.1397854090e-01, 4.1304263473e-01, + -2.2527515888e-01, 1.1935096979e+00, -1.8106347322e-01, + 1.7698837519e+00, -1.1245247126e+00, -8.0709517002e-01}, + {-6.6057407856e-01, 5.4547226429e-01, -4.9604213238e-01, + -1.5724540949e+00, 7.3046660423e-01, -1.8622325659e+00, + -1.7612577677e+00, 2.7059727907e-01, 2.0769243240e+00}, + { 1.6489045620e+00, -9.1659523547e-02, 6.5829491615e-01, + 9.5888656378e-01, -9.7225487232e-01, 9.9177438021e-01, + 4.1120108962e-01, 1.7129123211e+00, -1.3064719737e-01}, + { 5.6299138069e-01, 1.1197557449e+00, 1.0256431103e+00, + -1.0448687077e+00, -3.9633819461e-01, 2.1613819599e+00, + -2.9366123676e-01, 2.0935084820e+00, -3.1408703327e-01}, + {-2.2087215912e-03, -5.5481916666e-01, -1.0586282015e+00, + 1.6510184109e-01, -5.3518980742e-01, -1.5306407213e+00, + -1.4912575483e+00, 4.6741631627e-01, 1.6276098490e+00}}, + + {{-1.2711778879e+00, -6.6529834270e-01, 2.0430049896e+00, + 1.3407748938e+00, 1.5101557970e+00, 3.0264301300e+00, + 5.7267320156e-01, 1.9472989440e-01, -1.0449569672e-01}, + {-4.3861621618e-01, -8.4084004164e-01, 1.8874751031e-01, + -3.0964607000e-01, -3.3041627407e+00, 4.0943336487e-01, + -5.3273528814e-01, 1.1388880014e+00, 4.4220641255e-01}, + {-1.8995233774e+00, 2.4473433197e-01, -4.1401520371e-01, + -6.5818083286e-01, -1.1139613390e+00, 2.1693031490e-01, + -1.0517214537e+00, -2.3312714100e+00, 6.0954615474e-02}, + { 1.3127720915e-02, -2.2521468997e-01, -2.5984519720e-01, + 1.5528632700e-01, -4.9426975846e-01, -1.1347863674e+00, + 1.1981898546e-01, 2.3249061108e+00, -4.2492222786e-01}, + { 2.5971227884e-01, 6.9073438644e-02, -1.2523316145e+00, + -1.8091107905e-01, -1.7139790952e-01, 8.1327247620e-01, + -6.7866450548e-01, 2.2402961254e+00, -8.5352472961e-02}, + { 7.1751189232e-01, -2.3494932055e-01, -1.3409119844e+00, + 5.2470743656e-01, 7.6781928539e-01, -7.1144473553e-01, + -1.9754718542e+00, 2.6893837452e+00, 7.8437983990e-01}, + {-4.0214532614e-01, 1.7369346619e+00, -1.7632387578e-01, + -1.5825942755e+00, 1.0516833067e+00, 2.0817482471e+00, + -9.7633296251e-01, 7.8872179985e-01, -4.3127769232e-01}, + { 9.7235912085e-01, -5.8469034731e-02, 2.6687437296e-01, + -1.4018902779e+00, 1.2706108093e+00, 5.8360731602e-01, + -1.2217177153e+00, -8.2037007809e-01, -6.4826738834e-01}, + { 3.8622283936e-01, 1.1064618826e+00, -1.5179945230e+00, + 3.9867460728e-01, 8.5346035659e-02, -1.1623222828e+00, + 9.5119558275e-02, 8.7334537506e-01, 1.0425381660e+00}}, + + {{ 7.8875219822e-01, -1.1821738482e+00, 3.8991808891e-01, + -1.0048108101e+00, -2.0707476139e-01, -7.3082458973e-01, + 1.0729664564e+00, -1.2859574556e+00, -8.7584072351e-01}, + { 3.7907764316e-01, 1.1199241877e+00, 1.9296696186e+00, + 1.2730616331e-01, 6.0980606079e-01, -7.3303855956e-02, + -5.9152889252e-01, -1.2527221441e-01, 1.0408999920e+00}, + {-4.4774639606e-01, -8.6148458719e-01, 1.4992856979e+00, + 1.3516107798e+00, -1.1647412777e+00, 1.3260208368e+00, + -5.3640037775e-01, 8.5038894415e-01, 4.0011933446e-01}, + {-7.2440391779e-01, 6.6310149431e-01, -9.2786878347e-01, + 1.4332934618e+00, -1.2407248020e+00, -1.9271074235e-01, + 5.4011422396e-01, -2.0801360905e-01, 1.1701456308e+00}, + {-1.5601218939e+00, -1.3747200966e+00, 2.4702382088e-01, + 4.4083452225e-01, -6.4500576258e-01, -7.6838213205e-01, + 5.3242987394e-01, 2.4775697291e-01, -1.4102922678e+00}, + {-9.1009324789e-01, -8.6870056391e-01, -7.6603555679e-01, + -1.4323582649e+00, 1.5076324344e-01, -6.6071575880e-01, + 6.5643036366e-01, 7.0738911629e-01, -1.2404441833e+00}, + { 6.5142494440e-01, 2.4921335280e-02, 5.0163829327e-01, + 1.3338588476e+00, 1.7744785547e+00, -6.3132202625e-01, + 8.1679749489e-01, -4.4332244992e-01, -1.3621041775e+00}, + { 6.9176673889e-01, -6.8686300516e-01, 1.0556088686e+00, + 1.1115324497e+00, -3.8817191124e-01, -6.3901716471e-01, + 7.4065899849e-01, 5.2005118132e-01, 4.8783886433e-01}, + {-4.8411735892e-01, 8.3703887463e-01, 5.9305649996e-01, + 1.3562313318e+00, -7.5646054745e-01, -4.2536877096e-02, + 1.7571094036e+00, -4.7270351648e-01, 2.2838380337e+00}}}, + + + {{{-4.0447491407e-01, 6.2086683512e-01, 1.2934124470e+00, + 1.0094794035e+00, -1.0171808004e+00, -2.6295968890e-01, + -4.9549680948e-01, -7.9358913004e-02, 9.4110012054e-02}, + { 8.7407690287e-01, 4.2964717746e-01, -1.7619720697e+00, + -3.6411872506e-01, -1.7255870998e-01, 4.4035488367e-01, + 1.6617731750e-01, 3.5132259130e-01, -4.4415783882e-01}, + {-1.7608845234e+00, -2.3283758759e-01, 6.2204450369e-01, + 4.5604482293e-01, 4.8442709446e-01, -1.2630403042e+00, + -1.0600868613e-02, -9.3579161167e-01, 7.2725065053e-02}, + {-4.9674937129e-01, 2.7484998107e-01, -7.8241840005e-02, + -3.8036200404e-01, -1.1722209454e+00, 1.5537194014e+00, + 3.3076366782e-01, -8.7499739602e-03, 2.0694589615e+00}, + {-1.3299342394e+00, -1.5577025712e-01, 1.6106146574e+00, + 5.3044539690e-01, -1.2436783314e+00, -7.8046637774e-01, + -9.2501389980e-01, 1.5277367830e+00, 3.6043179035e-01}, + {-9.9188965559e-01, -1.2423185110e+00, 3.9069399238e-01, + -1.2723475695e+00, -1.7772109509e+00, -3.7175655365e-01, + -8.7014752626e-01, 1.3463658094e+00, -1.5951974392e+00}, + {-7.8578054905e-01, -6.2972821295e-02, -4.8052150011e-01, + -1.2783004045e+00, -3.8468798995e-01, 4.7666281462e-02, + 4.2015764117e-01, -4.4800898433e-01, 3.9750581980e-01}, + { 1.2391144037e+00, 4.4924438000e-01, -5.7612675428e-01, + -1.2152553797e+00, 6.7230182886e-01, 1.3609430790e+00, + -4.2446309328e-01, 3.0986270308e-01, 3.6792102456e-01}, + { 1.4776864648e-01, -9.7534912825e-01, -1.9648849964e-01, + -1.0378727913e+00, 2.5092484429e-02, 8.9258450270e-01, + -2.2762279212e-01, -2.3942720890e+00, 7.9677361250e-01}}, + + {{-4.4367963076e-01, 9.8137008026e-03, 1.6468089819e+00, + 3.9348766208e-01, 3.7895750999e-01, -1.8910832405e+00, + 1.6934220791e+00, -5.1142543554e-01, 2.1927893162e+00}, + {-1.2872399092e+00, 5.1995629072e-01, 2.8462198377e-01, + -7.7300745249e-01, 6.1586141586e-02, 7.9627609253e-01, + 5.2585881948e-01, 2.0059709251e-01, -1.0767682791e+00}, + { 1.2913355827e+00, -5.2280706167e-01, 9.3896692991e-01, + -2.7119943500e-01, -1.3428537548e-01, -7.7558577061e-02, + -1.4985687733e+00, 1.5150824785e+00, -1.3824665546e+00}, + {-3.4071408212e-02, -7.0768481493e-01, 3.9081773162e-01, + -1.0144554377e+00, -1.2199249268e+00, 9.7416710854e-01, + -2.1364924908e+00, -7.5508749485e-01, -1.3795818090e+00}, + { 4.5370283723e-01, -1.7424255610e+00, -8.5776680708e-01, + 3.9504718781e-01, 9.9192768335e-02, 7.1981537342e-01, + 2.7460846305e-01, -8.1848166883e-02, 7.6311039925e-01}, + { 1.6829998791e-01, 2.8629219532e-01, -1.3655959070e-01, + -1.2729966640e+00, 2.9406669736e-01, -3.6713847518e-01, + 6.3521367311e-01, 9.0642973781e-02, 4.6122816205e-01}, + {-5.9019500017e-01, 6.1684101820e-01, 5.7554990053e-01, + -7.1885848045e-01, 1.5339116752e-01, 7.2704249620e-01, + -1.1901499033e+00, 1.8046575785e-01, -3.2128947973e-01}, + { 6.9699871540e-01, -1.5316461325e+00, -1.0008054972e+00, + 1.8971544504e+00, 1.6860273480e-01, 4.6585604548e-01, + -7.4088859558e-01, -2.0486815274e-01, 2.4802033603e-01}, + {-8.8578667492e-03, -8.0224162340e-01, 1.5357034206e+00, + 1.2365963459e+00, 1.4597702026e+00, -5.4030877352e-01, + 7.9093635082e-01, 1.1919885874e+00, -1.9415197372e+00}}, + + {{ 1.6230100393e-01, 1.7142108679e+00, 1.5414776802e+00, + -4.2192405462e-01, 4.9785825610e-01, 2.1395962238e+00, + 9.2708784342e-01, -8.3940023184e-01, -8.0437123775e-01}, + {-9.4176328182e-01, 2.6041597128e-01, -1.0130367279e+00, + -3.5772189498e-01, 1.6592922211e+00, 1.9243527651e+00, + 1.4461495876e+00, 1.2969638109e+00, 2.9279315472e+00}, + { 5.7384677231e-02, -6.6253073514e-02, -7.2724334896e-02, + -2.3743975163e-01, 9.5880138874e-01, 3.6361989379e-01, + 1.2075768709e+00, 5.1945459843e-01, -2.4960200787e+00}, + { 1.5223011971e+00, 6.6761517525e-01, -7.1185566485e-02, + 1.3005328178e+00, -1.6010546684e+00, -8.3948358893e-02, + 1.3929320872e-01, 5.7007002831e-01, 1.5402120352e+00}, + {-1.0000891685e+00, 5.6669050455e-01, 1.1230304241e+00, + -9.2030251026e-01, -2.2001121044e+00, 7.8229683638e-01, + -3.0678564310e-01, -1.7156904936e-01, -8.6419695616e-01}, + { 8.8148981333e-01, -1.7107343674e+00, -2.9174503684e-01, + 7.9814374447e-01, -9.9373269081e-01, -5.2981477231e-02, + -1.0876508951e+00, -4.6575244516e-02, 8.7985855341e-01}, + {-2.6840454340e-01, 7.7923542261e-01, 9.3854445219e-01, + 4.0857732296e-01, 5.7850652933e-01, 1.4003618956e+00, + -7.1249789000e-01, -3.1672206521e-01, 1.6309374571e-01}, + { 8.0395114422e-01, 5.8513361216e-01, 1.3350075483e+00, + 7.1663349867e-01, 7.2658437490e-01, 4.9841433764e-01, + -1.7066024542e+00, -8.7015740573e-02, -3.9591795206e-01}, + {-2.6701366901e-01, -2.9452782869e-01, 2.2669389844e-02, + -8.6450153589e-01, -4.1996881366e-01, -5.4280841351e-01, + 1.0260531902e+00, -1.0119054317e+00, -5.9093588591e-01}}}} + }); + std::shared_ptr<Tensor> myBiases = std::make_shared<Tensor>(Array1D<float,4> { + { 0.0219541416, -0.0854099169, 0.0740336627, 0.0793448612} + }); + std::shared_ptr<Tensor> myWeights = std::make_shared<Tensor>(Array4D<float,4,3,5,5> { + {{{{-0.0879265219, -0.0899311453, 0.1075298190, -0.0439606048, + 0.0142062334}, + {-0.0589592122, -0.0045246030, -0.0316865370, 0.0141878976, + 0.1153501272}, + { 0.0643049031, -0.0598805100, -0.0832074434, 0.0077776825, + 0.0623578280}, + {-0.0301217884, -0.0073744338, 0.1121563762, -0.0863516033, + 0.0386641659}, + { 0.0234115086, -0.0829083994, 0.0418176949, -0.0570392683, + -0.0580074638}}, + + {{-0.0012715239, 0.0863351673, -0.1069950312, 0.0660321191, + -0.0725978017}, + { 0.0047315061, -0.0222416110, 0.0019143816, 0.0152916117, + -0.0825890377}, + { 0.1134350598, 0.0128993327, -0.0935105979, 0.0104216831, + 0.0318991244}, + {-0.0340792425, 0.1136758402, -0.0319192223, 0.0048904521, + -0.0254838988}, + { 0.0453651547, -0.1142466366, 0.0613734871, 0.0306740720, + 0.0760814846}}, + + {{ 0.0366517194, 0.0485932752, -0.0874269605, -0.0775868669, + -0.0356033146}, + { 0.1102349758, 0.0213746168, -0.0696568191, 0.0301248170, + 0.0069745299}, + {-0.0293842964, -0.0122400951, 0.0375012197, -0.1118148938, + -0.0948114246}, + {-0.1078331321, -0.0746667907, -0.0589634664, -0.0510230660, + 0.0664016604}, + {-0.0519216098, 0.0763967782, 0.0384426974, -0.0007680102, + 0.1003586799}}}, + + + {{{-0.0771117881, -0.0127118658, -0.1025340930, -0.0193717945, + -0.0718696415}, + { 0.1134984195, -0.0469665602, -0.0328875706, -0.0890083611, + 0.0217132103}, + {-0.0783405602, -0.0964497253, 0.0073435861, -0.0290560126, + -0.0307252090}, + { 0.0957401842, 0.0746279880, -0.0221166238, -0.1134141684, + 0.0168936905}, + { 0.0833431259, 0.0800427124, -0.0428234823, 0.0058708852, + 0.0055020354}}, + + {{-0.0619548559, 0.0440408550, -0.0580734424, -0.0271147881, + 0.0551881120}, + {-0.1015950069, -0.0655148402, -0.0656934455, 0.0734511092, + 0.0593257882}, + {-0.0388379470, 0.0538616925, 0.0215578172, 0.1115155891, + 0.0267907996}, + {-0.0117263123, 0.1013097093, -0.0503486842, -0.0227387249, + 0.0769604445}, + {-0.0400040187, -0.0017201286, 0.0305580869, -0.1087302193, + -0.0778466389}}, + + {{ 0.0157795977, 0.0284815803, -0.0283647832, -0.0756585225, + 0.0766027272}, + { 0.0176657196, -0.1124993339, -0.0154858092, -0.0368758999, + 0.0100407479}, + {-0.0125693697, 0.0512169749, -0.0256510209, -0.0971343294, + -0.0872697607}, + {-0.0426635183, 0.0547962859, -0.0496184044, 0.0890550837, + 0.1007452309}, + {-0.1043196693, -0.0133433538, -0.0131574012, 0.0442749150, + 0.0401787795}}}, + + + {{{-0.1098219156, 0.0145482961, -0.0767832026, 0.0287463516, + 0.0936923251}, + { 0.0129374759, -0.0915895551, 0.0694310442, 0.0978608951, + -0.0756938607}, + { 0.0162203833, -0.0620732345, -0.0158150289, -0.0646957755, + -0.0085407924}, + {-0.0168146919, -0.0887613446, -0.0721658245, 0.0921881124, + 0.0079541644}, + {-0.1055120081, -0.0643930957, -0.0260313656, -0.0003582919, + 0.0954318866}}, + + {{ 0.0811082423, 0.1095830873, -0.0475429185, -0.0180855002, + 0.0421846844}, + {-0.0270713326, -0.0276994482, -0.0893911272, -0.0372085199, + 0.0332398191}, + {-0.0295267235, 0.1145598739, 0.0082155224, 0.0932523906, + 0.0545750260}, + { 0.0845445022, 0.0105949445, 0.0310290195, -0.0396258235, + 0.0049864636}, + {-0.0088399630, 0.0189545881, 0.0030256936, -0.1071743071, + 0.0308798887}}, + + {{-0.0736748204, 0.1093047708, 0.0833086222, 0.0749989003, + 0.0896537676}, + { 0.0174796991, -0.0201664530, -0.1130384058, 0.0199203752, + 0.0047632209}, + {-0.0784958825, 0.0911915898, -0.1046215370, 0.0936999246, + -0.0141570913}, + {-0.0699632838, -0.1020562500, 0.0693636611, 0.0441951603, + -0.1042146608}, + {-0.0897461325, -0.0261580031, 0.0266584475, -0.0242327619, + -0.0091426708}}}, + + + {{{-0.0294735078, -0.0490118340, -0.0102116829, -0.0089236405, + 0.0638454780}, + { 0.0318225212, -0.0750555843, -0.0006857774, -0.0149815530, + 0.0697304457}, + { 0.0654041991, -0.0882468224, 0.1105226800, 0.0033364408, + -0.0604107007}, + {-0.0374080427, 0.0036495004, 0.0383783057, -0.0263143200, + 0.1003912762}, + { 0.0507353209, 0.0155910300, 0.1017192900, 0.0231544450, + -0.1084882244}}, + + {{-0.0090856012, -0.0697057769, 0.0872633979, 0.0332706533, + -0.0800112858}, + {-0.0645161867, -0.0415929221, -0.0418168269, -0.0703653470, + -0.0586728156}, + {-0.0439705029, 0.1082400829, 0.0359866321, -0.0057579288, + 0.0186352655}, + { 0.0343357958, 0.0320450775, -0.0980449989, -0.0822830945, + -0.0560547598}, + {-0.0906760171, -0.0421256870, 0.0518113375, 0.0785672292, + 0.0198612548}}, + + {{ 0.1009553447, 0.0794672221, 0.0007000518, -0.0483656302, + 0.0281033702}, + {-0.0185741764, 0.0940017775, 0.0855893716, 0.1091704220, + -0.0968915299}, + { 0.0257219113, -0.1015173718, -0.1027367935, -0.0880665779, + -0.0726886243}, + { 0.0171099678, -0.0688400492, -0.0629827529, -0.0427498519, + 0.0592775978}, + { 0.0118976049, -0.0184275638, 0.0676623732, -0.0120042292, + 0.0371227749}}}} + }); + std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<float,2,4,1,1> { + {{{{-0.1854654104}}, + {{ 0.0502447225}}, + {{ 0.3770641685}}, + {{-0.2695779204}}}, + + {{{-0.1893050075}}, + {{-0.3626379371}}, + {{ 0.7148165107}}, + {{-0.0087520313}}}} + }); + conv_op.associateInput(0, myInput); + conv_op.associateInput(1, myWeights); + conv_op.associateInput(2, myBiases); + conv_op.setBackend("cpu"); + conv_op.setDataType(DataType::Float32); + conv_op.forwardDims(); + conv_op.forward(); + REQUIRE(approxEq<float>(*(conv_op.getOutput(0)),*expectedOutput, 1e-5f, 1e-6f)); + } } } \ No newline at end of file