Skip to content
Snippets Groups Projects
Commit c7da2a04 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Remove constant shape folding from forwardDims and create a recipes.

parent 456aa607
No related branches found
No related tags found
1 merge request!297Reshape forward dims
...@@ -93,8 +93,8 @@ class test_forward_dims_constant_shape(unittest.TestCase): ...@@ -93,8 +93,8 @@ class test_forward_dims_constant_shape(unittest.TestCase):
# Note: Except Div every operator are backend independent # Note: Except Div every operator are backend independent
self.graph.set_backend("cpu") self.graph.set_backend("cpu")
self.graph.set_datatype(aidge_core.dtype.float32) self.graph.set_datatype(aidge_core.dtype.float32)
self.assertTrue(self.graph.forward_dims([[5, 12, 24, 24]], allow_data_dependency = True, shape_as_constant = True),
"Failed to forward dimensions.") aidge_core.constant_shape_folding(self.graph, [[5, 12, 24, 24]])
self.assertEqual(len(self.graph.get_nodes()), 6, "After forward dims with constant folding we don't have the expected number of nodes.") self.assertEqual(len(self.graph.get_nodes()), 6, "After forward dims with constant folding we don't have the expected number of nodes.")
......
...@@ -291,10 +291,9 @@ public: ...@@ -291,10 +291,9 @@ public:
* *
* @param dims Vector of dimension vectors for graph inputs. Empty by default. * @param dims Vector of dimension vectors for graph inputs. Empty by default.
* @param allowDataDependency Whether to allow data-dependent dimension computation. False by default. * @param allowDataDependency Whether to allow data-dependent dimension computation. False by default.
* @param shapeAsConstant If true treat shape as constant and fold the graph, this implies that the graph may change if part of the graph are foldable. False by default.
* @return true if dimension propagation succeeded, false otherwise. * @return true if dimension propagation succeeded, false otherwise.
*/ */
bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false, bool shapeAsConstant = false); bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false);
/** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */
void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const;
......
...@@ -31,6 +31,14 @@ namespace Aidge { ...@@ -31,6 +31,14 @@ namespace Aidge {
*/ */
bool constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false); bool constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false);
/**
* @brief Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer.
*
* @param graph Graph to fold the constant
* @return bool True if the graph has been modified
*/
bool constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims = {});
// FUSE MATMUL + ADD -> FC // FUSE MATMUL + ADD -> FC
/** /**
......
...@@ -128,7 +128,7 @@ void init_GraphView(py::module& m) { ...@@ -128,7 +128,7 @@ void init_GraphView(py::module& m) {
.def("clone", &GraphView::clone) .def("clone", &GraphView::clone)
.def("get_nodes", &GraphView::getNodes) .def("get_nodes", &GraphView::getNodes)
.def("get_node", &GraphView::getNode, py::arg("node_name")) .def("get_node", &GraphView::getNode, py::arg("node_name"))
.def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, py::arg("shape_as_constant") = false, .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false,
R"mydelimiter( R"mydelimiter(
Compute and propagate Tensor dimensions through the GraphView. Compute and propagate Tensor dimensions through the GraphView.
...@@ -167,8 +167,6 @@ void init_GraphView(py::module& m) { ...@@ -167,8 +167,6 @@ void init_GraphView(py::module& m) {
Vector of dimension vectors for graph inputs. Empty by default. Vector of dimension vectors for graph inputs. Empty by default.
allow_data_dependency : bool, optional allow_data_dependency : bool, optional
Whether to allow data-dependent dimension computation, by default False Whether to allow data-dependent dimension computation, by default False
shape_as_constant : bool, optional
If True, treat shape as constant and fold the graph. This implies that the graph may change if part of the graph are foldable, by default False.
Returns Returns
------- -------
......
...@@ -36,6 +36,17 @@ void init_Recipes(py::module &m) ...@@ -36,6 +36,17 @@ void init_Recipes(py::module &m)
:rtype: bool :rtype: bool
)mydelimiter"); )mydelimiter");
m.def("constant_shape_folding", static_cast<bool(*)(std::shared_ptr<GraphView>, const std::vector<std::vector<DimSize_t>>&)>(constantShapeFolding), py::arg("graph_view"), py::arg("dims") = std::vector<std::vector<DimSize_t>>(), R"mydelimiter(
Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer.
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
:param constant_shape: If true, ``Shape`` operator are considered constant, default=False
:type constant_shape: bool, optional
:return: True if the graph has been modified
:rtype: bool
)mydelimiter");
m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter(
Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
......
...@@ -445,7 +445,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType ...@@ -445,7 +445,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType
forwardDims(dims); forwardDims(dims);
} }
bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency, bool shapeAsConstant) { bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency) {
Log::debug("Starting dimension forward propagation for GraphView"); Log::debug("Starting dimension forward propagation for GraphView");
// remove current Data connections and use dummy inputs to propagate dimensions // remove current Data connections and use dummy inputs to propagate dimensions
// setInputs // setInputs
...@@ -570,33 +570,6 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ ...@@ -570,33 +570,6 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_
Log::debug("Dimensions forwarded for node {} (of type {})", Log::debug("Dimensions forwarded for node {} (of type {})",
nodePtr->name(), nodePtr->type()); nodePtr->name(), nodePtr->type());
if (shapeAsConstant && (nodePtr->type() == Shape_Op::Type)) {
Log::debug("Trying to constant fold the graph");
// Shape are folded if we don't find back the node in the graph
if(constantFolding(shared_from_this(), true)){
Log::notice("Constant folding worked, resetting list of nodes to graph inputs.");
// Graph was modified during constant folding
// We re-propagate dims starting from the entry of the graph
nextList = inputNodes();
for (const auto& currentNodePtr : getNodes()) {
if (currentNodePtr->type() == Producer_Op::Type) {
// Producers are already dims forwarded!
dimsForwarded.insert(currentNodePtr);
// Producers childs are dims forwardable
for (const auto& child : currentNodePtr->getChildren()) {
if (inView(child)) {
nextList.insert(child);
}
}
}
}
Log::debug("Breaking loop to restart from the beginning");
break;
}else{
Log::debug("Constant folding fail to fold any nodes.");
}
}
// Recompute every time, even if it was already computed in a // Recompute every time, even if it was already computed in a
// previous call of forwardDims(), as the graph may have changed! // previous call of forwardDims(), as the graph may have changed!
dimsForwarded.insert(nodePtr); dimsForwarded.insert(nodePtr);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
// #include <cassert>
#include <memory>
#include <set>
#include <string>
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Shape.hpp"
#include "aidge/recipes/Recipes.hpp"
#include "aidge/utils/Log.hpp"
// #include "aidge/utils/Types.h"
bool Aidge::constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims) {
bool modified = false;
bool forwarded = false;
bool not_shape_present = true;
for (auto nodePtr: graph->getNodes())
not_shape_present &= (nodePtr->type() != Shape_Op::Type);
if (not_shape_present)
return false;
do{
forwarded = graph->forwardDims(dims, true);
modified = constantFolding(graph, true);
} while(modified);
if (!forwarded){
Log::warn("Failed to forward GraphView.");
}
return modified;
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment