From ecbcec98b20a44b629dea1a66a561a8a3e56a44d Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Mon, 12 Feb 2024 10:44:20 +0000 Subject: [PATCH] [Fix] 'keepDim' variable --- src/operator/MatMul.cpp | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp index fa63608c7..4c7195af2 100644 --- a/src/operator/MatMul.cpp +++ b/src/operator/MatMul.cpp @@ -32,6 +32,11 @@ void Aidge::MatMul_Op::computeOutputDims() { std::vector<std::size_t> dims0 = getInput(0)->dims(); std::vector<std::size_t> dims1 = getInput(1)->dims(); + // keep second-to-last dimension of dims0 + const bool keepDim0 = dims0.size() > 1; + // keep last dimension of dims1 + const bool keepDim1 = dims1.size() > 1; + if (dims0.size() == 1) { dims0.insert(dims0.cbegin(), 1); } @@ -42,10 +47,10 @@ void Aidge::MatMul_Op::computeOutputDims() { if (dims0.size() > dims1.size()) { - dims1.insert(dims1.cbegin(), dims0.begin(), dims0.end() - dims1.size()); + dims1.insert(dims1.cbegin(), dims0.begin(), dims0.end() - dims1.size()); } else if (dims1.size() > dims0.size()) { - dims0.insert(dims0.cbegin(), dims1.begin(), dims1.end() - dims0.size()); + dims0.insert(dims0.cbegin(), dims1.begin(), dims1.end() - dims0.size()); } AIDGE_ASSERT(dims0[dims_size-1] == dims1[dims_size-2], "Incompatible matrices sizes."); @@ -56,11 +61,10 @@ void Aidge::MatMul_Op::computeOutputDims() { outDims[i] = std::max(dims0[i], dims1[i]); } - // keep second-to-last dimension of dims0 - if (dims0.size() > 1) + // use keepDim0 instead of dims0.size() because dims0 has been modified + if (keepDim0) outDims.push_back(dims0[dims_size-2]); - // keep last dimension of dims1 - if (dims1.size() > 1) + if (keepDim1) outDims.push_back(dims1[dims_size-1]); mOutputs[0]->resize(outDims); -- GitLab