From 959c165559b1390c476afaae94ac8ebfb8077b1c Mon Sep 17 00:00:00 2001 From: Olivier BICHLER <olivier.bichler@cea.fr> Date: Wed, 20 Mar 2024 23:51:01 +0100 Subject: [PATCH] Added constantFolding recipe --- include/aidge/recipes/Recipes.hpp | 2 + src/recipes/ConstantFolding.cpp | 86 +++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+) create mode 100644 src/recipes/ConstantFolding.cpp diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index 2f77ae707..97c608cd3 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -22,6 +22,8 @@ namespace Aidge { +void constantFolding(std::shared_ptr<GraphView> graph); + // FUSE MATMUL + ADD -> FC /** diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp new file mode 100644 index 000000000..42fb45224 --- /dev/null +++ b/src/recipes/ConstantFolding.cpp @@ -0,0 +1,86 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include <cassert> +#include <memory> +#include <set> +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/recipes/Recipes.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { + bool folded; + do { + folded = false; + std::set<std::shared_ptr<Node>> candidates; + for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) { + if (nodePtr->type() == Producer_Op::Type) { + const auto& childs = nodePtr->getChildren(); + candidates.insert(childs.begin(), childs.end()); + } + } + + for (const auto& node : candidates) { + bool foldable = true; + auto replaceGraph = std::make_shared<GraphView>(); + for (const auto& input : node->inputs()) { + if (input.first) { + if (input.first->type() != Producer_Op::Type) { + foldable = false; + break; + } + + const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); + if (!producer->getAttr<bool>("Constant")) { + Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant", + node->name(), node->type(), input.first->name()); + foldable = false; + break; + } + + replaceGraph->add(input.first, false); + } + } + + if (foldable) { + Log::info("Folding node {} (of type {})", node->name(), node->type()); + replaceGraph->add(node, false); + + node->forward(); + + auto prodGraph = std::make_shared<GraphView>(); + const auto op = std::static_pointer_cast<OperatorTensor>(node->getOperator()); + + for (IOIndex_t output = 0; output < node->nbOutputs(); ++output) { + const auto computedOutput = std::make_shared<Tensor>(op->getOutput(output)->clone()); + const auto newProd = Producer(computedOutput, node->name() + "_" + std::to_string(output), true); + + // Add output in right order + prodGraph->add(newProd); + } + + if (GraphView::replace(replaceGraph, prodGraph)) { + folded = true; + } + else { + Log::warn("Error with replace when folding node {} (of type {})", + node->name(), node->type()); + } + } + } + } + while (folded); +} -- GitLab