diff --git a/src/operator/FCImpl.cpp b/src/operator/FCImpl.cpp index bc4a7a7cab91049c623e9a9e95ee63367da00722..995245907c8c87b0367c7edfa4493bd6b7faf660 100644 --- a/src/operator/FCImpl.cpp +++ b/src/operator/FCImpl.cpp @@ -57,9 +57,10 @@ void Aidge::FCImpl_cpu::forward() const auto& input2 = std::static_pointer_cast<Tensor>(mOp.getRawInput(2))->refCastFrom(input2Fallback, *std::static_pointer_cast<Tensor>(mOp.getRawOutput(0))); // Call kernel + const auto batchSize = (input0.dims().size() > 1) ? input0.dims()[0] : 1; kernelFunc(dynamic_cast<const FC_Op&>(mOp).getStaticAttributes(), - input0.dims()[0], - input0.size() / input0.dims()[0], + batchSize, + input0.size() / batchSize, input0.getImpl()->rawPtr(), input1.getImpl()->rawPtr(), input2.getImpl()->rawPtr(), getCPUPtr(mOp.getRawOutput(0))); }