diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 8a78d8bfc0c2a39e82a483298659a88540e0db2e..8ac05b5b81709f22969f9cf885920bd107780d4b 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -291,9 +291,10 @@ public: * * @param dims Vector of dimension vectors for graph inputs. Empty by default. * @param allowDataDependency Whether to allow data-dependent dimension computation. False by default. + * @param shapeAsConstant If true treat shape as constant and fold the graph, this implies that the graph may change if part of the graph are foldable. False by default. * @return true if dimension propagation succeeded, false otherwise. */ - bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false); + bool forwardDims(const std::vector<std::vector<DimSize_t>>& dims = {}, bool allowDataDependency = false, bool shapeAsConstant = false); /** @brief Set the same backend for each Operator of the GraphView object's Nodes. */ void setBackend(const std::string& backend, const DeviceIdx_t device = 0) const; diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index 31e3a009953c8d5b938dfe6fb888d911bef1e066..ec4119e3267cb0cfac5667da038ac158dd4c9fe1 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -128,7 +128,7 @@ void init_GraphView(py::module& m) { .def("clone", &GraphView::clone) .def("get_nodes", &GraphView::getNodes) .def("get_node", &GraphView::getNode, py::arg("node_name")) - .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, + .def("forward_dims", &GraphView::forwardDims, py::arg("dims")=std::vector<std::vector<DimSize_t>>(), py::arg("allow_data_dependency") = false, py::arg("shape_as_constant") = false, R"mydelimiter( Compute and propagate Tensor dimensions through the GraphView. @@ -167,6 +167,8 @@ void init_GraphView(py::module& m) { Vector of dimension vectors for graph inputs. Empty by default. allow_data_dependency : bool, optional Whether to allow data-dependent dimension computation, by default False + shape_as_constant : bool, optional + If True, treat shape as constant and fold the graph. This implies that the graph may change if part of the graph are foldable, by default False. Returns ------- diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index b372fd3929dd3abca12956d00e2b2effcc49ce0e..0e7951fe8874bb999500bd686cc5ca6e3686c8f7 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -28,10 +28,12 @@ #include "aidge/data/Tensor.hpp" #include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Memorize.hpp" #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/operator/Memorize.hpp" +#include "aidge/operator/Shape.hpp" +#include "aidge/recipes/Recipes.hpp" // constantFolding #include "aidge/utils/Directories.hpp" #include "aidge/utils/FileManagement.hpp" #include "aidge/utils/ErrorHandling.hpp" @@ -443,7 +445,7 @@ void Aidge::GraphView::compile(const std::string& backend, const Aidge::DataType forwardDims(dims); } -bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency) { +bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_t>>& dims, bool allowDataDependency, bool shapeAsConstant) { Log::debug("Starting dimension forward propagation for GraphView"); // remove current Data connections and use dummy inputs to propagate dimensions // setInputs @@ -521,7 +523,7 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ Log::debug("Initializing dimension propagation"); // Establish initial list of dims forwardable nodes: graph input node + Producers childs std::set<std::shared_ptr<Node>> dimsForwarded; ///< List of nodes that are already dims forwarded - std::set<std::shared_ptr<Node>> listNodes = inputNodes(); + std::set<std::shared_ptr<Node>> listNodes = inputNodes(); // list of node to forward dims for (const auto& nodePtr : getNodes()) { if (nodePtr->type() == Producer_Op::Type) { // Producers are already dims forwarded! @@ -534,13 +536,17 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ } } } - do { - std::set<std::shared_ptr<Node>> nextList; + Log::debug("List of node to forward dimensions:"); + for(auto node : listNodes){ + Log::debug("\t- Node {} (of type {})", node->name(), node->type()); + } + std::set<std::shared_ptr<Node>> nextList; // future listNodes for (const auto& nodePtr : listNodes) { + Log::debug("Trying to forward dims of node {} (of type {})", nodePtr->name(), nodePtr->type()); + if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); - bool anyParent = false; bool parentsForwarded = true; for (const auto& parent : nodePtr->getParents()) { @@ -564,6 +570,34 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ Log::debug("Dimensions forwarded for node {} (of type {})", nodePtr->name(), nodePtr->type()); + if (shapeAsConstant && (nodePtr->type() == Shape_Op::Type)) { + Log::debug("Trying to constant fold the graph"); + // Shape are folded if we don't find back the node in the graph + if(constantFolding(shared_from_this(), true)){ + Log::notice("Shape node {} (of type {}) was folded.", nodePtr->name(), nodePtr->type()); + Log::debug("Resetting list of nodes to graph inputs"); + // Graph was modified during constant folding + // We re-propagate dims starting from the entry of the graph + nextList = inputNodes(); + for (const auto& currentNodePtr : getNodes()) { + if (currentNodePtr->type() == Producer_Op::Type) { + // Producers are already dims forwarded! + dimsForwarded.insert(currentNodePtr); + // Producers childs are dims forwardable + for (const auto& child : currentNodePtr->getChildren()) { + if (inView(child)) { + nextList.insert(child); + } + } + } + } + + Log::debug("Breaking loop to restart from the beginning"); + break; + }else{ + Log::debug("Shape node {} (of type {}) was not folded.", nodePtr->name(), nodePtr->type()); + } + } // Recompute every time, even if it was already computed in a // previous call of forwardDims(), as the graph may have changed! dimsForwarded.insert(nodePtr); @@ -577,12 +611,17 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ if (parentsForwarded) { Log::debug("Unable to forward dimensions for node {} (of type {})", nodePtr->name(), nodePtr->type()); } + Log::debug("Adding back node {} (of type {}) to the list of nodes to forward dimensions", nodePtr->name(), nodePtr->type()); nextList.insert(nodePtr); } } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Node {} (of type {}) as it is not an OperatorTensor. ForwardDims is currently only supported for OperatorTensor.", nodePtr->name(), nodePtr->type()); + } + Log::debug("- - - - -"); } - Log::debug("********************"); + Log::debug("Finished treating current list of nodes ..."); // Internal check to make sure we won't enter in an infinite loop! if (nextList == listNodes) {