diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index b78346b25be7cb31a049c56a42a12f9314a1f959..0cce7a5b94692ba99923c1a866f88d9b5faee8a1 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -32,10 +32,8 @@ const std::string Aidge::Reshape_Op::Type = "Reshape"; bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { // check input has been associated - for (size_t i = 0; i < 2; ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } if (!getInput(0)->empty()) { std::vector<DimSize_t> outDims; @@ -46,6 +44,9 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { // Fill shape attr if empty if (this->template getAttr<ReshapeAttr::Shape>().empty()) { + if (!getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type()); + } if(!getInput(1)->empty()) { this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); diff --git a/unit_tests/operator/Test_ReshapeImpl.cpp b/unit_tests/operator/Test_ReshapeImpl.cpp index 5b6f5f3fbff86cb2dacb5b136b6834b6e59835b5..2685518b5c49f13ff4aee4202163e3aa3267aa5f 100644 --- a/unit_tests/operator/Test_ReshapeImpl.cpp +++ b/unit_tests/operator/Test_ReshapeImpl.cpp @@ -58,8 +58,6 @@ TEST_CASE("[cpu/operator] Reshape(forward)") { std::shared_ptr<Node> myReshape = Reshape({2, 3}); auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); op->associateInput(0, input); - // TODO find a way to avoid connecting an empty tensor - op->associateInput(1, std::make_shared<Tensor>(Tensor({}))); op->setDataType(DataType::Float32); op->setBackend("cpu"); myReshape->forward(); @@ -116,8 +114,6 @@ TEST_CASE("[cpu/operator] Reshape(forward)") { std::shared_ptr<Node> myReshape = Reshape({3, 2}); auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); op->associateInput(0, input); - // TODO find a way to avoid connecting an empty tensor - op->associateInput(1, std::make_shared<Tensor>(Tensor({}))); op->setDataType(DataType::Float32); op->setBackend("cpu"); myReshape->forward();