Skip to content
Snippets Groups Projects
Commit 714604ff authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'dev' into 'master'

version 0.2.0 fix

See merge request !53
parents 26594930 c8bfdc62
No related branches found
No related tags found
1 merge request!53version 0.2.0 fix
Pipeline #43639 passed
......@@ -59,21 +59,17 @@ void Aidge::MatMulImpl_cpu::forward()
const std::size_t nbDims = dims0.size();
// initialize strides to iterate through data because of broadcasting
std::size_t *stride_post0;
std::size_t *stride_post1;
std::int32_t *stride_step0;
std::int32_t *stride_step1;
std::unique_ptr<std::size_t[]> stride_post0 = std::make_unique<std::size_t[]>(nbDims - 2);
std::unique_ptr<std::size_t[]> stride_post1 = std::make_unique<std::size_t[]>(nbDims - 2);
std::unique_ptr<std::int32_t[]> stride_step0 = std::make_unique<std::int32_t[]>(nbDims - 2);
std::unique_ptr<std::int32_t[]> stride_step1 = std::make_unique<std::int32_t[]>(nbDims - 2);
if (nbDims > 2) {
stride_post0 = new std::size_t[nbDims-2];
stride_post0[nbDims - 3] = 1;
stride_post1 = new std::size_t[nbDims-2];
stride_post1[nbDims - 3] = 1;
for (std::size_t i = nbDims-4; i != static_cast<std::size_t>(-1); --i) {
stride_post0[i] = stride_post0[i+1]*dims0[i+1];
stride_post1[i] = stride_post1[i+1]*dims1[i+1];
}
stride_step0 = new std::int32_t[nbDims-2];
stride_step1 = new std::int32_t[nbDims-2];
for (std::size_t i = 0; i != nbDims-2; ++i) {
stride_step0[i] = (dims0[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post0[i]) : 1;
stride_step1[i] = (dims1[i] == 1) ? 1 - static_cast<std::int32_t>(stride_post1[i]) : 1;
......@@ -111,12 +107,6 @@ void Aidge::MatMulImpl_cpu::forward()
dim = outDims.size() - 1 - keepDim0 - keepDim1;
}
}
if (nbDims > 2) {
delete[] stride_post0;
delete[] stride_post1;
delete[] stride_step0;
delete[] stride_step1;
}
}
// void Aidge::MatMulImpl_cpu::forward()
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment