Skip to content
Snippets Groups Projects
Commit dd7b1b22 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed batchSize for unidimensional inputs

parent 7bef21e1
No related branches found
No related tags found
2 merge requests!50version 0.2.0,!37Support for recurrent networks
Pipeline #40593 passed
...@@ -57,9 +57,10 @@ void Aidge::FCImpl_cpu::forward() ...@@ -57,9 +57,10 @@ void Aidge::FCImpl_cpu::forward()
const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0)));
// Call kernel // Call kernel
const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1;
kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(), kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(),
input0.dims()[0], batchSize,
input0.size() / input0.dims()[0], input0.size() / batchSize,
input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(), input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(),
getCPUPtr(mOp.getRawOutput(0))); getCPUPtr(mOp.getRawOutput(0)));
} }
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment