diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index ffb01bc6af6eb54ec4690d276434482fc3eb71da..a6904740f2b745b0c22ba1e7dd8933bd13e2e8dd 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -55,23 +55,7 @@ public: } - void computeOutputDims() override final { - if (!getInput(0)->empty() && !getInput(1)->empty()) - { - std::vector<std::size_t> outDims; - for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++) - { - outDims.push_back(getInput(0)->dims()[i]); - } - size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0; - for (std::size_t i = 0; i < getInput(1)->nbDims(); i++) - { - if(i != secondToLastIdx) - outDims.push_back(getInput(1)->dims()[i]); - } - mOutputs[0]->resize(outDims); - } - } + void computeOutputDims() override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override { diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp index 666ed3921ed1190a91935bd9f38303e23963d912..4bb54e83b69b200215a5ae7a274e6b2692656380 100644 --- a/src/operator/MatMul.cpp +++ b/src/operator/MatMul.cpp @@ -9,8 +9,56 @@ * ********************************************************************************/ +#include <algorithm> #include <string> +#include <vector> #include "aidge/operator/MatMul.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::MatMul_Op::Type = "MatMul"; \ No newline at end of file +const std::string Aidge::MatMul_Op::Type = "MatMul"; + +void Aidge::MatMul_Op::computeOutputDims() { + if (!getInput(0)->empty() && !getInput(1)->empty()) + { + const auto dims0 = getInput(0)->dims(); + const auto dims1 = getInput(1)->dims(); + + if (dims0.size() > 2 && dims1.size() > 2) + { + bool supportedSizes = true; + std::size_t d0 = dims0.size()-3, d1 = dims1.size()-3; + while(d0>0 && d1>0 && supportedSizes) + { + if(dims0[d0] != dims1[d1]) + supportedSizes = false; + + d0--; + d1--; + } + if(!supportedSizes) + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported sizes for MatMul!"); + } + + std::size_t secondToLastIdx2 = dims1.size()>1 ? dims1.size() - 2 : dims1.size() - 1; + if(dims0[dims0.size() - 1] != dims1[secondToLastIdx2]) + AIDGE_THROW_OR_ABORT(std::runtime_error, "Inner dimension missmatch for MatMul!"); + + std::vector<std::size_t> outDims; + if(dims0.size() > 2 || dims1.size() > 2) + { + if(dims0.size() > dims1.size()) + std::copy_n(dims0.begin(), dims0.size()-2, std::back_inserter(outDims)); + else + std::copy_n(dims1.begin(), dims1.size()-2, std::back_inserter(outDims)); + } + + if(dims0.size() > 1) + outDims.push_back(dims0[dims0.size()-2]); + if(dims1.size() > 1) + outDims.push_back(dims1[dims1.size() - 1]); + + mOutputs[0]->resize(outDims); + } +} \ No newline at end of file