Skip to content
Snippets Groups Projects
Commit 24e2dec1 authored by Maxence Naud's avatar Maxence Naud
Browse files

Refactor 'computeOutputDims()' member function of MatMul Operator to add dimensions broadcasting

parent 9743cb8f
No related branches found
No related tags found
No related merge requests found
......@@ -23,40 +23,46 @@ void Aidge::MatMul_Op::computeOutputDims() {
if (!getInput(0) || !getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Missing input. Cannot compute output dimensions for MatMul Operator.");
}
if (!getInput(0)->empty() && !getInput(1)->empty())
if (getInput(0)->empty() && getInput(1)->empty()) {
// both inputs are scalar
mOutputs[0]->resize({});
}
else if (!getInput(0)->empty() && !getInput(1)->empty())
{
const auto dims0 = getInput(0)->dims();
const auto dims1 = getInput(1)->dims();
if (dims0.size() > 2 && dims1.size() > 2)
{
for (std::size_t d0 = dims0.size()-3, d1 = dims1.size()-3;
(d0>0) && (d1>0);
--d0, --d1)
{
if(dims0[d0] != dims1[d1])
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported sizes for MatMul!");
}
std::vector<std::size_t> dims0 = getInput(0)->dims();
std::vector<std::size_t> dims1 = getInput(1)->dims();
if (dims0.size() == 1) {
dims0.insert(dims0.cbegin(), 1);
}
if (dims1.size() == 1) {
dims1.push_back(1);
}
const std::size_t dims_size = std::max(dims0.size(), dims1.size());
if (dims0.size() > dims1.size()) {
dims1.insert(dims1.cbegin(), dims0.begin(), dims0.end() - dims1.size());
}
else if (dims1.size() > dims0.size()) {
dims0.insert(dims0.cbegin(), dims1.begin(), dims1.end() - dims0.size());
}
AIDGE_ASSERT(dims0[dims_size-1] == dims1[dims_size-2], "Incompatible matrices sizes.");
std::size_t secondToLastIdx2 = dims1.size()>1 ? dims1.size() - 2 : dims1.size() - 1;
if(dims0[dims0.size() - 1] != dims1[secondToLastIdx2])
AIDGE_THROW_OR_ABORT(std::runtime_error, "Inner dimension missmatch for MatMul!");
std::vector<std::size_t> outDims;
if(dims0.size() > 2 || dims1.size() > 2)
{
if(dims0.size() > dims1.size())
std::copy_n(dims0.begin(), dims0.size()-2, std::back_inserter(outDims));
else
std::copy_n(dims1.begin(), dims1.size()-2, std::back_inserter(outDims));
std::vector<std::size_t> outDims = std::vector<std::size_t>(dims_size-2, 1);
for (std::size_t i = 0; i < dims_size-2; ++i) {
AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad vector dimension.");
outDims[i] = std::max(dims0[i], dims1[i]);
}
if(dims0.size() > 1)
outDims.push_back(dims0[dims0.size()-2]);
if(dims1.size() > 1)
outDims.push_back(dims1[dims1.size() - 1]);
// keep second-to-last dimension of dims0
if (dims0.size() > 1)
outDims.push_back(dims0[dims_size-2]);
// keep last dimension of dims1
if (dims1.size() > 1)
outDims.push_back(dims1[dims_size-1]);
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment