diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index b0eea3c1f9f7054021b631c85e0f80e7f8845da6..c1a7c35e395418995a720efd49c7cfce0801863e 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -9,38 +9,48 @@ * ********************************************************************************/ -#include <cstddef> +#include <cstddef> // std::size_t +#include <cstdint> // std::int64_t +#include <stdexcept> // std::runtime_error #include <string> #include <vector> #include "aidge/operator/Reshape.hpp" -#include "aidge/utils/Types.h" #include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" const std::string Aidge::Reshape_Op::Type = "Reshape"; void Aidge::Reshape_Op::computeOutputDims() { - // check inputs have been associated + // check input has been associated if (!getInput(0)) { AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); } - DimSize_t nbOutDims = this->template getAttr<ReshapeAttr::Shape>().size(); std::vector<DimSize_t> outDims; + + // variables to handle a negative dimension + bool foundNegativeDimension = false; std::size_t outSize = 1; - for(std::size_t i=0; i<nbOutDims; ++i) + DimIdx_t negativeIndex = 0; + + for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) { - int dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; - if (dimSize < 1) - { - AIDGE_THROW_OR_ABORT(std::runtime_error, "bad dimension value"); + std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; + if (dimSize < 0) { + if (foundNegativeDimension) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); + } + foundNegativeDimension = true; + dimSize = 1; + negativeIndex = static_cast<DimIdx_t>(i); } - outDims.push_back(dimSize); - outSize *= dimSize; + outDims.push_back(static_cast<DimSize_t>(dimSize)); + outSize *= static_cast<DimSize_t>(dimSize); } - if (getInput(0)->size() != outSize){ - AIDGE_THROW_OR_ABORT(std::runtime_error, "Output shape must give the same size as input"); + if (foundNegativeDimension) { + outDims[negativeIndex] = (getInput(0) -> size()) / outSize; } mOutputs[0]->resize(outDims);