Skip to content
Snippets Groups Projects
Commit b9d7bf31 authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge branch 'fix' into 'main'

Negative values for shape parameter in reshape

See merge request eclipse/aidge/aidge_core!75
parents a5fb8430 8cd7106f
No related branches found
No related tags found
1 merge request!75Negative values for shape parameter in reshape
Pipeline #37664 passed
......@@ -9,38 +9,48 @@
*
********************************************************************************/
#include <cstddef>
#include <cstddef> // std::size_t
#include <cstdint> // std::int64_t
#include <stdexcept> // std::runtime_error
#include <string>
#include <vector>
#include "aidge/operator/Reshape.hpp"
#include "aidge/utils/Types.h"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/utils/Types.h"
const std::string Aidge::Reshape_Op::Type = "Reshape";
void Aidge::Reshape_Op::computeOutputDims() {
// check inputs have been associated
// check input has been associated
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected");
}
DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size();
std::vector<DimSize_t> outDims;
// variables to handle a negative dimension
bool foundNegativeDimension = false;
std::size_t outSize = 1;
for(std::size_t i=0; i<nbOutDims; ++i)
DimIdx_t negativeIndex = 0;
for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i)
{
int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 1)
{
AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value");
std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i];
if (dimSize < 0) {
if (foundNegativeDimension) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator.");
}
foundNegativeDimension = true;
dimSize = 1;
negativeIndex = static_cast<DimIdx_t>(i);
}
outDims.push_back(dimSize);
outSize *= dimSize;
outDims.push_back(static_cast<DimSize_t>(dimSize));
outSize *= static_cast<DimSize_t>(dimSize);
}
if (getInput(0)->size() != outSize){
AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input");
if (foundNegativeDimension) {
outDims[negativeIndex] = (getInput(0) -> size()) / outSize;
}
mOutputs[0]->resize(outDims);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment