Skip to content
Snippets Groups Projects
Commit ad799c3a authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

fix computOutputDims

parent db487d89
No related branches found
No related tags found
No related merge requests found
......@@ -55,23 +55,7 @@ public:
}
void computeOutputDims() override final {
if (!getInput(0)->empty() && !getInput(1)->empty())
{
std::vector<std::size_t> outDims;
for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++)
{
outDims.push_back(getInput(0)->dims()[i]);
}
size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0;
for (std::size_t i = 0; i < getInput(1)->nbDims(); i++)
{
if(i != secondToLastIdx)
outDims.push_back(getInput(1)->dims()[i]);
}
mOutputs[0]->resize(outDims);
}
}
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
......
......@@ -9,8 +9,56 @@
*
********************************************************************************/
#include <algorithm>
#include <string>
#include <vector>
#include "aidge/operator/MatMul.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::MatMul_Op::Type = "MatMul";
\ No newline at end of file
const std::string Aidge::MatMul_Op::Type = "MatMul";
void Aidge::MatMul_Op::computeOutputDims() {
if (!getInput(0)->empty() && !getInput(1)->empty())
{
const auto dims0 = getInput(0)->dims();
const auto dims1 = getInput(1)->dims();
if (dims0.size() > 2 && dims1.size() > 2)
{
bool supportedSizes = true;
std::size_t d0 = dims0.size()-3, d1 = dims1.size()-3;
while(d0>0 && d1>0 && supportedSizes)
{
if(dims0[d0] != dims1[d1])
supportedSizes = false;
d0--;
d1--;
}
if(!supportedSizes)
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported sizes for MatMul!");
}
std::size_t secondToLastIdx2 = dims1.size()>1 ? dims1.size() - 2 : dims1.size() - 1;
if(dims0[dims0.size() - 1] != dims1[secondToLastIdx2])
AIDGE_THROW_OR_ABORT(std::runtime_error, "Inner dimension missmatch for MatMul!");
std::vector<std::size_t> outDims;
if(dims0.size() > 2 || dims1.size() > 2)
{
if(dims0.size() > dims1.size())
std::copy_n(dims0.begin(), dims0.size()-2, std::back_inserter(outDims));
else
std::copy_n(dims1.begin(), dims1.size()-2, std::back_inserter(outDims));
}
if(dims0.size() > 1)
outDims.push_back(dims0[dims0.size()-2]);
if(dims1.size() > 1)
outDims.push_back(dims1[dims1.size() - 1]);
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment