Skip to content
Snippets Groups Projects
Commit 3d34c30d authored by Grégoire Kubler's avatar Grégoire Kubler
Browse files

chore : conv forward 1/2D formatting

parent 2e013a01
No related branches found
No related tags found
No related merge requests found
This commit is part of merge request !142. Comments created here will be created in the context of that merge request.
...@@ -450,16 +450,15 @@ REGISTRAR(ConvImpl1D_cpu, ...@@ -450,16 +450,15 @@ REGISTRAR(ConvImpl1D_cpu,
* @param output_ Output Tensor. * @param output_ Output Tensor.
*/ */
template <class I, class W, class B, class O> template <class I, class W, class B, class O>
void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, void ConvImpl2D_cpu_forward_kernel(const array<DimSize_t, 2> &strideDims,
const std::array<DimSize_t, 2>& dilationDims, const array<DimSize_t, 2> &dilationDims,
const std::array<DimSize_t, 2>& kernelDims, const array<DimSize_t, 2> &kernelDims,
const std::array<DimSize_t, 4> &inputDims, const array<DimSize_t, 4> &inputDims,
DimSize_t outChannels, DimSize_t outChannels,
const void *input_, const void *input_,
const void *weights_, const void *weights_,
const void *biases_, const void *biases_,
void *output_) void *output_) {
{
// FIXME: missing convolution attributes as arguments // FIXME: missing convolution attributes as arguments
const I *input = static_cast<const I *>(input_); const I *input = static_cast<const I *>(input_);
const W *weights = static_cast<const W *>(weights_); const W *weights = static_cast<const W *>(weights_);
...@@ -467,59 +466,102 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, ...@@ -467,59 +466,102 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
O *output = static_cast<O *>(output_); O *output = static_cast<O *>(output_);
// output H size // output H size
const DimSize_t dilated_kernel_x = dilationDims[0]*(kernelDims[0] - 1) + 1; const DimSize_t dilated_kernel_x =
const std::size_t oxSize = dilationDims[0] * (kernelDims[0] - 1) + 1;
static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) / const std::size_t oxSize = static_cast<std::size_t>(std::floor(
static_cast<float>(strideDims[0]))); static_cast<float>(inputDims[2] - dilated_kernel_x + strideDims[0]) /
static_cast<float>(strideDims[0])));
// output W size // output W size
const DimSize_t dilated_kernel_y = dilationDims[1]*(kernelDims[1] - 1) + 1; const DimSize_t dilated_kernel_y =
const std::size_t oySize = dilationDims[1] * (kernelDims[1] - 1) + 1;
static_cast<std::size_t>(std::floor(static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) / const std::size_t oySize = static_cast<std::size_t>(std::floor(
static_cast<float>(strideDims[1]))); static_cast<float>(inputDims[3] - dilated_kernel_y + strideDims[1]) /
static_cast<float>(strideDims[1])));
// TODO: kernel computation // TODO: kernel computation
// output (batch, outCh, Xout, Yout) // output (batch, outCh, Xout, Yout)
// input (batch, inCh, Xin, Yin) // input (batch, inCh, Xin, Yin)
// weight (outCh, inCh, kernelX, kernelY) // weight (outCh, inCh, kernelX, kernelY)
// does not take Dilation attribute into account // does not take Dilation attribute into account
const std::size_t outChannels_s = oxSize * oySize; const std::size_t outChannels_s = oxSize * oySize;
if (dilated_kernel_x == 3 && dilated_kernel_y == 3) { if (dilated_kernel_x == 3 && dilated_kernel_y == 3) {
for (std::size_t batch = 0; batch < inputDims[0]; ++batch) { for (std::size_t batch = 0; batch < inputDims[0]; ++batch) {
for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
// If bias = nullptr, set B(0) // If bias = nullptr, set B(0)
B biasVal = (biases != nullptr) ? biases[outCh] : B(0); B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
std::fill(output, output+outChannels_s, biasVal); std::fill(output, output + outChannels_s, biasVal);
for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; std::size_t iIndex = (inCh + batch * inputDims[1]) *
const std::size_t wIndex = (inCh + outCh*inputDims[1]) * 9; inputDims[2] * inputDims[3];
if (strideDims[0] == 1 && strideDims[1]==1) { const std::size_t wIndex =
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex-=inputDims[3]) { (inCh + outCh * inputDims[1]) * 9;
if (strideDims[0] == 1 && strideDims[1] == 1) {
for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
++ox, oIndex += oySize, iIndex -= inputDims[3]) {
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy]+weights[wIndex+1]*input[iIndex+oy+1]+weights[wIndex+2]*input[iIndex+oy+2]; output[oIndex + oy] +=
weights[wIndex + 0] * input[iIndex + oy] +
weights[wIndex + 1] *
input[iIndex + oy + 1] +
weights[wIndex + 2] *
input[iIndex + oy + 2];
} }
iIndex+=inputDims[3]; iIndex += inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy]+weights[wIndex+4]*input[iIndex+oy+1]+weights[wIndex+5]*input[iIndex+oy+2]; output[oIndex + oy] +=
weights[wIndex + 3] * input[iIndex + oy] +
weights[wIndex + 4] *
input[iIndex + oy + 1] +
weights[wIndex + 5] *
input[iIndex + oy + 2];
} }
iIndex+=inputDims[3]; iIndex += inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy]+weights[wIndex+7]*input[iIndex+oy+1]+weights[wIndex+8]*input[iIndex+oy+2]; output[oIndex + oy] +=
weights[wIndex + 6] * input[iIndex + oy] +
weights[wIndex + 7] *
input[iIndex + oy + 1] +
weights[wIndex + 8] *
input[iIndex + oy + 2];
} }
} }
} else { } else {
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=(strideDims[0]-2)*inputDims[3]) { for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox,
oIndex += oySize,
iIndex += (strideDims[0] -
2) * inputDims[3]) {
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+0]*input[iIndex+oy*strideDims[1]]+weights[wIndex+1]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+2]*input[iIndex+oy*strideDims[1]+2]; output[oIndex + oy] +=
weights[wIndex + 0] *
input[iIndex + oy * strideDims[1]] +
weights[wIndex + 1] *
input[iIndex + oy * strideDims[1] +
1] +
weights[wIndex + 2] *
input[iIndex + oy * strideDims[1] + 2];
} }
iIndex+=inputDims[3]; iIndex += inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+3]*input[iIndex+oy*strideDims[1]]+weights[wIndex+4]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+5]*input[iIndex+oy*strideDims[1]+2]; output[oIndex + oy] +=
weights[wIndex + 3] *
input[iIndex + oy * strideDims[1]] +
weights[wIndex + 4] *
input[iIndex + oy * strideDims[1] +
1] +
weights[wIndex + 5] *
input[iIndex + oy * strideDims[1] + 2];
} }
iIndex+=inputDims[3]; iIndex += inputDims[3];
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
output[oIndex + oy] += weights[wIndex+6]*input[iIndex+oy*strideDims[1]]+weights[wIndex+7]*input[iIndex+oy*strideDims[1]+1]+weights[wIndex+8]*input[iIndex+oy*strideDims[1]+2]; output[oIndex + oy] +=
weights[wIndex + 6] *
input[iIndex + oy * strideDims[1]] +
weights[wIndex + 7] *
input[iIndex + oy * strideDims[1] +
1] +
weights[wIndex + 8] *
input[iIndex + oy * strideDims[1] + 2];
} }
} }
} }
...@@ -532,18 +574,26 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, ...@@ -532,18 +574,26 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
// If bias = nullptr, set B(0) // If bias = nullptr, set B(0)
B biasVal = (biases != nullptr) ? biases[outCh] : B(0); B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
std::fill(output, output+outChannels_s, biasVal); std::fill(output, output + outChannels_s, biasVal);
for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
std::size_t iIndex = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; std::size_t iIndex = (inCh + batch * inputDims[1]) *
const std::size_t wIndex = (inCh + outCh*inputDims[1]); inputDims[2] * inputDims[3];
const std::size_t wIndex = (inCh + outCh * inputDims[1]);
if (strideDims[0] == 1 && strideDims[1] == 1) { if (strideDims[0] == 1 && strideDims[1] == 1) {
for (std::size_t oIndex = 0; oIndex < oxSize*oySize; ++oIndex, ++iIndex) { for (std::size_t oIndex = 0; oIndex < oxSize * oySize;
++oIndex, ++iIndex) {
output[oIndex] += weights[wIndex] * input[iIndex]; output[oIndex] += weights[wIndex] * input[iIndex];
} }
} else { } else {
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex+=inputDims[3]*strideDims[0]) { for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) { ++ox,
output[oIndex + oy] += weights[wIndex+0]*input[iIndex+iy]; oIndex += oySize,
iIndex +=
inputDims[3] * strideDims[0]) {
for (std::size_t oy = 0, iy = 0; oy < oySize;
++oy, iy += strideDims[1]) {
output[oIndex + oy] +=
weights[wIndex + 0] * input[iIndex + iy];
} }
} }
} }
...@@ -556,21 +606,36 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, ...@@ -556,21 +606,36 @@ void ConvImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
for (std::size_t outCh = 0; outCh < outChannels; ++outCh) { for (std::size_t outCh = 0; outCh < outChannels; ++outCh) {
// If bias = nullptr, set B(0) // If bias = nullptr, set B(0)
B biasVal = (biases != nullptr) ? biases[outCh] : B(0); B biasVal = (biases != nullptr) ? biases[outCh] : B(0);
std::fill(output, output+outChannels_s, biasVal); std::fill(output, output + outChannels_s, biasVal);
for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) { for (std::size_t inCh = 0; inCh < inputDims[1]; ++inCh) {
std::size_t iIndex_channel = (inCh + batch*inputDims[1]) * inputDims[2] * inputDims[3]; std::size_t iIndex_channel =
const std::size_t wIndex = (inCh + outCh*inputDims[1]) * kernelDims[0] * kernelDims[1]; (inCh + batch * inputDims[1]) * inputDims[2] *
inputDims[3];
const std::size_t wIndex = (inCh + outCh * inputDims[1]) *
kernelDims[0] * kernelDims[1];
// loop over each ouput line // loop over each ouput line
for (std::size_t ox = 0, oIndex = 0; ox < oxSize; ++ox, oIndex+=oySize, iIndex_channel+=inputDims[3]*strideDims[0]) { for (std::size_t ox = 0, oIndex = 0; ox < oxSize;
++ox,
oIndex += oySize,
iIndex_channel +=
inputDims[3] * strideDims[0]) {
// loop over associated input line // loop over associated input line
for (std::size_t ky = 0, ix = 0; ky < kernelDims[0]; ++ky, ix += inputDims[3]*dilationDims[0]) { for (std::size_t ky = 0, ix = 0; ky < kernelDims[0];
++ky, ix += inputDims[3] * dilationDims[0]) {
// loop over the entire line // loop over the entire line
for (std::size_t oy = 0, iy = 0; oy < oySize; ++oy, iy+=strideDims[1]) { for (std::size_t oy = 0, iy = 0; oy < oySize;
const std::size_t iIndex = iIndex_channel + ix + iy; ++oy, iy += strideDims[1]) {
// loop over elements assosicated with one output const std::size_t iIndex =
for (std::size_t kx = 0; kx < kernelDims[0]; ++kx) { iIndex_channel + ix + iy;
output[oIndex + oy] += weights[wIndex+kernelDims[0]*ky+kx]*input[iIndex+kx*dilationDims[1]]; // loop over elements assosicated with one
// output
for (std::size_t kx = 0; kx < kernelDims[0];
++kx) {
output[oIndex + oy] +=
weights[wIndex + kernelDims[0] * ky +
kx] *
input[iIndex + kx * dilationDims[1]];
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment