From 7aec113a6869024a0e2e46835ddc30703b83d7d9 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 30 Jan 2024 16:23:37 +0100 Subject: [PATCH] fix computOutputDims --- include/aidge/operator/MatMul.hpp | 18 +---------- src/operator/MatMul.cpp | 50 ++++++++++++++++++++++++++++++- 2 files changed, 50 insertions(+), 18 deletions(-) diff --git a/include/aidge/operator/MatMul.hpp b/include/aidge/operator/MatMul.hpp index 5f06e8c2a..a6904740f 100644 --- a/include/aidge/operator/MatMul.hpp +++ b/include/aidge/operator/MatMul.hpp @@ -55,23 +55,7 @@ public: } - void computeOutputDims() override final { - if (!getInput(0)->empty() && !getInput(1)->empty()) - { - std::vector<std::size_t> outDims; - for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++) - { - outDims.push_back(getInput(0)->dims()[i]); - } - size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0; - for (std::size_t i = 0; i < getInput(1)->nbDims(); i++) - { - if(i != secondToLastIdx) - outDims.push_back(getInput(1)->dims()[i]); - } - mOutputs[0]->resize(outDims); - } - } + void computeOutputDims() override final; void setBackend(const std::string& name, DeviceIdx_t device = 0) override { diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp index 666ed3921..4bb54e83b 100644 --- a/src/operator/MatMul.cpp +++ b/src/operator/MatMul.cpp @@ -9,8 +9,56 @@ * ********************************************************************************/ +#include <algorithm> #include <string> +#include <vector> #include "aidge/operator/MatMul.hpp" +#include "aidge/utils/Types.h" +#include "aidge/utils/ErrorHandling.hpp" -const std::string Aidge::MatMul_Op::Type = "MatMul"; \ No newline at end of file +const std::string Aidge::MatMul_Op::Type = "MatMul"; + +void Aidge::MatMul_Op::computeOutputDims() { + if (!getInput(0)->empty() && !getInput(1)->empty()) + { + const auto dims0 = getInput(0)->dims(); + const auto dims1 = getInput(1)->dims(); + + if (dims0.size() > 2 && dims1.size() > 2) + { + bool supportedSizes = true; + std::size_t d0 = dims0.size()-3, d1 = dims1.size()-3; + while(d0>0 && d1>0 && supportedSizes) + { + if(dims0[d0] != dims1[d1]) + supportedSizes = false; + + d0--; + d1--; + } + if(!supportedSizes) + AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported sizes for MatMul!"); + } + + std::size_t secondToLastIdx2 = dims1.size()>1 ? dims1.size() - 2 : dims1.size() - 1; + if(dims0[dims0.size() - 1] != dims1[secondToLastIdx2]) + AIDGE_THROW_OR_ABORT(std::runtime_error, "Inner dimension missmatch for MatMul!"); + + std::vector<std::size_t> outDims; + if(dims0.size() > 2 || dims1.size() > 2) + { + if(dims0.size() > dims1.size()) + std::copy_n(dims0.begin(), dims0.size()-2, std::back_inserter(outDims)); + else + std::copy_n(dims1.begin(), dims1.size()-2, std::back_inserter(outDims)); + } + + if(dims0.size() > 1) + outDims.push_back(dims0[dims0.size()-2]); + if(dims1.size() > 1) + outDims.push_back(dims1[dims1.size() - 1]); + + mOutputs[0]->resize(outDims); + } +} \ No newline at end of file -- GitLab