Skip to content
Snippets Groups Projects
Commit cff1ec49 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'fix_lstm2' into 'dev'

Partial fix for issue eclipse/aidge/aidge_core#167

See merge request !107
parents 9e272f0c d8ae9830
No related branches found
No related tags found
2 merge requests!118v0.4.0,!107Partial fix for issue eclipse/aidge/aidge_core#167
Pipeline #59535 passed
...@@ -29,7 +29,7 @@ using BatchNorm2D_Op = BatchNorm_Op<2>; ...@@ -29,7 +29,7 @@ using BatchNorm2D_Op = BatchNorm_Op<2>;
using BatchNormImpl2D_cpu = OperatorImpl_cpu<BatchNorm_Op<2>, using BatchNormImpl2D_cpu = OperatorImpl_cpu<BatchNorm_Op<2>,
void(float, void(float,
float, float,
const std::array<DimSize_t, 4> &, const std::vector<DimSize_t> &,
const void *, const void *,
const void *, const void *,
const void *, const void *,
......
...@@ -38,7 +38,7 @@ namespace Aidge { ...@@ -38,7 +38,7 @@ namespace Aidge {
* @param output_ Output Tensor. * @param output_ Output Tensor.
*/ */
template <class I, class P, class O> template <class I, class P, class O>
void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std::array<DimSize_t, 4> &dims, void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std::vector<DimSize_t> &dims,
const void *input_, const void *scale_, const void *shift_, void *batchMean_, void *batchVar_, void *output_, const bool freeze) { const void *input_, const void *scale_, const void *shift_, void *batchMean_, void *batchVar_, void *output_, const bool freeze) {
// FIXME: missing convolution attributes as arguments // FIXME: missing convolution attributes as arguments
const I *input = static_cast<const I *>(input_); const I *input = static_cast<const I *>(input_);
...@@ -49,9 +49,8 @@ void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std ...@@ -49,9 +49,8 @@ void BatchNormImpl2D_cpu_forward_kernel(float epsilon, float momentum, const std
O *output = static_cast<O *>(output_); O *output = static_cast<O *>(output_);
const DimSize_t nbBatch = dims[0]; const DimSize_t nbBatch = dims[0];
const DimSize_t nbChannels = dims[1]; const DimSize_t nbChannels = (dims.size() > 1) ? dims[1] : 1;
const DimSize_t featureMapSize = dims[2]*dims[3]; const DimSize_t featureMapSize = (dims.size() > 2) ? std::accumulate(dims.begin() + 2, dims.end(), 1, std::multiplies<DimSize_t>()) : 1;
if ((freeze == true) || (momentum == 0.0f)) { if ((freeze == true) || (momentum == 0.0f)) {
for (std::size_t batch = 0; batch < nbBatch; ++batch) { for (std::size_t batch = 0; batch < nbBatch; ++batch) {
......
...@@ -30,15 +30,13 @@ void Aidge::BatchNormImpl2D_cpu::forward() { ...@@ -30,15 +30,13 @@ void Aidge::BatchNormImpl2D_cpu::forward() {
AIDGE_ASSERT(op_.getInput(3), "missing input #3 for BatchNorm Operator"); AIDGE_ASSERT(op_.getInput(3), "missing input #3 for BatchNorm Operator");
AIDGE_ASSERT(op_.getInput(4), "missing input #4 for BatchNorm Operator"); AIDGE_ASSERT(op_.getInput(4), "missing input #4 for BatchNorm Operator");
AIDGE_ASSERT(op_.getOutput(0)->nbDims() == 4, "");
// Find the correct kernel type // Find the correct kernel type
const auto impl = Registrar<BatchNormImpl2D_cpu>::create(getBestMatch(getRequiredSpec())); const auto impl = Registrar<BatchNormImpl2D_cpu>::create(getBestMatch(getRequiredSpec()));
// Call kernel // Call kernel
impl.forward(op_.epsilon(), impl.forward(op_.epsilon(),
op_.momentum(), op_.momentum(),
op_.getInput(0)->template dims<4>(), op_.getInput(0)->dims(),
getCPUPtr(op_.getRawInput(0)), getCPUPtr(op_.getRawInput(0)),
getCPUPtr(op_.getRawInput(1)), getCPUPtr(op_.getRawInput(1)),
getCPUPtr(op_.getRawInput(2)), getCPUPtr(op_.getRawInput(2)),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment