Skip to content
Snippets Groups Projects
Commit 8e7e8c69 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Basic MatMul optimization

parent 5835bf04
No related branches found
No related tags found
2 merge requests!93Release v0.3.0,!79Refactor OperatorImpl for backend/export
Pipeline #54704 failed
...@@ -18,19 +18,19 @@ namespace Aidge { ...@@ -18,19 +18,19 @@ namespace Aidge {
template <class I, class O> template <class I, class O>
void MatMulImpl_cpu_forward_kernel(const std::size_t n, const std::size_t k, const std::size_t m, void MatMulImpl_cpu_forward_kernel(const std::size_t n, const std::size_t k, const std::size_t m,
const void* input1_, const void* input2_, void* output_) { const void* input1_, const void* input2_, void* __restrict__ output_) {
// FIXME: missing MatMul parameters as arguments // FIXME: missing MatMul parameters as arguments
const I* input1 = static_cast<const I*>(input1_); const I* input1 = static_cast<const I*>(input1_);
const I* input2 = static_cast<const I*>(input2_); const I* input2 = static_cast<const I*>(input2_);
O* output = static_cast<O*>(output_); O* __restrict__ output = static_cast<O* __restrict__>(output_);
std::memset(output, O(0), n * m * sizeof(O));
for (std::size_t i = 0; i < n; ++i) { for (std::size_t i = 0; i < n; ++i) {
for (std::size_t j = 0; j < m; ++j) { for (std::size_t l = 0; l < k; ++l) {
O sum = O(0); for (std::size_t j = 0; j < m; ++j) {
for (std::size_t l = 0; l < k; ++l) { output[i*m + j] += static_cast<O>(input1[i*k + l] * input2[l*m + j]);
sum += static_cast<O>(input1[i*k + l] * input2[l*m + j]);
} }
output[i*m + j] = sum;
} }
} }
} }
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment