From cfc1e933ac185418acc120781c4541ce471c721f Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Fri, 10 Jan 2025 16:15:38 +0000 Subject: [PATCH] Update ConstantFolding recipes with arg to consider Shape as constant. --- include/aidge/recipes/Recipes.hpp | 8 +++++++- python_binding/recipes/pybind_Recipes.cpp | 8 ++++++++ src/recipes/ConstantFolding.cpp | 22 ++++++++++++---------- 3 files changed, 27 insertions(+), 11 deletions(-) diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index b0bc6dcef..f019eb51c 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -22,7 +22,13 @@ namespace Aidge { -void constantFolding(std::shared_ptr<GraphView> graph); +/** + * @brief Retrieve part of the graph that can be pre-computed and replace them by a Producer. + * + * @param graph Graph to fold the constant + * @param constant_shape If true Shape operators are considered to be constant + */ +void constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false); // FUSE MATMUL + ADD -> FC diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 500367cb8..06f98ad4f 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -25,6 +25,14 @@ namespace Aidge { void init_Recipes(py::module &m) { + m.def("constant_folding", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter( + Retrieve part of the graph that can be pre-computed and replace them by a Producer. + + :param graph_view: Graph view on which we want to apply the recipe + :type graph_view: :py:class:`aidge_core.GraphView` + :param constant_shape: If true, ``Shape`` operator are considered constant, default=False + :type constant_shape: bool, optional + )mydelimiter"); m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp index 613393756..031e448cf 100644 --- a/src/recipes/ConstantFolding.cpp +++ b/src/recipes/ConstantFolding.cpp @@ -17,17 +17,18 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Shape.hpp" #include "aidge/recipes/Recipes.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" -void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { +void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) { bool folded; do { folded = false; std::set<std::shared_ptr<Node>> candidates; for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) { - if (nodePtr->type() == Producer_Op::Type) { + if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() != Shape_Op::Type))) { const auto& childs = nodePtr->getChildren(); candidates.insert(childs.begin(), childs.end()); } @@ -39,17 +40,18 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { size_t i = 0; for (const auto& input : node->inputs()) { if (input.first) { - if (input.first->type() != Producer_Op::Type) { + if (input.first->type() != Producer_Op::Type || (constantShape && (input.first->type() != Shape_Op::Type))) { foldable = false; break; } - - const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); - if (!producer->constant()) { - Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant", - node->name(), node->type(), input.first->name()); - foldable = false; - break; + if (input.first->type() == Producer_Op::Type){ + const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); + if (!producer->constant()) { + Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant", + node->name(), node->type(), input.first->name()); + foldable = false; + break; + } } replaceGraph->add(input.first, false); -- GitLab