Skip to content
Snippets Groups Projects
Commit b3f9c722 authored by Maxence Naud's avatar Maxence Naud
Browse files

FIX: Conv[DepthWise] forward implementation for some cases that were not...

FIX: Conv[DepthWise] forward implementation for some cases that were not tested in the new implementation
parent 740a27ba
No related branches found
No related tags found
1 merge request!124FIX: Conv[DepthWise] forward implementation and add tests
...@@ -175,17 +175,17 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri ...@@ -175,17 +175,17 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri
} }
} }
} else { } else {
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex-=strideDims[0]*inputDims[3]) { for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=(strideDims[0]-2)*inputDims[3]) {
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy]+weights[wIndex+1]*input[iIndex+oy+strideDims[0]]+weights[wIndex+2]*input[iIndex+oy+strideDims[0]*2]; output[oIndex + oy] = biasVal + weights[wIndex+0]*input[iIndex+oy*strideDims[1]]+weights[wIndex+1]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+2]*input[iIndex+oy*strideDims[1]+2];
} }
iIndex+=strideDims[0]*inputDims[3]; iIndex+=inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy]+weights[wIndex+4]*input[iIndex+oy+strideDims[0]]+weights[wIndex+5]*input[iIndex+oy+strideDims[0]*2]; output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy*strideDims[1]]+weights[wIndex+4]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+5]*input[iIndex+oy*strideDims[1]+2];
} }
iIndex+=strideDims[0]*inputDims[3]; iIndex+=inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy]+weights[wIndex+7]*input[iIndex+oy+strideDims[0]]+weights[wIndex+8]*input[iIndex+oy+strideDims[0]*2]; output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy*strideDims[1]]+weights[wIndex+7]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+8]*input[iIndex+oy*strideDims[1]+2];
} }
} }
} }
...@@ -193,25 +193,23 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri ...@@ -193,25 +193,23 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri
} }
} }
} else if (dilated_kernel_x == 1 && dilated_kernel_y == 1) { } else if (dilated_kernel_x == 1 && dilated_kernel_y == 1) {
std::size_t index = 0;
for (std::size_t batch = 0; batch < inputDims[0]; ++batch) { for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
for (std::size_t ch = 0; ch < inputDims[1]; ++ch) { for (std::size_t ch = 0; ch < inputDims[1]; ++ch) {
B biasVal = (biases != nullptr) ? biases[ch] : B(0); B biasVal = (biases != nullptr) ? biases[ch] : B(0);
const std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3]; std::size_t iIndex = (ch + batch*inputDims[1]) * inputDims[2] * inputDims[3];
const std::size_t wIndex = ch; const std::size_t wIndex = ch;
if (strideDims[0] == 1 && strideDims[1] == 1) { if (strideDims[0] == 1 && strideDims[1] == 1) {
for (; index < iIndex + oxSize*oySize; ++index) { for (std::size_t i = iIndex; i < iIndex + oxSize*oySize; ++i) {
output[index] = biasVal + weights[wIndex] * input[index]; output[i] = biasVal + weights[wIndex] * input[i];
} }
} else { } else {
std::size_t oIndex = (ch + batch*inputDims[1]) * oxSize * oySize; std::size_t oIndex = (ch + batch*inputDims[1]) * oxSize * oySize;
for (std::size_t ox = 0; ox < oxSize; ++ox, oIndex+=oySize) { for (std::size_t ox = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=strideDims[0]*inputDims[3]) {
index = iIndex + strideDims[0]*inputDims[3];
for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) { for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) {
output[oIndex + oy] += weights[wIndex]*input[index+iy]; output[oIndex + oy] = biasVal + weights[wIndex]*input[iIndex+iy];
} }
} }
} }
...@@ -234,16 +232,16 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri ...@@ -234,16 +232,16 @@ void ConvDepthWiseImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& stri
const std::size_t ix = ox * strideDims[0]; const std::size_t ix = ox * strideDims[0];
const std::size_t iy = oy * strideDims[1]; const std::size_t iy = oy * strideDims[1];
for (std::size_t sx = 0; sx*dilationDims[0] < dilated_kernel_x; ++sx) { for (std::size_t kx = 0; kx*dilationDims[0] < dilated_kernel_x; ++kx) {
for (std::size_t sy = 0; sy*dilationDims[1] < dilated_kernel_y; ++sy) { for (std::size_t ky = 0; ky*dilationDims[1] < dilated_kernel_y; ++ky) {
output[oIndexFull] += weights[wIndex + sx*kernelDims[1] + sy] * output[oIndexFull] += weights[wIndex + kx*kernelDims[1] + ky] *
input[iIndex + static_cast<std::size_t>(ix + sx*dilationDims[0])*inputDims[3] + static_cast<std::size_t>(iy + sy*dilationDims[1])]; input[iIndex + (ix + kx*dilationDims[0])*inputDims[3] + (iy + ky*dilationDims[1])];
} }
} }
} }
} }
output += outChannels_s;
} }
output += outChannels_s;
} }
} }
} }
......
...@@ -183,17 +183,17 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, ...@@ -183,17 +183,17 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
} }
} }
} else { } else {
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex-=strideDims[0]*inputDims[3]) { for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=(strideDims[0]-2)*inputDims[3]) {
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy]+weights[wIndex+1]*input[iIndex+oy+strideDims[0]]+weights[wIndex+2]*input[iIndex+oy+strideDims[0]*2]; output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy*strideDims[1]]+weights[wIndex+1]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+2]*input[iIndex+oy*strideDims[1]+2];
} }
iIndex+=strideDims[0]*inputDims[3]; iIndex+=inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy]+weights[wIndex+4]*input[iIndex+oy+strideDims[0]]+weights[wIndex+5]*input[iIndex+oy+strideDims[0]*2]; output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy*strideDims[1]]+weights[wIndex+4]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+5]*input[iIndex+oy*strideDims[1]+2];
} }
iIndex+=strideDims[0]*inputDims[3]; iIndex+=inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy]+weights[wIndex+7]*input[iIndex+oy+strideDims[0]]+weights[wIndex+8]*input[iIndex+oy+strideDims[0]*2]; output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy*strideDims[1]]+weights[wIndex+7]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+8]*input[iIndex+oy*strideDims[1]+2];
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment