Skip to content
Snippets Groups Projects
Commit 6424edc9 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'main' into dev

parents 1bd36647 b9d7bf31
No related branches found
No related tags found
1 merge request!105version 0.2.0
Pipeline #38205 passed
......@@ -9,38 +9,48 @@
*
********************************************************************************/
#include <cstddef>
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/operator/Reshape.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::Reshape_Op::Type = "Reshape";
void Aidge::Reshape_Op::computeOutputDims() {
// check inputs have been associated
// check input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size();
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
for(std::size_t i=0; i<nbOutDims; ++i)
DimIdx_t negativeIndex = 0;
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 1)
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value");
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
outDims.push_back(dimSize);
outSize *= dimSize;
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
}
if (getInput(0)->size() != outSize){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input");
if (foundNegativeDimension) {
outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
}
mOutputs[0]->resize(outDims);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment