Skip to content
Snippets Groups Projects
Commit 0a88775b authored by Maxence Naud's avatar Maxence Naud
Browse files

[Fix] 'keepDim' variable

parent 24e2dec1
No related branches found
No related tags found
No related merge requests found
......@@ -32,6 +32,11 @@ void Aidge::MatMul_Op::computeOutputDims() {
std::vector<std::size_t> dims0 = getInput(0)->dims();
std::vector<std::size_t> dims1 = getInput(1)->dims();
// keep second-to-last dimension of dims0
const bool keepDim0 = dims0.size() > 1;
// keep last dimension of dims1
const bool keepDim1 = dims1.size() > 1;
if (dims0.size() == 1) {
dims0.insert(dims0.cbegin(), 1);
}
......@@ -42,10 +47,10 @@ void Aidge::MatMul_Op::computeOutputDims() {
if (dims0.size() > dims1.size()) {
dims1.insert(dims1.cbegin(), dims0.begin(), dims0.end() - dims1.size());
dims1.insert(dims1.cbegin(), dims0.begin(), dims0.end() - dims1.size());
}
else if (dims1.size() > dims0.size()) {
dims0.insert(dims0.cbegin(), dims1.begin(), dims1.end() - dims0.size());
dims0.insert(dims0.cbegin(), dims1.begin(), dims1.end() - dims0.size());
}
AIDGE_ASSERT(dims0[dims_size-1] == dims1[dims_size-2], "Incompatible matrices sizes.");
......@@ -56,11 +61,10 @@ void Aidge::MatMul_Op::computeOutputDims() {
outDims[i] = std::max(dims0[i], dims1[i]);
}
// keep second-to-last dimension of dims0
if (dims0.size() > 1)
// use keepDim0 instead of dims0.size() because dims0 has been modified
if (keepDim0)
outDims.push_back(dims0[dims_size-2]);
// keep last dimension of dims1
if (dims1.size() > 1)
if (keepDim1)
outDims.push_back(dims1[dims_size-1]);
mOutputs[0]->resize(outDims);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment