From fc39b17a0d3c6cd7d93ad825e9c8b3120ff7c889 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Tue, 7 May 2024 10:03:08 +0200 Subject: [PATCH] avoid connecting empty shape tensor --- src/operator/Reshape.cpp | 9 +++++---- unit_tests/operator/Test_ReshapeImpl.cpp | 4 ---- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/src/operator/Reshape.cpp b/src/operator/Reshape.cpp index b78346b25..0cce7a5b9 100644 --- a/src/operator/Reshape.cpp +++ b/src/operator/Reshape.cpp @@ -32,10 +32,8 @@ const std::string Aidge::Reshape_Op::Type = "Reshape"; bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { // check input has been associated - for (size_t i = 0; i < 2; ++i) { - if (!getInput(i)) { - AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i); - } + if (!getInput(0)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type()); } if (!getInput(0)->empty()) { std::vector<DimSize_t> outDims; @@ -46,6 +44,9 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) { // Fill shape attr if empty if (this->template getAttr<ReshapeAttr::Shape>().empty()) { + if (!getInput(1)) { + AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type()); + } if(!getInput(1)->empty()) { this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size()); diff --git a/unit_tests/operator/Test_ReshapeImpl.cpp b/unit_tests/operator/Test_ReshapeImpl.cpp index 5b6f5f3fb..2685518b5 100644 --- a/unit_tests/operator/Test_ReshapeImpl.cpp +++ b/unit_tests/operator/Test_ReshapeImpl.cpp @@ -58,8 +58,6 @@ TEST_CASE("[cpu/operator] Reshape(forward)") { std::shared_ptr<Node> myReshape = Reshape({2, 3}); auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); op->associateInput(0, input); - // TODO find a way to avoid connecting an empty tensor - op->associateInput(1, std::make_shared<Tensor>(Tensor({}))); op->setDataType(DataType::Float32); op->setBackend("cpu"); myReshape->forward(); @@ -116,8 +114,6 @@ TEST_CASE("[cpu/operator] Reshape(forward)") { std::shared_ptr<Node> myReshape = Reshape({3, 2}); auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator()); op->associateInput(0, input); - // TODO find a way to avoid connecting an empty tensor - op->associateInput(1, std::make_shared<Tensor>(Tensor({}))); op->setDataType(DataType::Float32); op->setBackend("cpu"); myReshape->forward(); -- GitLab