Skip to content
Snippets Groups Projects
Commit f230cc4b authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Allow BatchNorm implementation to handle any number of dimenions

parent a2bcd467
No related branches found
No related tags found
No related merge requests found
Pipeline #59300 passed
...@@ -29,7 +29,7 @@ using BatchNorm2D_Op = BatchNorm_Op<2>; ...@@ -29,7 +29,7 @@ using BatchNorm2D_Op = BatchNorm_Op<2>;
using BatchNormImpl2D_cpu = OperatorImpl_cpu<BatchNorm_Op<2>, using BatchNormImpl2D_cpu = OperatorImpl_cpu<BatchNorm_Op<2>,
void(float, void(float,
float, float,
const std::array<DimSize_t, 4> &, const std::vector<DimSize_t> &,
const void *, const void *,
const void *, const void *,
const void *, const void *,
......
...@@ -38,7 +38,7 @@ namespace Aidge { ...@@ -38,7 +38,7 @@ namespace Aidge {
* @param output_ Output Tensor. * @param output_ Output Tensor.
*/ */
template <class I, class P, class O> template <class I, class P, class O>
void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std::array<DimSize_t, 4> &dims, void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std::vector<DimSize_t> &dims,
const void *input_, const void *scale_, const void *shift_, void *batchMean_, void *batchVar_, void *output_, const bool freeze) { const void *input_, const void *scale_, const void *shift_, void *batchMean_, void *batchVar_, void *output_, const bool freeze) {
// FIXME: missing convolution attributes as arguments // FIXME: missing convolution attributes as arguments
const I *input = static_cast<const I *>(input_); const I *input = static_cast<const I *>(input_);
...@@ -49,9 +49,8 @@ void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std ...@@ -49,9 +49,8 @@ void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std
O *output = static_cast<O *>(output_); O *output = static_cast<O *>(output_);
const DimSize_t nbBatch = dims[0]; const DimSize_t nbBatch = dims[0];
const DimSize_t nbChannels = dims[1]; const DimSize_t nbChannels = (dims.size() > 1) ? dims[1] : 1;
const DimSize_t featureMapSize = dims[2]*dims[3]; const DimSize_t featureMapSize = (dims.size() > 2) ? std::accumulate(dims.begin() + 2, dims.end(), 1, std::multiplies<DimSize_t>()) : 1;
if ((freeze == true) || (momentum == 0.0f)) { if ((freeze == true) || (momentum == 0.0f)) {
for (std::size_t batch = 0; batch < nbBatch; ++batch) { for (std::size_t batch = 0; batch < nbBatch; ++batch) {
......
...@@ -30,15 +30,13 @@ void Aidge::BatchNormImpl2D_cpu::forward() { ...@@ -30,15 +30,13 @@ void Aidge::BatchNormImpl2D_cpu::forward() {
AIDGE_ASSERT(op_.getInput(3), "missing input #3 for BatchNorm Operator"); AIDGE_ASSERT(op_.getInput(3), "missing input #3 for BatchNorm Operator");
AIDGE_ASSERT(op_.getInput(4), "missing input #4 for BatchNorm Operator"); AIDGE_ASSERT(op_.getInput(4), "missing input #4 for BatchNorm Operator");
AIDGE_ASSERT(op_.getOutput(0)->nbDims() == 4, "");
// Find the correct kernel type // Find the correct kernel type
const auto impl = Registrar<BatchNormImpl2D_cpu>::create(getBestMatch(getRequiredSpec())); const auto impl = Registrar<BatchNormImpl2D_cpu>::create(getBestMatch(getRequiredSpec()));
// Call kernel // Call kernel
impl.forward(op_.epsilon(), impl.forward(op_.epsilon(),
op_.momentum(), op_.momentum(),
op_.getInput(0)->template dims<4>(), op_.getInput(0)->dims(),
getCPUPtr(op_.getRawInput(0)), getCPUPtr(op_.getRawInput(0)),
getCPUPtr(op_.getRawInput(1)), getCPUPtr(op_.getRawInput(1)),
getCPUPtr(op_.getRawInput(2)), getCPUPtr(op_.getRawInput(2)),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment