Skip to content
Snippets Groups Projects

Refactor OperatorImpl for backend/export

Merged Olivier BICHLER requested to merge backend_export into dev
2 unresolved threads
1 file
+ 7
7
Compare changes
  • Side-by-side
  • Inline
@@ -18,19 +18,19 @@ namespace Aidge {
@@ -18,19 +18,19 @@ namespace Aidge {
template <class I, class O>
template <class I, class O>
void MatMulImpl_cpu_forward_kernel(const std::size_t n, const std::size_t k, const std::size_t m,
void MatMulImpl_cpu_forward_kernel(const std::size_t n, const std::size_t k, const std::size_t m,
const void* input1_, const void* input2_, void* output_) {
const void* input1_, const void* input2_, void* __restrict__ output_) {
// FIXME: missing MatMul parameters as arguments
// FIXME: missing MatMul parameters as arguments
const I* input1 = static_cast<const I*>(input1_);
const I* input1 = static_cast<const I*>(input1_);
const I* input2 = static_cast<const I*>(input2_);
const I* input2 = static_cast<const I*>(input2_);
O* output = static_cast<O*>(output_);
O* __restrict__ output = static_cast<O* __restrict__>(output_);
 
 
std::memset(output, O(0), n * m * sizeof(O));
for (std::size_t i = 0; i < n; ++i) {
for (std::size_t i = 0; i < n; ++i) {
for (std::size_t j = 0; j < m; ++j) {
for (std::size_t l = 0; l < k; ++l) {
O sum = O(0);
for (std::size_t j = 0; j < m; ++j) {
for (std::size_t l = 0; l < k; ++l) {
output[i*m + j] += static_cast<O>(input1[i*k + l] * input2[l*m + j]);
sum += static_cast<O>(input1[i*k + l] * input2[l*m + j]);
}
}
output[i*m + j] = sum;
}
}
}
}
}
}
Loading