From d20e0aca9195c796e3d8c17d9f6f51b48ba432fb Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Thu, 27 Feb 2025 09:12:42 +0000 Subject: [PATCH] Constant folding now return true if graph has been modified + fix behavior when constantShape is True. --- include/aidge/recipes/Recipes.hpp | 3 ++- python_binding/recipes/pybind_Recipes.cpp | 4 +++- src/recipes/ConstantFolding.cpp | 25 ++++++++++++++++++++--- 3 files changed, 27 insertions(+), 5 deletions(-) diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index f019eb51c..cf428e558 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -27,8 +27,9 @@ namespace Aidge { * * @param graph Graph to fold the constant * @param constant_shape If true Shape operators are considered to be constant + * @return bool True if the graph has been modified */ -void constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false); +bool constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false); // FUSE MATMUL + ADD -> FC diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 06f98ad4f..0c9f86fe7 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -25,13 +25,15 @@ namespace Aidge { void init_Recipes(py::module &m) { - m.def("constant_folding", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter( + m.def("constant_folding", static_cast<bool(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter( Retrieve part of the graph that can be pre-computed and replace them by a Producer. :param graph_view: Graph view on which we want to apply the recipe :type graph_view: :py:class:`aidge_core.GraphView` :param constant_shape: If true, ``Shape`` operator are considered constant, default=False :type constant_shape: bool, optional + :return: True if the graph has been modified + :rtype: bool )mydelimiter"); m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp index 031e448cf..05c92afb3 100644 --- a/src/recipes/ConstantFolding.cpp +++ b/src/recipes/ConstantFolding.cpp @@ -22,7 +22,9 @@ #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" -void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) { +bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) { + bool modified = false; + Log::info("Running constant folding on graph {}", graph->name()); bool folded; do { folded = false; @@ -35,19 +37,32 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape } for (const auto& node : candidates) { + Log::debug("Checking if node {} (of type {}) is foldable", node->name(), node->type()); bool foldable = true; auto replaceGraph = std::make_shared<GraphView>(); size_t i = 0; for (const auto& input : node->inputs()) { if (input.first) { - if (input.first->type() != Producer_Op::Type || (constantShape && (input.first->type() != Shape_Op::Type))) { + if (!(input.first->type() == Producer_Op::Type || (constantShape && (input.first->type() == Shape_Op::Type)))) { + Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant. With constant = {}", + i, node->name(), node->type(), input.first->name(), input.first->type(), constantShape); foldable = false; break; } + + if (constantShape && (input.first->type() == Shape_Op::Type)){ + if (!std::static_pointer_cast<OperatorTensor>(input.first->getOperator())->dimsForwarded()){ + Log::debug("Node {} (of type {}) not foldable because Shape input [{}] {} dims has not been forwarded", + node->name(), node->type(), i, input.first->name()); + foldable = false; + break; + } + } + if (input.first->type() == Producer_Op::Type){ const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); if (!producer->constant()) { - Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant", + Log::debug("Node {} (of type {}) not foldable because Producer input {} not Constant", node->name(), node->type(), input.first->name()); foldable = false; break; @@ -59,6 +74,8 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape else if (node->inputCategory(i) != InputCategory::OptionalData && node->inputCategory(i) != InputCategory::OptionalParam) { + Log::debug("Input {} of node {} (of type {}) is mandatory but not set, cannot fold.", + i, node->name(), node->type()); foldable = false; break; } @@ -84,6 +101,7 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape if (GraphView::replace(replaceGraph, prodGraph)) { folded = true; + modified = true; } else { Log::warn("Error with replace when folding node {} (of type {})", @@ -93,4 +111,5 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape } } while (folded); + return modified; } -- GitLab