Skip to content
Snippets Groups Projects
Commit d8ae9830 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Allow BatchNorm implementation to handle any number of dimenions

parent 9e272f0c
No related branches found
No related tags found
2 merge requests!118v0.4.0,!107Partial fix for issue eclipse/aidge/aidge_core#167
Pipeline #59522 passed
This commit is part of merge request !107. Comments created here will be created in the context of that merge request.
...@@ -29,7 +29,7 @@ using BatchNorm2D_Op = BatchNorm_Op<2>; ...@@ -29,7 +29,7 @@ using BatchNorm2D_Op = BatchNorm_Op<2>;
using BatchNormImpl2D_cpu = OperatorImpl_cpu<BatchNorm_Op<2>, using BatchNormImpl2D_cpu = OperatorImpl_cpu<BatchNorm_Op<2>,
void(float, void(float,
float, float,
const std::array<DimSize_t, 4> &, const std::vector<DimSize_t> &,
const void *, const void *,
const void *, const void *,
const void *, const void *,
......
...@@ -38,7 +38,7 @@ namespace Aidge { ...@@ -38,7 +38,7 @@ namespace Aidge {
* @param output_ Output Tensor. * @param output_ Output Tensor.
*/ */
template <class I, class P, class O> template <class I, class P, class O>
void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std::array<DimSize_t, 4> &dims, void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std::vector<DimSize_t> &dims,
const void *input_, const void *scale_, const void *shift_, void *batchMean_, void *batchVar_, void *output_, const bool freeze) { const void *input_, const void *scale_, const void *shift_, void *batchMean_, void *batchVar_, void *output_, const bool freeze) {
// FIXME: missing convolution attributes as arguments // FIXME: missing convolution attributes as arguments
const I *input = static_cast<const I *>(input_); const I *input = static_cast<const I *>(input_);
...@@ -49,9 +49,8 @@ void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std ...@@ -49,9 +49,8 @@ void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std
O *output = static_cast<O *>(output_); O *output = static_cast<O *>(output_);
const DimSize_t nbBatch = dims[0]; const DimSize_t nbBatch = dims[0];
const DimSize_t nbChannels = dims[1]; const DimSize_t nbChannels = (dims.size() > 1) ? dims[1] : 1;
const DimSize_t featureMapSize = dims[2]*dims[3]; const DimSize_t featureMapSize = (dims.size() > 2) ? std::accumulate(dims.begin() + 2, dims.end(), 1, std::multiplies<DimSize_t>()) : 1;
if ((freeze == true) || (momentum == 0.0f)) { if ((freeze == true) || (momentum == 0.0f)) {
for (std::size_t batch = 0; batch < nbBatch; ++batch) { for (std::size_t batch = 0; batch < nbBatch; ++batch) {
......
...@@ -30,15 +30,13 @@ void Aidge::BatchNormImpl2D_cpu::forward() { ...@@ -30,15 +30,13 @@ void Aidge::BatchNormImpl2D_cpu::forward() {
AIDGE_ASSERT(op_.getInput(3), "missing input #3 for BatchNorm Operator"); AIDGE_ASSERT(op_.getInput(3), "missing input #3 for BatchNorm Operator");
AIDGE_ASSERT(op_.getInput(4), "missing input #4 for BatchNorm Operator"); AIDGE_ASSERT(op_.getInput(4), "missing input #4 for BatchNorm Operator");
AIDGE_ASSERT(op_.getOutput(0)->nbDims() == 4, "");
// Find the correct kernel type // Find the correct kernel type
const auto impl = Registrar<BatchNormImpl2D_cpu>::create(getBestMatch(getRequiredSpec())); const auto impl = Registrar<BatchNormImpl2D_cpu>::create(getBestMatch(getRequiredSpec()));
// Call kernel // Call kernel
impl.forward(op_.epsilon(), impl.forward(op_.epsilon(),
op_.momentum(), op_.momentum(),
op_.getInput(0)->template dims<4>(), op_.getInput(0)->dims(),
getCPUPtr(op_.getRawInput(0)), getCPUPtr(op_.getRawInput(0)),
getCPUPtr(op_.getRawInput(1)), getCPUPtr(op_.getRawInput(1)),
getCPUPtr(op_.getRawInput(2)), getCPUPtr(op_.getRawInput(2)),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment