Skip to content
Snippets Groups Projects
Commit 7aec113a authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

fix computOutputDims

parent 59b3c0a9
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!76Matmul rework
......@@ -55,23 +55,7 @@ public:
}
void computeOutputDims() override final {
if (!getInput(0)->empty() && !getInput(1)->empty())
{
std::vector<std::size_t> outDims;
for (std::size_t i = 0; i < getInput(0)->nbDims()-1; i++)
{
outDims.push_back(getInput(0)->dims()[i]);
}
size_t secondToLastIdx = getInput(1)->nbDims() > 1 ? getInput(1)->nbDims() - 2 : 0;
for (std::size_t i = 0; i < getInput(1)->nbDims(); i++)
{
if(i != secondToLastIdx)
outDims.push_back(getInput(1)->dims()[i]);
}
mOutputs[0]->resize(outDims);
}
}
void computeOutputDims() override final;
void setBackend(const std::string& name, DeviceIdx_t device = 0) override {
......
......@@ -9,8 +9,56 @@
*
********************************************************************************/
#include <algorithm>
#include <string>
#include <vector>
#include "aidge/operator/MatMul.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
const std::string Aidge::MatMul_Op::Type = "MatMul";
\ No newline at end of file
const std::string Aidge::MatMul_Op::Type = "MatMul";
void Aidge::MatMul_Op::computeOutputDims() {
if (!getInput(0)->empty() && !getInput(1)->empty())
{
const auto dims0 = getInput(0)->dims();
const auto dims1 = getInput(1)->dims();
if (dims0.size() > 2 && dims1.size() > 2)
{
bool supportedSizes = true;
std::size_t d0 = dims0.size()-3, d1 = dims1.size()-3;
while(d0>0 && d1>0 && supportedSizes)
{
if(dims0[d0] != dims1[d1])
supportedSizes = false;
d0--;
d1--;
}
if(!supportedSizes)
AIDGE_THROW_OR_ABORT(std::runtime_error, "Unsupported sizes for MatMul!");
}
std::size_t secondToLastIdx2 = dims1.size()>1 ? dims1.size() - 2 : dims1.size() - 1;
if(dims0[dims0.size() - 1] != dims1[secondToLastIdx2])
AIDGE_THROW_OR_ABORT(std::runtime_error, "Inner dimension missmatch for MatMul!");
std::vector<std::size_t> outDims;
if(dims0.size() > 2 || dims1.size() > 2)
{
if(dims0.size() > dims1.size())
std::copy_n(dims0.begin(), dims0.size()-2, std::back_inserter(outDims));
else
std::copy_n(dims1.begin(), dims1.size()-2, std::back_inserter(outDims));
}
if(dims0.size() > 1)
outDims.push_back(dims0[dims0.size()-2]);
if(dims1.size() > 1)
outDims.push_back(dims1[dims1.size() - 1]);
mOutputs[0]->resize(outDims);
}
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment