From cfc1e933ac185418acc120781c4541ce471c721f Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Fri, 10 Jan 2025 16:15:38 +0000
Subject: [PATCH] Update ConstantFolding recipes with arg to consider Shape as
 constant.

---
 include/aidge/recipes/Recipes.hpp         |  8 +++++++-
 python_binding/recipes/pybind_Recipes.cpp |  8 ++++++++
 src/recipes/ConstantFolding.cpp           | 22 ++++++++++++----------
 3 files changed, 27 insertions(+), 11 deletions(-)

diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp
index b0bc6dcef..f019eb51c 100644
--- a/include/aidge/recipes/Recipes.hpp
+++ b/include/aidge/recipes/Recipes.hpp
@@ -22,7 +22,13 @@
 
 namespace Aidge {
 
-void constantFolding(std::shared_ptr<GraphView> graph);
+/**
+ * @brief Retrieve part of the graph that can be pre-computed and replace them by a Producer.
+ *
+ * @param graph Graph to fold the constant
+ * @param constant_shape If true Shape operators are considered to be constant
+ */
+void constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false);
 
 // FUSE MATMUL + ADD -> FC
 
diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp
index 500367cb8..06f98ad4f 100644
--- a/python_binding/recipes/pybind_Recipes.cpp
+++ b/python_binding/recipes/pybind_Recipes.cpp
@@ -25,6 +25,14 @@ namespace Aidge {
 void init_Recipes(py::module &m)
 {
 
+  m.def("constant_folding", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter(
+    Retrieve part of the graph that can be pre-computed and replace them by a Producer.
+
+    :param graph_view: Graph view on which we want to apply the recipe
+    :type graph_view: :py:class:`aidge_core.GraphView`
+    :param constant_shape: If true, ``Shape`` operator are considered constant, default=False
+    :type constant_shape: bool, optional
+    )mydelimiter");
 
   m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter(
     Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp
index 613393756..031e448cf 100644
--- a/src/recipes/ConstantFolding.cpp
+++ b/src/recipes/ConstantFolding.cpp
@@ -17,17 +17,18 @@
 #include "aidge/graph/GraphView.hpp"
 #include "aidge/graph/Node.hpp"
 #include "aidge/operator/Producer.hpp"
+#include "aidge/operator/Shape.hpp"
 #include "aidge/recipes/Recipes.hpp"
 #include "aidge/utils/ErrorHandling.hpp"
 #include "aidge/utils/Types.h"
 
-void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
+void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) {
     bool folded;
     do {
         folded = false;
         std::set<std::shared_ptr<Node>> candidates;
         for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) {
-            if (nodePtr->type() == Producer_Op::Type) {
+            if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() != Shape_Op::Type))) {
                 const auto& childs = nodePtr->getChildren();
                 candidates.insert(childs.begin(), childs.end());
             }
@@ -39,17 +40,18 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
             size_t i = 0;
             for (const auto& input : node->inputs()) {
                 if (input.first) {
-                    if (input.first->type() != Producer_Op::Type) {
+                    if (input.first->type() != Producer_Op::Type || (constantShape && (input.first->type() != Shape_Op::Type))) {
                         foldable = false;
                         break;
                     }
-
-                    const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator());
-                    if (!producer->constant()) {
-                        Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant",
-                            node->name(), node->type(), input.first->name());
-                        foldable = false;
-                        break;
+                    if (input.first->type() == Producer_Op::Type){
+                        const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator());
+                        if (!producer->constant()) {
+                            Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant",
+                                node->name(), node->type(), input.first->name());
+                            foldable = false;
+                            break;
+                        }
                     }
 
                     replaceGraph->add(input.first, false);
-- 
GitLab