From c7da2a046684956b7cd73fa1420da12c9f476efe Mon Sep 17 00:00:00 2001
From: cmoineau <cyril.moineau@cea.fr>
Date: Mon, 3 Mar 2025 14:36:06 +0000
Subject: [PATCH] Remove constant shape folding from forwardDims and create a
 recipes.

---
 .../test_forward_dims_constant_shape.py       |  4 +-
 include/aidge/graph/GraphView.hpp             |  3 +-
 include/aidge/recipes/Recipes.hpp             |  8 ++++
 python_binding/graph/pybind_GraphView.cpp     |  4 +-
 python_binding/recipes/pybind_Recipes.cpp     | 11 +++++
 src/graph/GraphView.cpp                       | 29 +-------------
 src/recipes/ShapeFolding.cpp                  | 40 +++++++++++++++++++
 7 files changed, 64 insertions(+), 35 deletions(-)
 create mode 100644 src/recipes/ShapeFolding.cpp

diff --git a/aidge_core/unit_tests/test_forward_dims_constant_shape.py b/aidge_core/unit_tests/test_forward_dims_constant_shape.py
index ecab2664a..aea0260c8 100644
--- a/aidge_core/unit_tests/test_forward_dims_constant_shape.py
+++ b/aidge_core/unit_tests/test_forward_dims_constant_shape.py
@@ -93,8 +93,8 @@ class test_forward_dims_constant_shape(unittest.TestCase):
         # Note: Except Div every operator are backend independent
         self.graph.set_backend("cpu")
         self.graph.set_datatype(aidge_core.dtype.float32)
-        self.assertTrue(self.graph.forward_dims([[5, 12, 24, 24]], allow_data_dependency = True, shape_as_constant = True),
-                        "Failed to forward dimensions.")
+
+        aidge_core.constant_shape_folding(self.graph, [[5, 12, 24, 24]])
         self.assertEqual(len(self.graph.get_nodes()), 6, "After forward dims with constant folding we don't have the expected number of nodes.")
 
 
diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index f6297f790..be325cb96 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -291,10 +291,9 @@ public:
      *
      * @param dims Vector of dimension vectors for graph inputs. Empty by default.
      * @param allowDataDependency Whether to allow data-dependent dimension computation. False by default.
-     * @param shapeAsConstant If true treat shape as constant and fold the graph, this implies that the graph may change if part of the graph are foldable. False by default.
      * @return true if dimension propagation succeeded, false otherwise.
      */
-    bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false, bool shapeAsConstant = false);
+    bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
 
     /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
     void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp
index cf428e558..2d04fc426 100644
--- a/include/aidge/recipes/Recipes.hpp
+++ b/include/aidge/recipes/Recipes.hpp
@@ -31,6 +31,14 @@ namespace Aidge {
  */
 bool constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false);
 
+/**
+ * @brief Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer.
+ *
+ * @param graph Graph to fold the constant
+ * @return bool True if the graph has been modified
+ */
+bool constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims = {});
+
 // FUSE MATMUL + ADD -> FC
 
 /**
diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp
index ec4119e32..31e3a0099 100644
--- a/python_binding/graph/pybind_GraphView.cpp
+++ b/python_binding/graph/pybind_GraphView.cpp
@@ -128,7 +128,7 @@ void init_GraphView(py::module& m) {
           .def("clone", &GraphView::clone)
           .def("get_nodes", &GraphView::getNodes)
           .def("get_node", &GraphView::getNode, py::arg("node_name"))
-          .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, py::arg("shape_as_constant") = false,
+          .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false,
           R"mydelimiter(
             Compute and propagate Tensor dimensions through the GraphView.
 
@@ -167,8 +167,6 @@ void init_GraphView(py::module& m) {
                 Vector of dimension vectors for graph inputs. Empty by default.
             allow_data_dependency : bool, optional
                 Whether to allow data-dependent dimension computation, by default False
-            shape_as_constant : bool, optional
-                If True, treat shape as constant and fold the graph. This implies that the graph may change if part of the graph are foldable, by default False.
 
             Returns
             -------
diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp
index 0c9f86fe7..68ad81b8b 100644
--- a/python_binding/recipes/pybind_Recipes.cpp
+++ b/python_binding/recipes/pybind_Recipes.cpp
@@ -36,6 +36,17 @@ void init_Recipes(py::module &m)
     :rtype: bool
     )mydelimiter");
 
+  m.def("constant_shape_folding", static_cast<bool(*)(std::shared_ptr<GraphView>, const std::vector<std::vector<DimSize_t>>&)>(constantShapeFolding), py::arg("graph_view"), py::arg("dims") = std::vector<std::vector<DimSize_t>>(), R"mydelimiter(
+      Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer.
+
+      :param graph_view: Graph view on which we want to apply the recipe
+      :type graph_view: :py:class:`aidge_core.GraphView`
+      :param constant_shape: If true, ``Shape`` operator are considered constant, default=False
+      :type constant_shape: bool, optional
+      :return: True if the graph has been modified
+      :rtype: bool
+      )mydelimiter");
+
   m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter(
     Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
 
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 533c9dc25..2fbc264e4 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -445,7 +445,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
     forwardDims(dims);
 }
 
-bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency, bool shapeAsConstant) {
+bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency) {
     Log::debug("Starting dimension forward propagation for GraphView");
     // remove current Data connections and use dummy inputs to propagate dimensions
     // setInputs
@@ -570,33 +570,6 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
                     Log::debug("Dimensions forwarded for node {} (of type {})",
                         nodePtr->name(), nodePtr->type());
 
-                    if (shapeAsConstant && (nodePtr->type() == Shape_Op::Type)) {
-                        Log::debug("Trying to constant fold the graph");
-                        // Shape are folded if we don't find back the node in the graph
-                        if(constantFolding(shared_from_this(), true)){
-                            Log::notice("Constant folding worked, resetting list of nodes to graph inputs.");
-                            // Graph was modified during constant folding
-                            // We re-propagate dims starting from the entry of the graph
-                            nextList = inputNodes();
-                            for (const auto& currentNodePtr : getNodes()) {
-                                if (currentNodePtr->type() == Producer_Op::Type) {
-                                    // Producers are already dims forwarded!
-                                    dimsForwarded.insert(currentNodePtr);
-                                    // Producers childs are dims forwardable
-                                    for (const auto& child : currentNodePtr->getChildren()) {
-                                        if (inView(child)) {
-                                            nextList.insert(child);
-                                        }
-                                    }
-                                }
-                            }
-
-                            Log::debug("Breaking loop to restart from the beginning");
-                            break;
-                        }else{
-                            Log::debug("Constant folding fail to fold any nodes.");
-                        }
-                    }
                     // Recompute every time, even if it was already computed in a
                     // previous call of forwardDims(), as the graph may have changed!
                     dimsForwarded.insert(nodePtr);
diff --git a/src/recipes/ShapeFolding.cpp b/src/recipes/ShapeFolding.cpp
new file mode 100644
index 000000000..f869e646b
--- /dev/null
+++ b/src/recipes/ShapeFolding.cpp
@@ -0,0 +1,40 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+// #include <cassert>
+#include <memory>
+#include <set>
+#include <string>
+
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/operator/Shape.hpp"
+#include "aidge/recipes/Recipes.hpp"
+#include "aidge/utils/Log.hpp"
+// #include "aidge/utils/Types.h"
+
+bool Aidge::constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims) {
+    bool modified      = false;
+    bool forwarded     = false;
+    bool not_shape_present = true;
+    for (auto nodePtr: graph->getNodes())
+        not_shape_present &= (nodePtr->type() != Shape_Op::Type);
+    if (not_shape_present)
+        return false;
+    do{
+        forwarded = graph->forwardDims(dims, true);
+        modified = constantFolding(graph, true);
+    } while(modified);
+    if (!forwarded){
+        Log::warn("Failed to forward GraphView.");
+    }
+
+    return modified;
+}
-- 
GitLab