Skip to content
Snippets Groups Projects
Commit fc39b17a authored by Houssem ROUIS's avatar Houssem ROUIS Committed by Maxence Naud
Browse files

avoid connecting empty shape tensor

parent e1d961d9
No related branches found
No related tags found
No related merge requests found
......@@ -32,10 +32,8 @@ const std::string Aidge::Reshape_Op::Type = "Reshape";
bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
// check input has been associated
for (size_t i = 0; i < 2; ++i) {
if (!getInput(i)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #{} should be associated with a Tensor", type(), i);
}
if (!getInput(0)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #0 should be associated with a Tensor", type());
}
if (!getInput(0)->empty()) {
std::vector<DimSize_t> outDims;
......@@ -46,6 +44,9 @@ bool Aidge::Reshape_Op::forwardDims(bool /*allowDataDependency*/) {
// Fill shape attr if empty
if (this->template getAttr<ReshapeAttr::Shape>().empty()) {
if (!getInput(1)) {
AIDGE_THROW_OR_ABORT(std::runtime_error, "{}: input #1 should be associated with a Tensor", type());
}
if(!getInput(1)->empty()) {
this->template getAttr<ReshapeAttr::Shape>().clear(); // If both are provided input would override attrs
this->template getAttr<ReshapeAttr::Shape>().reserve(getInput(1)->size());
......
......@@ -58,8 +58,6 @@ TEST_CASE("[cpu/operator] Reshape(forward)") {
std::shared_ptr<Node> myReshape = Reshape({2, 3});
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
// TODO find a way to avoid connecting an empty tensor
op->associateInput(1, std::make_shared<Tensor>(Tensor({})));
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
......@@ -116,8 +114,6 @@ TEST_CASE("[cpu/operator] Reshape(forward)") {
std::shared_ptr<Node> myReshape = Reshape({3, 2});
auto op = std::static_pointer_cast<OperatorTensor>(myReshape -> getOperator());
op->associateInput(0, input);
// TODO find a way to avoid connecting an empty tensor
op->associateInput(1, std::make_shared<Tensor>(Tensor({})));
op->setDataType(DataType::Float32);
op->setBackend("cpu");
myReshape->forward();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment