From d73de8080f43100db336d55cffbdd4ce17cd7f7e Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 28 Feb 2025 15:23:02 +0000
Subject: [PATCH] Fix constant folding with shape as constant.

---
 src/recipes/ConstantFolding.cpp | 15 +++++++++++----
 1 file changed, 11 insertions(+), 4 deletions(-)

diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp
index 05c92afb3..40cd30a7d 100644
--- a/src/recipes/ConstantFolding.cpp
+++ b/src/recipes/ConstantFolding.cpp
@@ -30,7 +30,7 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
         folded = false;
         std::set<std::shared_ptr<Node>> candidates;
         for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) {
-            if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() != Shape_Op::Type))) {
+            if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() == Shape_Op::Type))) {
                 const auto& childs = nodePtr->getChildren();
                 candidates.insert(childs.begin(), childs.end());
             }
@@ -44,8 +44,8 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
             for (const auto& input : node->inputs()) {
                 if (input.first) {
                     if (!(input.first->type() == Producer_Op::Type || (constantShape && (input.first->type() == Shape_Op::Type)))) {
-                        Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant. With constant = {}",
-                            i, node->name(), node->type(), input.first->name(), input.first->type(), constantShape);
+                        Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant.",
+                            i, node->name(), node->type(), input.first->name(), input.first->type());
                         foldable = false;
                         break;
                     }
@@ -98,7 +98,14 @@ bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
                     // Add output in right order
                     prodGraph->add(newProd);
                 }
-
+                Log::debug("Trying to replace:");
+                for(auto nodeToReplace: replaceGraph->getNodes()){
+                    Log::debug("\t- {} ({})", nodeToReplace->name(), nodeToReplace->type());
+                }
+                Log::debug("With:");
+                for(auto nodeReplacing: prodGraph->getNodes()){
+                    Log::debug("\t- {} ({})", nodeReplacing->name(), nodeReplacing->type());
+                }
                 if (GraphView::replace(replaceGraph, prodGraph)) {
                     folded = true;
                     modified = true;
-- 
GitLab