diff --git a/include/aidge/backend/cpu/operator/BatchNormImpl.hpp b/include/aidge/backend/cpu/operator/BatchNormImpl.hpp index 36a100b21edc6cd63a0176c89f2f1e57c10001c7..03dd5d1d04d5263eb84843925a1ce9ee3263423f 100644 --- a/include/aidge/backend/cpu/operator/BatchNormImpl.hpp +++ b/include/aidge/backend/cpu/operator/BatchNormImpl.hpp @@ -29,7 +29,7 @@ using BatchNorm2D_Op = BatchNorm_Op<2>; using BatchNormImpl2D_cpu = OperatorImpl_cpu<BatchNorm_Op<2>, void(float, float, - const std::array<DimSize_t, 4> &, + const std::vector<DimSize_t> &, const void *, const void *, const void *, diff --git a/include/aidge/backend/cpu/operator/BatchNormImpl_kernels.hpp b/include/aidge/backend/cpu/operator/BatchNormImpl_kernels.hpp index ec71e3b8e37e344c551fd643dc7b3957bdddcb67..cf97f7372ac528ef28d0f378beb2650af32bfa30 100644 --- a/include/aidge/backend/cpu/operator/BatchNormImpl_kernels.hpp +++ b/include/aidge/backend/cpu/operator/BatchNormImpl_kernels.hpp @@ -38,7 +38,7 @@ namespace Aidge { * @param output_ Output Tensor. */ template <class I, class P, class O> -void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std::array<DimSize_t, 4> &dims, +void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std::vector<DimSize_t> &dims, const void *input_, const void *scale_, const void *shift_, void *batchMean_, void *batchVar_, void *output_, const bool freeze) { // FIXME: missing convolution attributes as arguments const I *input = static_cast<const I *>(input_); @@ -49,9 +49,8 @@ void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std O *output = static_cast<O *>(output_); const DimSize_t nbBatch = dims[0]; - const DimSize_t nbChannels = dims[1]; - const DimSize_t featureMapSize = dims[2]*dims[3]; - + const DimSize_t nbChannels = (dims.size() > 1) ? dims[1] : 1; + const DimSize_t featureMapSize = (dims.size() > 2) ? std::accumulate(dims.begin() + 2, dims.end(), 1, std::multiplies<DimSize_t>()) : 1; if ((freeze == true) || (momentum == 0.0f)) { for (std::size_t batch = 0; batch < nbBatch; ++batch) { diff --git a/src/operator/BatchNormImpl.cpp b/src/operator/BatchNormImpl.cpp index 9f1d986e63f14e6038c80054e5e3bc631ec24224..af59310830a865b496019e7620cfb661721ff39a 100644 --- a/src/operator/BatchNormImpl.cpp +++ b/src/operator/BatchNormImpl.cpp @@ -30,15 +30,13 @@ void Aidge::BatchNormImpl2D_cpu::forward() { AIDGE_ASSERT(op_.getInput(3), "missing input #3 for BatchNorm Operator"); AIDGE_ASSERT(op_.getInput(4), "missing input #4 for BatchNorm Operator"); - AIDGE_ASSERT(op_.getOutput(0)->nbDims() == 4, ""); - // Find the correct kernel type const auto impl = Registrar<BatchNormImpl2D_cpu>::create(getBestMatch(getRequiredSpec())); // Call kernel impl.forward(op_.epsilon(), op_.momentum(), - op_.getInput(0)->template dims<4>(), + op_.getInput(0)->dims(), getCPUPtr(op_.getRawInput(0)), getCPUPtr(op_.getRawInput(1)), getCPUPtr(op_.getRawInput(2)),