From d20e0aca9195c796e3d8c17d9f6f51b48ba432fb Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Thu, 27 Feb 2025 09:12:42 +0000
Subject: [PATCH] Constant folding now return true if graph has been modified +
 fix behavior when constantShape is True.

---
 include/aidge/recipes/Recipes.hpp         |  3 ++-
 python_binding/recipes/pybind_Recipes.cpp |  4 +++-
 src/recipes/ConstantFolding.cpp           | 25 ++++++++++++++++++++---
 3 files changed, 27 insertions(+), 5 deletions(-)

diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp
index f019eb51c..cf428e558 100644
--- a/include/aidge/recipes/Recipes.hpp
+++ b/include/aidge/recipes/Recipes.hpp
@@ -27,8 +27,9 @@ namespace Aidge {
  *
  * @param graph Graph to fold the constant
  * @param constant_shape If true Shape operators are considered to be constant
+ * @return bool True if the graph has been modified
  */
-void constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false);
+bool constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false);
 
 // FUSE MATMUL + ADD -> FC
 
diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp
index 06f98ad4f..0c9f86fe7 100644
--- a/python_binding/recipes/pybind_Recipes.cpp
+++ b/python_binding/recipes/pybind_Recipes.cpp
@@ -25,13 +25,15 @@ namespace Aidge {
 void init_Recipes(py::module &m)
 {
 
-  m.def("constant_folding", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter(
+  m.def("constant_folding", static_cast<bool(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter(
     Retrieve part of the graph that can be pre-computed and replace them by a Producer.
 
     :param graph_view: Graph view on which we want to apply the recipe
     :type graph_view: :py:class:`aidge_core.GraphView`
     :param constant_shape: If true, ``Shape`` operator are considered constant, default=False
     :type constant_shape: bool, optional
+    :return: True if the graph has been modified
+    :rtype: bool
     )mydelimiter");
 
   m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter(
diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp
index 031e448cf..05c92afb3 100644
--- a/src/recipes/ConstantFolding.cpp
+++ b/src/recipes/ConstantFolding.cpp
@@ -22,7 +22,9 @@
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/Types.h"
 
-void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) {
+bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) {
+    bool modified = false;
+    Log::info("Running constant folding on graph {}", graph->name());
     bool folded;
     do {
         folded = false;
@@ -35,19 +37,32 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
         }
 
         for (const auto& node : candidates) {
+            Log::debug("Checking if node {} (of type {}) is foldable", node->name(), node->type());
             bool foldable = true;
             auto replaceGraph = std::make_shared<GraphView>();
             size_t i = 0;
             for (const auto& input : node->inputs()) {
                 if (input.first) {
-                    if (input.first->type() != Producer_Op::Type || (constantShape && (input.first->type() != Shape_Op::Type))) {
+                    if (!(input.first->type() == Producer_Op::Type || (constantShape && (input.first->type() == Shape_Op::Type)))) {
+                        Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant. With constant = {}",
+                            i, node->name(), node->type(), input.first->name(), input.first->type(), constantShape);
                         foldable = false;
                         break;
                     }
+
+                    if (constantShape && (input.first->type() == Shape_Op::Type)){
+                        if (!std::static_pointer_cast<OperatorTensor>(input.first->getOperator())->dimsForwarded()){
+                            Log::debug("Node {} (of type {}) not foldable because Shape input [{}] {} dims has not been forwarded",
+                                node->name(), node->type(), i, input.first->name());
+                            foldable = false;
+                            break;
+                        }
+                    }
+
                     if (input.first->type() == Producer_Op::Type){
                         const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator());
                         if (!producer->constant()) {
-                            Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant",
+                            Log::debug("Node {} (of type {}) not foldable because Producer input {} not Constant",
                                 node->name(), node->type(), input.first->name());
                             foldable = false;
                             break;
@@ -59,6 +74,8 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
                 else if (node->inputCategory(i) != InputCategory::OptionalData
                     && node->inputCategory(i) != InputCategory::OptionalParam)
                 {
+                    Log::debug("Input {} of node {} (of type {}) is mandatory but not set, cannot fold.",
+                        i, node->name(), node->type());
                     foldable = false;
                     break;
                 }
@@ -84,6 +101,7 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
 
                 if (GraphView::replace(replaceGraph, prodGraph)) {
                     folded = true;
+                    modified = true;
                 }
                 else {
                     Log::warn("Error with replace when folding node {} (of type {})",
@@ -93,4 +111,5 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape
         }
     }
     while (folded);
+    return modified;
 }
-- 
GitLab