Skip to content
Snippets Groups Projects
Commit e7eec22f authored by Maxence Naud's avatar Maxence Naud Committed by Maxence Naud
Browse files

improve conv kernel speed

parent 2a745b7d
No related branches found
No related tags found
1 merge request!101Upd 2D Conv[DepthWise] kernels
This commit is part of merge request !101. Comments created here will be created in the context of that merge request.
...@@ -157,57 +157,52 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, ...@@ -157,57 +157,52 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
// input (batch, inCh, Xin, Yin) // input (batch, inCh, Xin, Yin)
// weight (outCh, inCh, kernelX, kernelY) // weight (outCh, inCh, kernelX, kernelY)
// does not take Dilation attribute into account // does not take Dilation attribute into account
using signedsize = std::make_signed<std::size_t>::type; const std::size_t outChannels_s = oxSize * oySize;
for (std::size_t batch = 0; batch < inputDims[0]; ++batch) { for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
const std::size_t oIndex = (outCh + batch*outChannels) * oxSize * oySize;
// If bias = nullptr, set B(0) // If bias = nullptr, set B(0)
B biasVal = (biases != nullptr) ? biases[outCh] : B(0); B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
std::fill(output + oIndex, output+(oIndex+oxSize*oySize), biasVal); std::fill(output, output+outChannels_s, biasVal);
for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1]; const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1];
for (std::size_t ox = 0; ox < oxSize; ++ox) { for (std::size_t ox = 0; ox < oxSize; ++ox) {
// const signedsize difx = static_cast<signedsize>(- ox * strideDims[0]);
// const std::size_t sxMin = static_cast<std::size_t>(std::max(difx, signedsize(0)));
// const std::size_t sxMax = (static_cast<signedsize>(inputDims[2]) + difx) < 0 ? 0 : ((inputDims[2] + difx) > kernelDims[0] ? kernelDims[0] : inputDims[2] + difx);
const std::size_t sxMin = 0;
const std::size_t sxMax = dilated_kernel_x;
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
// const signedsize dify = static_cast<signedsize>(- oy * strideDims[1]);
// const std::size_t syMin = static_cast<std::size_t>(std::max(dify, signedsize(0))); const std::size_t oIndexFull = ox*oySize + oy;
// const std::size_t syMax = (static_cast<signedsize>(inputDims[3]) + dify) < 0 ? 0 : ((inputDims[3] + dify) > kernelDims[1] ? kernelDims[1] : inputDims[3] + dify); const size_t ix = ox * strideDims[0];
const std::size_t syMin = 0; const size_t iy = oy * strideDims[1];
const std::size_t syMax = dilated_kernel_y;
const std::size_t oIndexFull = oIndex + ox*oySize + oy; if (kernelDims[0] == 3 && kernelDims[1] == 3 && dilationDims[0] == 1 && dilationDims[1] == 1) {
const signedsize ix = static_cast<signedsize>(ox * strideDims[0]); output[oIndexFull] += (weights[wIndex] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy)] +
const signedsize iy = static_cast<signedsize>(oy * strideDims[1]); weights[wIndex + 1] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + 2] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
if (sxMin == 0 && syMin == 0 && sxMax == 3 && syMax == 3) { weights[wIndex + kernelDims[1]] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy)] +
output[oIndexFull] += (weights[wIndex + 0*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+0)] + weights[wIndex + kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + 0*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+1)] + weights[wIndex + kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
weights[wIndex + 0*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+2)] + weights[wIndex + 2*kernelDims[1]] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy)] +
weights[wIndex + 1*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
weights[wIndex + 1*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + 1*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
weights[wIndex + 2*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
weights[wIndex + 2*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+1)] + weights[wIndex + 2*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + 2*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+2)]); weights[wIndex + 2*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+2)]);
} else { } else {
for (std::size_t sx = sxMin; sx*dilationDims[0] < sxMax; ++sx) { for (std::size_t sx = 0; sx*dilationDims[0] < dilated_kernel_x; ++sx) {
for (std::size_t sy = syMin; sy*dilationDims[1] < syMax; ++sy) { for (std::size_t sy = 0; sy*dilationDims[1] < dilated_kernel_y; ++sy) {
output[oIndexFull] += weights[wIndex + sx*kernelDims[1] + sy] * output[oIndexFull] += weights[wIndex + sx*kernelDims[1] + sy] *
input[iIndex + static_cast<std::size_t>(ix+static_cast<signedsize>(sx*dilationDims[0]))*inputDims[3] + static_cast<std::size_t>(iy+static_cast<signedsize>(sy*dilationDims[1]))]; input[iIndex + (ix + (sx*dilationDims[0]))*inputDims[3] + (iy + (sy*dilationDims[1]))];
} }
} }
} }
} }
} }
} }
output += outChannels_s;
} }
} }
} }
// Kernels registration to implementation entry point // Kernels registration to implementation entry point
REGISTRAR(ConvImpl2D_cpu, REGISTRAR(ConvImpl2D_cpu,
{{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}}, {{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment