From c7da2a046684956b7cd73fa1420da12c9f476efe Mon Sep 17 00:00:00 2001 From: cmoineau <cyril.moineau@cea.fr> Date: Mon, 3 Mar 2025 14:36:06 +0000 Subject: [PATCH] Remove constant shape folding from forwardDims and create a recipes. --- .../test_forward_dims_constant_shape.py | 4 +- include/aidge/graph/GraphView.hpp | 3 +- include/aidge/recipes/Recipes.hpp | 8 ++++ python_binding/graph/pybind_GraphView.cpp | 4 +- python_binding/recipes/pybind_Recipes.cpp | 11 +++++ src/graph/GraphView.cpp | 29 +------------- src/recipes/ShapeFolding.cpp | 40 +++++++++++++++++++ 7 files changed, 64 insertions(+), 35 deletions(-) create mode 100644 src/recipes/ShapeFolding.cpp diff --git a/aidge_core/unit_tests/test_forward_dims_constant_shape.py b/aidge_core/unit_tests/test_forward_dims_constant_shape.py index ecab2664a..aea0260c8 100644 --- a/aidge_core/unit_tests/test_forward_dims_constant_shape.py +++ b/aidge_core/unit_tests/test_forward_dims_constant_shape.py @@ -93,8 +93,8 @@ class test_forward_dims_constant_shape(unittest.TestCase): # Note: Except Div every operator are backend independent self.graph.set_backend("cpu") self.graph.set_datatype(aidge_core.dtype.float32) - self.assertTrue(self.graph.forward_dims([[5, 12, 24, 24]], allow_data_dependency = True, shape_as_constant = True), - "Failed to forward dimensions.") + + aidge_core.constant_shape_folding(self.graph, [[5, 12, 24, 24]]) self.assertEqual(len(self.graph.get_nodes()), 6, "After forward dims with constant folding we don't have the expected number of nodes.") diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index f6297f790..be325cb96 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -291,10 +291,9 @@ public: * * @param dims Vector of dimension vectors for graph inputs. Empty by default. * @param allowDataDependency Whether to allow data-dependent dimension computation. False by default. - * @param shapeAsConstant If true treat shape as constant and fold the graph, this implies that the graph may change if part of the graph are foldable. False by default. * @return true if dimension propagation succeeded, false otherwise. */ - bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false, bool shapeAsConstant = false); + bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index cf428e558..2d04fc426 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -31,6 +31,14 @@ namespace Aidge { */ bool constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false); +/** + * @brief Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer. + * + * @param graph Graph to fold the constant + * @return bool True if the graph has been modified + */ +bool constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims = {}); + // FUSE MATMUL + ADD -> FC /** diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index ec4119e32..31e3a0099 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -128,7 +128,7 @@ void init_GraphView(py::module& m) { .def("clone", &GraphView::clone) .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) - .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, py::arg("shape_as_constant") = false, + .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, R"mydelimiter( Compute and propagate Tensor dimensions through the GraphView. @@ -167,8 +167,6 @@ void init_GraphView(py::module& m) { Vector of dimension vectors for graph inputs. Empty by default. allow_data_dependency : bool, optional Whether to allow data-dependent dimension computation, by default False - shape_as_constant : bool, optional - If True, treat shape as constant and fold the graph. This implies that the graph may change if part of the graph are foldable, by default False. Returns ------- diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 0c9f86fe7..68ad81b8b 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -36,6 +36,17 @@ void init_Recipes(py::module &m) :rtype: bool )mydelimiter"); + m.def("constant_shape_folding", static_cast<bool(*)(std::shared_ptr<GraphView>, const std::vector<std::vector<DimSize_t>>&)>(constantShapeFolding), py::arg("graph_view"), py::arg("dims") = std::vector<std::vector<DimSize_t>>(), R"mydelimiter( + Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer. + + :param graph_view: Graph view on which we want to apply the recipe + :type graph_view: :py:class:`aidge_core.GraphView` + :param constant_shape: If true, ``Shape`` operator are considered constant, default=False + :type constant_shape: bool, optional + :return: True if the graph has been modified + :rtype: bool + )mydelimiter"); + m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 533c9dc25..2fbc264e4 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -445,7 +445,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType forwardDims(dims); } -bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency, bool shapeAsConstant) { +bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency) { Log::debug("Starting dimension forward propagation for GraphView"); // remove current Data connections and use dummy inputs to propagate dimensions // setInputs @@ -570,33 +570,6 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ Log::debug("Dimensions forwarded for node {} (of type {})", nodePtr->name(), nodePtr->type()); - if (shapeAsConstant && (nodePtr->type() == Shape_Op::Type)) { - Log::debug("Trying to constant fold the graph"); - // Shape are folded if we don't find back the node in the graph - if(constantFolding(shared_from_this(), true)){ - Log::notice("Constant folding worked, resetting list of nodes to graph inputs."); - // Graph was modified during constant folding - // We re-propagate dims starting from the entry of the graph - nextList = inputNodes(); - for (const auto& currentNodePtr : getNodes()) { - if (currentNodePtr->type() == Producer_Op::Type) { - // Producers are already dims forwarded! - dimsForwarded.insert(currentNodePtr); - // Producers childs are dims forwardable - for (const auto& child : currentNodePtr->getChildren()) { - if (inView(child)) { - nextList.insert(child); - } - } - } - } - - Log::debug("Breaking loop to restart from the beginning"); - break; - }else{ - Log::debug("Constant folding fail to fold any nodes."); - } - } // Recompute every time, even if it was already computed in a // previous call of forwardDims(), as the graph may have changed! dimsForwarded.insert(nodePtr); diff --git a/src/recipes/ShapeFolding.cpp b/src/recipes/ShapeFolding.cpp new file mode 100644 index 000000000..f869e646b --- /dev/null +++ b/src/recipes/ShapeFolding.cpp @@ -0,0 +1,40 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +// #include <cassert> +#include <memory> +#include <set> +#include <string> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Shape.hpp" +#include "aidge/recipes/Recipes.hpp" +#include "aidge/utils/Log.hpp" +// #include "aidge/utils/Types.h" + +bool Aidge::constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims) { + bool modified = false; + bool forwarded = false; + bool not_shape_present = true; + for (auto nodePtr: graph->getNodes()) + not_shape_present &= (nodePtr->type() != Shape_Op::Type); + if (not_shape_present) + return false; + do{ + forwarded = graph->forwardDims(dims, true); + modified = constantFolding(graph, true); + } while(modified); + if (!forwarded){ + Log::warn("Failed to forward GraphView."); + } + + return modified; +} -- GitLab