Skip to content
Snippets Groups Projects
Commit c4098189 authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

chore : conv forward 1/2D formatting

parent f9d0c517
No related branches found
No related tags found
No related merge requests found
...@@ -450,16 +450,15 @@ REGISTRAR(ConvImpl1D_cpu, ...@@ -450,16 +450,15 @@ REGISTRAR(ConvImpl1D_cpu,
* @param output_ Output Tensor. * @param output_ Output Tensor.
*/ */
template <class I, class W, class B, class O> template <class I, class W, class B, class O>
void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, void ConvImpl2D_cpu_forward_kernel(const array<DimSize_t, 2> &strideDims,
const std::array<DimSize_t, 2>& dilationDims, const array<DimSize_t, 2> &dilationDims,
const std::array<DimSize_t, 2>& kernelDims, const array<DimSize_t, 2> &kernelDims,
const std::array<DimSize_t, 4> &inputDims, const array<DimSize_t, 4> &inputDims,
DimSize_t outChannels, DimSize_t outChannels,
const void *input_, const void *input_,
const void *weights_, const void *weights_,
const void *biases_, const void *biases_,
void *output_) void *output_) {
{
// FIXME: missing convolution attributes as arguments // FIXME: missing convolution attributes as arguments
const I *input = static_cast<const I *>(input_); const I *input = static_cast<const I *>(input_);
const W *weights = static_cast<const W *>(weights_); const W *weights = static_cast<const W *>(weights_);
...@@ -467,59 +466,102 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, ...@@ -467,59 +466,102 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
O *output = static_cast<O *>(output_); O *output = static_cast<O *>(output_);
// output H size // output H size
const DimSize_t dilated_kernel_x = dilationDims[0]*(kernelDims[0] - 1) + 1; const DimSize_t dilated_kernel_x =
const std::size_t oxSize = dilationDims[0] * (kernelDims[0] - 1) + 1;
static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) / const std::size_t oxSize = static_cast<std::size_t>(std::floor(
static_cast<float>(strideDims[0]))); static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) /
static_cast<float>(strideDims[0])));
// output W size // output W size
const DimSize_t dilated_kernel_y = dilationDims[1]*(kernelDims[1] - 1) + 1; const DimSize_t dilated_kernel_y =
const std::size_t oySize = dilationDims[1] * (kernelDims[1] - 1) + 1;
static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) / const std::size_t oySize = static_cast<std::size_t>(std::floor(
static_cast<float>(strideDims[1]))); static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) /
static_cast<float>(strideDims[1])));
// TODO: kernel computation // TODO: kernel computation
// output (batch, outCh, Xout, Yout) // output (batch, outCh, Xout, Yout)
// input (batch, inCh, Xin, Yin) // input (batch, inCh, Xin, Yin)
// weight (outCh, inCh, kernelX, kernelY) // weight (outCh, inCh, kernelX, kernelY)
// does not take Dilation attribute into account // does not take Dilation attribute into account
const std::size_t outChannels_s = oxSize * oySize; const std::size_t outChannels_s = oxSize * oySize;
if (dilated_kernel_x == 3 && dilated_kernel_y == 3) { if (dilated_kernel_x == 3 && dilated_kernel_y == 3) {
for (std::size_t batch = 0; batch < inputDims[0]; ++batch) { for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
// If bias = nullptr, set B(0) // If bias = nullptr, set B(0)
B biasVal = (biases != nullptr) ? biases[outCh] : B(0); B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
std::fill(output, output+outChannels_s, biasVal); std::fill(output, output + outChannels_s, biasVal);
for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; std::size_t iIndex = (inCh + batch * inputDims[1]) *
const std::size_t wIndex = (inCh + outCh*inputDims[1]) * 9; inputDims[2] * inputDims[3];
if (strideDims[0] == 1 && strideDims[1]==1) { const std::size_t wIndex =
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex-=inputDims[3]) { (inCh + outCh * inputDims[1]) * 9;
if (strideDims[0] == 1 && strideDims[1] == 1) {
for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
++ox, oIndex += oySize, iIndex -= inputDims[3]) {
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy]+weights[wIndex+1]*input[iIndex+oy+1]+weights[wIndex+2]*input[iIndex+oy+2]; output[oIndex + oy] +=
weights[wIndex + 0] * input[iIndex + oy] +
weights[wIndex + 1] *
input[iIndex + oy + 1] +
weights[wIndex + 2] *
input[iIndex + oy + 2];
} }
iIndex+=inputDims[3]; iIndex += inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy]+weights[wIndex+4]*input[iIndex+oy+1]+weights[wIndex+5]*input[iIndex+oy+2]; output[oIndex + oy] +=
weights[wIndex + 3] * input[iIndex + oy] +
weights[wIndex + 4] *
input[iIndex + oy + 1] +
weights[wIndex + 5] *
input[iIndex + oy + 2];
} }
iIndex+=inputDims[3]; iIndex += inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy]+weights[wIndex+7]*input[iIndex+oy+1]+weights[wIndex+8]*input[iIndex+oy+2]; output[oIndex + oy] +=
weights[wIndex + 6] * input[iIndex + oy] +
weights[wIndex + 7] *
input[iIndex + oy + 1] +
weights[wIndex + 8] *
input[iIndex + oy + 2];
} }
} }
} else { } else {
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=(strideDims[0]-2)*inputDims[3]) { for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox,
oIndex += oySize,
iIndex += (strideDims[0] -
2) * inputDims[3]) {
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy*strideDims[1]]+weights[wIndex+1]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+2]*input[iIndex+oy*strideDims[1]+2]; output[oIndex + oy] +=
weights[wIndex + 0] *
input[iIndex + oy * strideDims[1]] +
weights[wIndex + 1] *
input[iIndex + oy * strideDims[1] +
1] +
weights[wIndex + 2] *
input[iIndex + oy * strideDims[1] + 2];
} }
iIndex+=inputDims[3]; iIndex += inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy*strideDims[1]]+weights[wIndex+4]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+5]*input[iIndex+oy*strideDims[1]+2]; output[oIndex + oy] +=
weights[wIndex + 3] *
input[iIndex + oy * strideDims[1]] +
weights[wIndex + 4] *
input[iIndex + oy * strideDims[1] +
1] +
weights[wIndex + 5] *
input[iIndex + oy * strideDims[1] + 2];
} }
iIndex+=inputDims[3]; iIndex += inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy*strideDims[1]]+weights[wIndex+7]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+8]*input[iIndex+oy*strideDims[1]+2]; output[oIndex + oy] +=
weights[wIndex + 6] *
input[iIndex + oy * strideDims[1]] +
weights[wIndex + 7] *
input[iIndex + oy * strideDims[1] +
1] +
weights[wIndex + 8] *
input[iIndex + oy * strideDims[1] + 2];
} }
} }
} }
...@@ -532,18 +574,26 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, ...@@ -532,18 +574,26 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
// If bias = nullptr, set B(0) // If bias = nullptr, set B(0)
B biasVal = (biases != nullptr) ? biases[outCh] : B(0); B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
std::fill(output, output+outChannels_s, biasVal); std::fill(output, output + outChannels_s, biasVal);
for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; std::size_t iIndex = (inCh + batch * inputDims[1]) *
const std::size_t wIndex = (inCh + outCh*inputDims[1]); inputDims[2] * inputDims[3];
const std::size_t wIndex = (inCh + outCh * inputDims[1]);
if (strideDims[0] == 1 && strideDims[1] == 1) { if (strideDims[0] == 1 && strideDims[1] == 1) {
for (std::size_t oIndex = 0; oIndex < oxSize*oySize; ++oIndex, ++iIndex) { for (std::size_t oIndex = 0; oIndex < oxSize * oySize;
++oIndex, ++iIndex) {
output[oIndex] += weights[wIndex] * input[iIndex]; output[oIndex] += weights[wIndex] * input[iIndex];
} }
} else { } else {
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=inputDims[3]*strideDims[0]) { for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) { ++ox,
output[oIndex + oy] += weights[wIndex+0]*input[iIndex+iy]; oIndex += oySize,
iIndex +=
inputDims[3] * strideDims[0]) {
for (std::size_t oy = 0, iy = 0; oy < oySize;
++oy, iy += strideDims[1]) {
output[oIndex + oy] +=
weights[wIndex + 0] * input[iIndex + iy];
} }
} }
} }
...@@ -556,21 +606,36 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, ...@@ -556,21 +606,36 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
// If bias = nullptr, set B(0) // If bias = nullptr, set B(0)
B biasVal = (biases != nullptr) ? biases[outCh] : B(0); B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
std::fill(output, output+outChannels_s, biasVal); std::fill(output, output + outChannels_s, biasVal);
for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
std::size_t iIndex_channel = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; std::size_t iIndex_channel =
const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1]; (inCh + batch * inputDims[1]) * inputDims[2] *
inputDims[3];
const std::size_t wIndex = (inCh + outCh * inputDims[1]) *
kernelDims[0] * kernelDims[1];
// loop over each ouput line // loop over each ouput line
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex_channel+=inputDims[3]*strideDims[0]) { for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
++ox,
oIndex += oySize,
iIndex_channel +=
inputDims[3] * strideDims[0]) {
// loop over associated input line // loop over associated input line
for (std::size_t ky = 0, ix = 0; ky < kernelDims[0]; ++ky, ix += inputDims[3]*dilationDims[0]) { for (std::size_t ky = 0, ix = 0; ky < kernelDims[0];
++ky, ix += inputDims[3] * dilationDims[0]) {
// loop over the entire line // loop over the entire line
for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) { for (std::size_t oy = 0, iy = 0; oy < oySize;
const std::size_t iIndex = iIndex_channel + ix + iy; ++oy, iy += strideDims[1]) {
// loop over elements assosicated with one output const std::size_t iIndex =
for (std::size_t kx = 0; kx < kernelDims[0]; ++kx) { iIndex_channel + ix + iy;
output[oIndex + oy] += weights[wIndex+kernelDims[0]*ky+kx]*input[iIndex+kx*dilationDims[1]]; // loop over elements assosicated with one
// output
for (std::size_t kx = 0; kx < kernelDims[0];
++kx) {
output[oIndex + oy] +=
weights[wIndex + kernelDims[0] * ky +
kx] *
input[iIndex + kx * dilationDims[1]];
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment