From 10eba2d39e483ed2c5fa5dba4923bfe223ea6d1d Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Mon, 5 Feb 2024 13:06:02 +0000 Subject: [PATCH] [Upd] Reshape.cpp and Gather.cpp computeOutputDims() function to check input emptyness --- src/operator/Gather.cpp | 30 ++++++++++++++------------ src/operator/Reshape.cpp | 46 +++++++++++++++++++++------------------- 2 files changed, 41 insertions(+), 35 deletions(-) diff --git a/src/operator/Gather.cpp b/src/operator/Gather.cpp index 3eafe99ef..b5f9d738a 100644 --- a/src/operator/Gather.cpp +++ b/src/operator/Gather.cpp @@ -9,8 +9,8 @@ * ********************************************************************************/ -#include <cassert> #include <cstddef> +#include <cstdint> #include <string> #include <vector> @@ -26,18 +26,22 @@ void Aidge::Gather_Op::computeOutputDims() { AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); } - std::vector<DimSize_t> outDims = getInput(0)->dims(); - const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); - // TODO: check indices and gatheredShape + if (!getInput(0)->empty()) { + std::vector<DimSize_t> outDims = getInput(0)->dims(); + const std::vector<DimSize_t> gatheredShape = this->template getAttr<GatherAttr::GatheredShape>(); + // TODO: check indices and gatheredShape - const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? - this->template getAttr<GatherAttr::Axis>() : - this->template getAttr<GatherAttr::Axis>() + outDims.size(); - outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); - if (!gatheredShape.empty()) - { - outDims.insert(outDims.begin() + static_cast<std::size_t>(axisIdx), gatheredShape.begin(),gatheredShape.end()); - } + const std::int64_t axisIdx = this->template getAttr<GatherAttr::Axis>() >= 0 ? + this->template getAttr<GatherAttr::Axis>() : + this->template getAttr<GatherAttr::Axis>() + outDims.size(); + outDims.erase(outDims.begin() + static_cast<std::size_t>(axisIdx)); + if (!gatheredShape.empty()) + { + outDims.insert(outDims.cbegin() + static_cast<std::size_t>(axisIdx), + gatheredShape.cbegin(), + gatheredShape.cend()); + } - mOutputs[0]->resize(outDims); + mOutputs[0]->resize(outDims); + } } \ No newline at end of file diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index 7032c8110..30b060cd2 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -27,30 +27,32 @@ void Aidge::Reshape_Op::computeOutputDims() { AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); } - std::vector<DimSize_t> outDims; - // variables to handle a negative dimension - bool foundNegativeDimension = false; - std::size_t outSize = 1; - DimIdx_t negativeIndex = 0; - - for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) - { - std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; - if (dimSize < 0) { - if (foundNegativeDimension) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); + if (!getInput(0)->empty()) { + std::vector<DimSize_t> outDims; + // variables to handle a negative dimension + bool foundNegativeDimension = false; + std::size_t outSize = 1; + DimIdx_t negativeIndex = 0; + + for(std::size_t i = 0; i < this->template getAttr<ReshapeAttr::Shape>().size(); ++i) + { + std::int64_t dimSize = this->template getAttr<ReshapeAttr::Shape>()[i]; + if (dimSize < 0) { + if (foundNegativeDimension) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Found more than one negative dimension in Reshape Operator."); + } + foundNegativeDimension = true; + dimSize = 1; + negativeIndex = static_cast<DimIdx_t>(i); } - foundNegativeDimension = true; - dimSize = 1; - negativeIndex = static_cast<DimIdx_t>(i); + outDims.push_back(static_cast<DimSize_t>(dimSize)); + outSize *= static_cast<DimSize_t>(dimSize); } - outDims.push_back(static_cast<DimSize_t>(dimSize)); - outSize *= static_cast<DimSize_t>(dimSize); - } - if (foundNegativeDimension) { - outDims[negativeIndex] = (getInput(0) -> size()) / outSize; - } + if (foundNegativeDimension) { + outDims[negativeIndex] = (getInput(0) -> size()) / outSize; + } - mOutputs[0]->resize(outDims); + mOutputs[0]->resize(outDims); + } } \ No newline at end of file -- GitLab