From 959c165559b1390c476afaae94ac8ebfb8077b1c Mon Sep 17 00:00:00 2001
From: Olivier BICHLER <olivier.bichler@cea.fr>
Date: Wed, 20 Mar 2024 23:51:01 +0100
Subject: [PATCH] Added constantFolding recipe

---
 include/aidge/recipes/Recipes.hpp |  2 +
 src/recipes/ConstantFolding.cpp   | 86 +++++++++++++++++++++++++++++++
 2 files changed, 88 insertions(+)
 create mode 100644 src/recipes/ConstantFolding.cpp

diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp
index 2f77ae707..97c608cd3 100644
--- a/include/aidge/recipes/Recipes.hpp
+++ b/include/aidge/recipes/Recipes.hpp
@@ -22,6 +22,8 @@
 
 namespace Aidge {
 
+void constantFolding(std::shared_ptr<GraphView> graph);
+
 // FUSE MATMUL + ADD -> FC
 
 /**
diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp
new file mode 100644
index 000000000..42fb45224
--- /dev/null
+++ b/src/recipes/ConstantFolding.cpp
@@ -0,0 +1,86 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+#include <cassert>
+#include <memory>
+#include <set>
+#include <string>
+
+#include "aidge/data/Tensor.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/graph/Node.hpp"
+#include "aidge/operator/Producer.hpp"
+#include "aidge/recipes/Recipes.hpp"
+#include "aidge/utils/ErrorHandling.hpp"
+#include "aidge/utils/Types.h"
+
+void Aidge::constantFolding(std::shared_ptr<GraphView> graph) {
+    bool folded;
+    do {
+        folded = false;
+        std::set<std::shared_ptr<Node>> candidates;
+        for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) {
+            if (nodePtr->type() == Producer_Op::Type) {
+                const auto& childs = nodePtr->getChildren();
+                candidates.insert(childs.begin(), childs.end());
+            }
+        }
+
+        for (const auto& node : candidates) {
+            bool foldable = true;
+            auto replaceGraph = std::make_shared<GraphView>();
+            for (const auto& input : node->inputs()) {
+                if (input.first) {
+                    if (input.first->type() != Producer_Op::Type) {
+                        foldable = false;
+                        break;
+                    }
+
+                    const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator());
+                    if (!producer->getAttr<bool>("Constant")) {
+                        Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant",
+                            node->name(), node->type(), input.first->name());
+                        foldable = false;
+                        break;
+                    }
+
+                    replaceGraph->add(input.first, false);
+                }
+            }
+
+            if (foldable) {
+                Log::info("Folding node {} (of type {})", node->name(), node->type());
+                replaceGraph->add(node, false);
+
+                node->forward();
+
+                auto prodGraph = std::make_shared<GraphView>();
+                const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator());
+
+                for (IOIndex_t output = 0; output < node->nbOutputs(); ++output) {
+                    const auto computedOutput = std::make_shared<Tensor>(op->getOutput(output)->clone());
+                    const auto newProd = Producer(computedOutput, node->name() + "_" + std::to_string(output), true);
+
+                    // Add output in right order
+                    prodGraph->add(newProd);
+                }
+
+                if (GraphView::replace(replaceGraph, prodGraph)) {
+                    folded = true;
+                }
+                else {
+                    Log::warn("Error with replace when folding node {} (of type {})",
+                        node->name(), node->type());
+                }
+            }
+        }
+    }
+    while (folded);
+}
-- 
GitLab