Skip to content
Snippets Groups Projects
Commit e7eec22f authored by Maxence Naud's avatar Maxence Naud Committed by Maxence Naud
Browse files

improve conv kernel speed

parent 2a745b7d
No related branches found
No related tags found
1 merge request!101Upd 2D Conv[DepthWise] kernels
......@@ -157,57 +157,52 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
// input (batch, inCh, Xin, Yin)
// weight (outCh, inCh, kernelX, kernelY)
// does not take Dilation attribute into account
using signedsize = std::make_signed<std::size_t>::type;
const std::size_t outChannels_s = oxSize * oySize;
for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
const std::size_t oIndex = (outCh + batch*outChannels) * oxSize * oySize;
// If bias = nullptr, set B(0)
B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
std::fill(output + oIndex, output+(oIndex+oxSize*oySize), biasVal);
std::fill(output, output+outChannels_s, biasVal);
for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
const std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3];
const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1];
for (std::size_t ox = 0; ox < oxSize; ++ox) {
// const signedsize difx = static_cast<signedsize>(- ox * strideDims[0]);
// const std::size_t sxMin = static_cast<std::size_t>(std::max(difx, signedsize(0)));
// const std::size_t sxMax = (static_cast<signedsize>(inputDims[2]) + difx) < 0 ? 0 : ((inputDims[2] + difx) > kernelDims[0] ? kernelDims[0] : inputDims[2] + difx);
const std::size_t sxMin = 0;
const std::size_t sxMax = dilated_kernel_x;
for (std::size_t oy = 0; oy < oySize; ++oy) {
// const signedsize dify = static_cast<signedsize>(- oy * strideDims[1]);
// const std::size_t syMin = static_cast<std::size_t>(std::max(dify, signedsize(0)));
// const std::size_t syMax = (static_cast<signedsize>(inputDims[3]) + dify) < 0 ? 0 : ((inputDims[3] + dify) > kernelDims[1] ? kernelDims[1] : inputDims[3] + dify);
const std::size_t syMin = 0;
const std::size_t syMax = dilated_kernel_y;
const std::size_t oIndexFull = oIndex + ox*oySize + oy;
const signedsize ix = static_cast<signedsize>(ox * strideDims[0]);
const signedsize iy = static_cast<signedsize>(oy * strideDims[1]);
if (sxMin == 0 && syMin == 0 && sxMax == 3 && syMax == 3) {
output[oIndexFull] += (weights[wIndex + 0*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
weights[wIndex + 0*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + 0*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+0)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
weights[wIndex + 1*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
weights[wIndex + 1*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + 1*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
weights[wIndex + 2*kernelDims[1] + 0] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+0)] +
const std::size_t oIndexFull = ox*oySize + oy;
const size_t ix = ox * strideDims[0];
const size_t iy = oy * strideDims[1];
if (kernelDims[0] == 3 && kernelDims[1] == 3 && dilationDims[0] == 1 && dilationDims[1] == 1) {
output[oIndexFull] += (weights[wIndex] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy)] +
weights[wIndex + 1] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + 2] * input[iIndex + static_cast<std::size_t>(ix)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
weights[wIndex + kernelDims[1]] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy)] +
weights[wIndex + kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+1)*inputDims[3] + static_cast<std::size_t>(iy+2)] +
weights[wIndex + 2*kernelDims[1]] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy)] +
weights[wIndex + 2*kernelDims[1] + 1] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+1)] +
weights[wIndex + 2*kernelDims[1] + 2] * input[iIndex + static_cast<std::size_t>(ix+2)*inputDims[3] + static_cast<std::size_t>(iy+2)]);
} else {
for (std::size_t sx = sxMin; sx*dilationDims[0] < sxMax; ++sx) {
for (std::size_t sy = syMin; sy*dilationDims[1] < syMax; ++sy) {
for (std::size_t sx = 0; sx*dilationDims[0] < dilated_kernel_x; ++sx) {
for (std::size_t sy = 0; sy*dilationDims[1] < dilated_kernel_y; ++sy) {
output[oIndexFull] += weights[wIndex + sx*kernelDims[1] + sy] *
input[iIndex + static_cast<std::size_t>(ix+static_cast<signedsize>(sx*dilationDims[0]))*inputDims[3] + static_cast<std::size_t>(iy+static_cast<signedsize>(sy*dilationDims[1]))];
input[iIndex + (ix + (sx*dilationDims[0]))*inputDims[3] + (iy + (sy*dilationDims[1]))];
}
}
}
}
}
}
output += outChannels_s;
}
}
}
// Kernels registration to implementation entry point
REGISTRAR(ConvImpl2D_cpu,
{{DataType::Any, DataFormat::NCHW}, {DataType::Float32, DataFormat::NCHW}},
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment