Skip to content
Snippets Groups Projects
Commit 1096664a authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

add dilations and cielmode to AvgPooling

parent 39117d5a
No related branches found
No related tags found
No related merge requests found
...@@ -28,8 +28,10 @@ namespace Aidge { ...@@ -28,8 +28,10 @@ namespace Aidge {
using AvgPooling2D_Op = AvgPooling_Op<2>; using AvgPooling2D_Op = AvgPooling_Op<2>;
using AvgPoolingImpl2D_cpu = OperatorImpl_cpu<AvgPooling_Op<2>, using AvgPoolingImpl2D_cpu = OperatorImpl_cpu<AvgPooling_Op<2>,
void(const std::array<DimSize_t, 2>&, void(const std::array<DimSize_t, 2>&,
const std::array<DimSize_t, 2>&,
const std::array<DimSize_t, 2>&, const std::array<DimSize_t, 2>&,
const std::array<DimSize_t, 4>&, const std::array<DimSize_t, 4>&,
bool,
const void *, const void *,
void *)>; void *)>;
......
...@@ -35,66 +35,54 @@ namespace Aidge { ...@@ -35,66 +35,54 @@ namespace Aidge {
template <class I, class O> template <class I, class O>
void AvgPoolingImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims, void AvgPoolingImpl2D_cpu_forward_kernel(const std::array<DimSize_t, 2>& strideDims,
const std::array<DimSize_t, 2>& kernelDims, const std::array<DimSize_t, 2>& kernelDims,
const std::array<DimSize_t, 2>& dilations,
const std::array<DimSize_t, 4> &dims, const std::array<DimSize_t, 4> &dims,
bool ceilMode,
const void *input_, const void *input_,
void *output_) { void *output_) {
// FIXME: missing convolution attributes as arguments
const I *input = static_cast<const I *>(input_); const I *input = static_cast<const I *>(input_);
O *output = static_cast<O *>(output_); O *output = static_cast<O *>(output_);
// Calculate output dimensions based on ceilMode and dilations
auto compute_output_size = [&](DimSize_t inputDim, DimSize_t kernelDim, DimSize_t stride, DimSize_t dilation) {
DimSize_t effectiveKernelDim = (kernelDim - 1) * dilation + 1;
float result = static_cast<float>(inputDim - effectiveKernelDim + stride) / static_cast<float>(stride);
return ceilMode ? static_cast<DimSize_t>(std::ceil(result)) : static_cast<DimSize_t>(std::floor(result));
};
// output H size const std::size_t oxSize = compute_output_size(dims[2], kernelDims[0], strideDims[0], dilations[0]);
const std::size_t oxSize = const std::size_t oySize = compute_output_size(dims[3], kernelDims[1], strideDims[1], dilations[1]);
static_cast<std::size_t>(std::floor(static_cast<float>(dims[2] - kernelDims[0] + strideDims[0]) /
static_cast<float>(strideDims[0])));
// output W size
const std::size_t oySize =
static_cast<std::size_t>(std::floor(static_cast<float>(dims[3] - kernelDims[1] + strideDims[1]) /
static_cast<float>(strideDims[1])));
// TODO: kernel computation
// output (batch, outCh, Xout, Yout)
// input (batch, ch, Xin, Yin)
// weight (outCh, ch, kernelX, kernelY)
// does not take Dilation attribute into account
using signedsize = std::make_signed<std::size_t>::type; using signedsize = std::make_signed<std::size_t>::type;
for (std::size_t batch = 0; batch < dims[0]; ++batch) { for (std::size_t batch = 0; batch < dims[0]; ++batch) {
for (std::size_t ch = 0; ch < dims[1]; ++ch) { for (std::size_t ch = 0; ch < dims[1]; ++ch) {
const std::size_t oIndex = (ch + batch*dims[1]) * oxSize * oySize; const std::size_t oIndex = (ch + batch * dims[1]) * oxSize * oySize;
const std::size_t iIndex = (ch + batch*dims[1]) * dims[2] * dims[3]; const std::size_t iIndex = (ch + batch * dims[1]) * dims[2] * dims[3];
std::fill(output + oIndex, output+(oIndex+oxSize*oySize), 0); std::fill(output + oIndex, output + (oIndex + oxSize * oySize), 0);
for (std::size_t ox = 0; ox < oxSize; ++ox) { for (std::size_t ox = 0; ox < oxSize; ++ox) {
const signedsize difx = static_cast<signedsize>(- ox * strideDims[0]); const signedsize startx = static_cast<signedsize>(ox * strideDims[0]) - (dilations[0] - 1);
const std::size_t sxMin = static_cast<std::size_t>(std::max(difx, signedsize(0))); const std::size_t sxMin = static_cast<std::size_t>(std::max(startx, signedsize(0)));
const std::size_t sxMax = (static_cast<signedsize>(dims[2]) + difx) < 0 ? 0 : ((dims[2] + difx) > kernelDims[0] ? kernelDims[0] : dims[2] + difx); const std::size_t sxMax = std::min(dims[2], static_cast<std::size_t>(startx + kernelDims[0] * dilations[0]));
for (std::size_t oy = 0; oy < oySize; ++oy) { for (std::size_t oy = 0; oy < oySize; ++oy) {
const signedsize dify = static_cast<signedsize>(- oy * strideDims[1]); const signedsize starty = static_cast<signedsize>(oy * strideDims[1]) - (dilations[1] - 1);
const std::size_t syMin = static_cast<std::size_t>(std::max(dify, signedsize(0))); const std::size_t syMin = static_cast<std::size_t>(std::max(starty, signedsize(0)));
const std::size_t syMax = (static_cast<signedsize>(dims[3]) + dify) < 0 ? 0 : ((dims[3] + dify) > kernelDims[1] ? kernelDims[1] : dims[3] + dify); const std::size_t syMax = std::min(dims[3], static_cast<std::size_t>(starty + kernelDims[1] * dilations[1]));
const std::size_t oIndexFull = oIndex + ox*oySize + oy;
const std::size_t ix = ox * strideDims[0];
const std::size_t iy = oy * strideDims[1];
if (sxMin == 0 && syMin == 0 && sxMax == 3 && syMax == 3) { const std::size_t oIndexFull = oIndex + ox * oySize + oy;
output[oIndexFull] += static_cast<O>( O sum = static_cast<O>(0);
input[iIndex + (ix+0)*dims[3] + (iy+0)] + std::size_t count = 0;
input[iIndex + (ix+0)*dims[3] + (iy+1)] +
input[iIndex + (ix+0)*dims[3] + (iy+2)] + for (std::size_t sx = sxMin; sx < sxMax; sx += dilations[0]) {
input[iIndex + (ix+1)*dims[3] + (iy+0)] + for (std::size_t sy = syMin; sy < syMax; sy += dilations[1]) {
input[iIndex + (ix+1)*dims[3] + (iy+1)] + sum += static_cast<O>(input[iIndex + sx * dims[3] + sy]);
input[iIndex + (ix+1)*dims[3] + (iy+2)] + ++count;
input[iIndex + (ix+2)*dims[3] + (iy+0)] +
input[iIndex + (ix+2)*dims[3] + (iy+1)] +
input[iIndex + (ix+2)*dims[3] + (iy+2)]) / O(9);
} else {
for (std::size_t sx = sxMin; sx < sxMax; ++sx) {
for (std::size_t sy = syMin; sy < syMax; ++sy) {
output[oIndexFull] += input[iIndex + (ix+sx)*dims[3] + (iy+sy)];
}
} }
// padding not used
output[oIndexFull] /= (sxMax - sxMin) * (syMax - syMin);
} }
output[oIndexFull] = sum / static_cast<O>(count);
} }
} }
} }
......
...@@ -32,7 +32,9 @@ void Aidge::AvgPoolingImpl2D_cpu::forward() { ...@@ -32,7 +32,9 @@ void Aidge::AvgPoolingImpl2D_cpu::forward() {
// Call kernel // Call kernel
impl.forward(op_.strideDims(), impl.forward(op_.strideDims(),
op_.kernelDims(), op_.kernelDims(),
op_.dilations(),
op_.getInput(0)->template dims<4>(), op_.getInput(0)->template dims<4>(),
op_.ceilMode(),
getCPUPtr(op_.getInput(0)), getCPUPtr(op_.getInput(0)),
getCPUPtr(op_.getOutput(0))); getCPUPtr(op_.getOutput(0)));
} }
......
...@@ -160,28 +160,6 @@ TEST_CASE("[cpu/operator] And(forward)", "[And][CPU]") { ...@@ -160,28 +160,6 @@ TEST_CASE("[cpu/operator] And(forward)", "[And][CPU]") {
} }
SECTION("Broadcasting") { SECTION("Broadcasting") {
<<<<<<< HEAD
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
{ //
{ //
{{10, 20},{22, 23},{20, 20}}, //
{{10, 15},{10, 29},{20, 20}}, //
{{26, 25},{33, 20},{10, 20}} //
} //
} //
}); //
std::shared_ptr<Tensor> input_2 = std::make_shared<Tensor>(Array1D<int,2> {{10, 20}});
std::shared_ptr<Tensor> expectedOutput = std::make_shared<Tensor>(Array4D<int,1,3,3,2> {
{ //
{ //
{{ 1, 1},{ 0, 0},{ 0, 1}}, //
{{ 1, 0},{ 1, 0},{ 0, 1}}, //
{{ 0, 0},{ 0, 1},{ 1, 1}} //
} //
} //
}); //
=======
std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<float, 1, 2, 2, 2>{ std::shared_ptr<Tensor> input_1 = std::make_shared<Tensor>(Array4D<float, 1, 2, 2, 2>{
{ {
{{{1, 0}, {1, 0}}, {{{1, 0}, {1, 0}},
...@@ -193,7 +171,6 @@ TEST_CASE("[cpu/operator] And(forward)", "[And][CPU]") { ...@@ -193,7 +171,6 @@ TEST_CASE("[cpu/operator] And(forward)", "[And][CPU]") {
{{{1, 0}, {1, 0}}, {{{1, 0}, {1, 0}},
{{1, 0}, {0, 0}}}} {{1, 0}, {0, 0}}}}
}); });
>>>>>>> fix and kernel and unit tests
std::shared_ptr<Node> myAnd = And(); std::shared_ptr<Node> myAnd = And();
auto op = std::static_pointer_cast<OperatorTensor>(myAnd->getOperator()); auto op = std::static_pointer_cast<OperatorTensor>(myAnd->getOperator());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment