diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index 3e3d5e61695b1c755dac5aad9b2696f3bdbbcea6..b78346b25be7cb31a049c56a42a12f9314a1f959 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -32,10 +32,11 @@ const std::string Aidge::Reshape_Op::Type = "Reshape"; bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { // check input has been associated - if (!getInput(0)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "Input was not connected"); + for (size_t i = 0; i < 2; ++i) { + if (!getInput(i)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); + } } - if (!getInput(0)->empty()) { std::vector<DimSize_t> outDims; // variables to handle a negative dimension @@ -43,11 +44,11 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { std::size_t outSize = 1; DimIdx_t negativeIndex = 0; - - if (this->template getAttr<ReshapeAttr::Shape>().empty() && getInput(1)) { + // Fill shape attr if empty + if (this->template getAttr<ReshapeAttr::Shape>().empty()) { if(!getInput(1)->empty()) { + this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); - // Fill shape attr switch (mInputs[1]->dataType()) { case DataType::Float64: std::copy_n(static_cast<double*>(mInputs[1]->getImpl()->rawPtr()),