diff --git a/aidge_core/unit_tests/test_forward_dims_constant_shape.py b/aidge_core/unit_tests/test_forward_dims_constant_shape.py index ecab2664a25023d2919324898fc6d84d63d8853d..aea0260c8beb94cc880136f9338b16ca400ac25c 100644 --- a/aidge_core/unit_tests/test_forward_dims_constant_shape.py +++ b/aidge_core/unit_tests/test_forward_dims_constant_shape.py @@ -93,8 +93,8 @@ class test_forward_dims_constant_shape(unittest.TestCase): # Note: Except Div every operator are backend independent self.graph.set_backend("cpu") self.graph.set_datatype(aidge_core.dtype.float32) - self.assertTrue(self.graph.forward_dims([[5, 12, 24, 24]], allow_data_dependency = True, shape_as_constant = True), - "Failed to forward dimensions.") + + aidge_core.constant_shape_folding(self.graph, [[5, 12, 24, 24]]) self.assertEqual(len(self.graph.get_nodes()), 6, "After forward dims with constant folding we don't have the expected number of nodes.") diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index f6297f790e58a2d70db7158c9385e0fc2b9b6fc4..be325cb9697348031fbc16ff0fcda0c506519c03 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -291,10 +291,9 @@ public: * * @param dims Vector of dimension vectors for graph inputs. Empty by default. * @param allowDataDependency Whether to allow data-dependent dimension computation. False by default. - * @param shapeAsConstant If true treat shape as constant and fold the graph, this implies that the graph may change if part of the graph are foldable. False by default. * @return true if dimension propagation succeeded, false otherwise. */ - bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false, bool shapeAsConstant = false); + bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index cf428e558cd51bfde02a78cb7951f7acd86ea483..2d04fc4268ace79eb374cdbe8f874fb71b718ce3 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -31,6 +31,14 @@ namespace Aidge { */ bool constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false); +/** + * @brief Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer. + * + * @param graph Graph to fold the constant + * @return bool True if the graph has been modified + */ +bool constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims = {}); + // FUSE MATMUL + ADD -> FC /** diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index ec4119e3267cb0cfac5667da038ac158dd4c9fe1..31e3a009953c8d5b938dfe6fb888d911bef1e066 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -128,7 +128,7 @@ void init_GraphView(py::module& m) { .def("clone", &GraphView::clone) .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) - .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, py::arg("shape_as_constant") = false, + .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, R"mydelimiter( Compute and propagate Tensor dimensions through the GraphView. @@ -167,8 +167,6 @@ void init_GraphView(py::module& m) { Vector of dimension vectors for graph inputs. Empty by default. allow_data_dependency : bool, optional Whether to allow data-dependent dimension computation, by default False - shape_as_constant : bool, optional - If True, treat shape as constant and fold the graph. This implies that the graph may change if part of the graph are foldable, by default False. Returns ------- diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 0c9f86fe742eaa5d1b7489e150fbd194dedebbd9..68ad81b8b57062cb863a75810e086900a4ce5a6d 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -36,6 +36,17 @@ void init_Recipes(py::module &m) :rtype: bool )mydelimiter"); + m.def("constant_shape_folding", static_cast<bool(*)(std::shared_ptr<GraphView>, const std::vector<std::vector<DimSize_t>>&)>(constantShapeFolding), py::arg("graph_view"), py::arg("dims") = std::vector<std::vector<DimSize_t>>(), R"mydelimiter( + Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer. + + :param graph_view: Graph view on which we want to apply the recipe + :type graph_view: :py:class:`aidge_core.GraphView` + :param constant_shape: If true, ``Shape`` operator are considered constant, default=False + :type constant_shape: bool, optional + :return: True if the graph has been modified + :rtype: bool + )mydelimiter"); + m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 533c9dc255583f331b318c92adc19cb905e372a2..2fbc264e4ed5e198089b7629ddfbe218d93c6f78 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -445,7 +445,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType forwardDims(dims); } -bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency, bool shapeAsConstant) { +bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency) { Log::debug("Starting dimension forward propagation for GraphView"); // remove current Data connections and use dummy inputs to propagate dimensions // setInputs @@ -570,33 +570,6 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ Log::debug("Dimensions forwarded for node {} (of type {})", nodePtr->name(), nodePtr->type()); - if (shapeAsConstant && (nodePtr->type() == Shape_Op::Type)) { - Log::debug("Trying to constant fold the graph"); - // Shape are folded if we don't find back the node in the graph - if(constantFolding(shared_from_this(), true)){ - Log::notice("Constant folding worked, resetting list of nodes to graph inputs."); - // Graph was modified during constant folding - // We re-propagate dims starting from the entry of the graph - nextList = inputNodes(); - for (const auto& currentNodePtr : getNodes()) { - if (currentNodePtr->type() == Producer_Op::Type) { - // Producers are already dims forwarded! - dimsForwarded.insert(currentNodePtr); - // Producers childs are dims forwardable - for (const auto& child : currentNodePtr->getChildren()) { - if (inView(child)) { - nextList.insert(child); - } - } - } - } - - Log::debug("Breaking loop to restart from the beginning"); - break; - }else{ - Log::debug("Constant folding fail to fold any nodes."); - } - } // Recompute every time, even if it was already computed in a // previous call of forwardDims(), as the graph may have changed! dimsForwarded.insert(nodePtr); diff --git a/src/recipes/ShapeFolding.cpp b/src/recipes/ShapeFolding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f869e646bdbe9088ab83baa99b10e2359869fa76 --- /dev/null +++ b/src/recipes/ShapeFolding.cpp @@ -0,0 +1,40 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +// #include <cassert> +#include <memory> +#include <set> +#include <string> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Shape.hpp" +#include "aidge/recipes/Recipes.hpp" +#include "aidge/utils/Log.hpp" +// #include "aidge/utils/Types.h" + +bool Aidge::constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims) { + bool modified = false; + bool forwarded = false; + bool not_shape_present = true; + for (auto nodePtr: graph->getNodes()) + not_shape_present &= (nodePtr->type() != Shape_Op::Type); + if (not_shape_present) + return false; + do{ + forwarded = graph->forwardDims(dims, true); + modified = constantFolding(graph, true); + } while(modified); + if (!forwarded){ + Log::warn("Failed to forward GraphView."); + } + + return modified; +}