diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index b0bc6dcef823204bff248164d2b6bc13de9b35ec..f019eb51c2c487b48352aa4cab8a16e2e0d534cb 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -22,7 +22,13 @@ namespace Aidge { -void constantFolding(std::shared_ptr<GraphView> graph); +/** + * @brief Retrieve part of the graph that can be pre-computed and replace them by a Producer. + * + * @param graph Graph to fold the constant + * @param constant_shape If true Shape operators are considered to be constant + */ +void constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false); // FUSE MATMUL + ADD -> FC diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 500367cb8f58bbfbf76394b0c83cd8ff848ff8cb..06f98ad4f743f57ea30b85a20d7e744a3f0aa661 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -25,6 +25,14 @@ namespace Aidge { void init_Recipes(py::module &m) { + m.def("constant_folding", static_cast<void(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter( + Retrieve part of the graph that can be pre-computed and replace them by a Producer. + + :param graph_view: Graph view on which we want to apply the recipe + :type graph_view: :py:class:`aidge_core.GraphView` + :param constant_shape: If true, ``Shape`` operator are considered constant, default=False + :type constant_shape: bool, optional + )mydelimiter"); m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp index 613393756f7f6b7104118ea97593e3130055ceeb..031e448cf92e21c2457bf23de67ec2f867b1d880 100644 --- a/src/recipes/ConstantFolding.cpp +++ b/src/recipes/ConstantFolding.cpp @@ -17,17 +17,18 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Shape.hpp" #include "aidge/recipes/Recipes.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" -void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { +void Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) { bool folded; do { folded = false; std::set<std::shared_ptr<Node>> candidates; for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) { - if (nodePtr->type() == Producer_Op::Type) { + if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() != Shape_Op::Type))) { const auto& childs = nodePtr->getChildren(); candidates.insert(childs.begin(), childs.end()); } @@ -39,17 +40,18 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { size_t i = 0; for (const auto& input : node->inputs()) { if (input.first) { - if (input.first->type() != Producer_Op::Type) { + if (input.first->type() != Producer_Op::Type || (constantShape && (input.first->type() != Shape_Op::Type))) { foldable = false; break; } - - const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); - if (!producer->constant()) { - Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant", - node->name(), node->type(), input.first->name()); - foldable = false; - break; + if (input.first->type() == Producer_Op::Type){ + const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); + if (!producer->constant()) { + Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant", + node->name(), node->type(), input.first->name()); + foldable = false; + break; + } } replaceGraph->add(input.first, false);