diff --git a/src/recipes/ShapeFolding.cpp b/src/recipes/ShapeFolding.cpp index f869e646bdbe9088ab83baa99b10e2359869fa76..ca6ef5087a62262bcf40244c2e1dea2bc15fff81 100644 --- a/src/recipes/ShapeFolding.cpp +++ b/src/recipes/ShapeFolding.cpp @@ -24,6 +24,7 @@ bool Aidge::constantShapeFolding(std::shared_ptr<GraphView> graph, const std::ve bool modified = false; bool forwarded = false; bool not_shape_present = true; + bool was_modified = false; for (auto nodePtr: graph->getNodes()) not_shape_present &= (nodePtr->type() != Shape_Op::Type); if (not_shape_present) @@ -31,10 +32,11 @@ bool Aidge::constantShapeFolding(std::shared_ptr<GraphView> graph, const std::ve do{ forwarded = graph->forwardDims(dims, true); modified = constantFolding(graph, true); + was_modified = true; } while(modified); if (!forwarded){ Log::warn("Failed to forward GraphView."); } - return modified; + return was_modified; }