diff --git a/aidge_core/unit_tests/test_forward_dims_constant_shape.py b/aidge_core/unit_tests/test_forward_dims_constant_shape.py new file mode 100644 index 0000000000000000000000000000000000000000..fd3fee04b792027e8c33607d1c2156878262d548 --- /dev/null +++ b/aidge_core/unit_tests/test_forward_dims_constant_shape.py @@ -0,0 +1,102 @@ +""" +Copyright (c) 2023 CEA-List + +This program and the accompanying materials are made available under the +terms of the Eclipse Public License 2.0 which is available at +http://www.eclipse.org/legal/epl-2.0. + +SPDX-License-Identifier: EPL-2.0 +""" + +import unittest +import aidge_core +import numpy as np + +class DivImpl(aidge_core.OperatorImpl): + """Div operator implementation to avoid dependency to backend_cpu""" + + def __init__(self, op: aidge_core.Operator): + aidge_core.OperatorImpl.__init__(self, op, "div") + self.op = op + print("Creating divImpl") + def forward(self): + data_input_0 = np.array(self.op.get_input(0)) + data_input_1 = np.array(self.op.get_input(1)) + output = (data_input_0 / data_input_1) + self.op.set_output(0, aidge_core.Tensor(output)) # setting operator output + +# Note: In this test, except Div, every operator are backend independent +aidge_core.register_DivOp("cpu", DivImpl) + +class test_forward_dims_constant_shape(unittest.TestCase): + """Test forwardDims with shapeAsConstant=True + """ + def setUp(self): + # Declaring constant values + prod_two_a = aidge_core.Producer(aidge_core.Tensor(np.array(2, dtype=np.int64)), "two_a", constant=True) + prod_two_b = aidge_core.Producer(aidge_core.Tensor(np.array(2, dtype=np.int64)), "two_b", constant=True) + + # Declaring operators + shape_op_1 = aidge_core.Shape(name="shape_op_1") + shape_op_2 = aidge_core.Shape(name="shape_op_2") + shape_op_3 = aidge_core.Shape(name="shape_op_3") + shape_op_4 = aidge_core.Shape(name="shape_op_4") + gather_op_1 = aidge_core.Gather(axis = 0, indices = [0], name="gather_op_1") + gather_op_2 = aidge_core.Gather(axis = 0, indices = [1], name="gather_op_2") + gather_op_3 = aidge_core.Gather(axis = 0, indices = [2], name="gather_op_3") + gather_op_4 = aidge_core.Gather(axis = 0, indices = [3], name="gather_op_4") + div_op = aidge_core.Div(name="div_op") + + + u_op_1 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_1") + u_op_2 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_2") + u_op_3 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_3") + u_op_4 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_4") + u_op_5 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_5") + u_op_6 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_6") + u_op_7 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_7") + u_op_8 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_8") + u_op_9 = aidge_core.Unsqueeze(axes = [0], name="unsqueeze_op_9") + concat_op_1 = aidge_core.Concat(5, name="concat_op_1") + concat_op_2 = aidge_core.Concat(4, name="concat_op_2") + reshape_op_1 = aidge_core.Reshape(name="reshape_op_1") + reshape_op_2 = aidge_core.Reshape(name="reshape_op_2") + transpose_op_1 = aidge_core.Transpose([0, 2, 1, 3, 4], name="transpose_op_1") + + # Declaring Connectors + x = aidge_core.Connector(aidge_core.Identity(f"Input")) + a = aidge_core.Connector(prod_two_a) + b = aidge_core.Connector(prod_two_b) + + # Graph creation using functional declaration + x1 = shape_op_1(x) + x2 = shape_op_2(x) + x3 = shape_op_3(x) + x4 = shape_op_4(x) + n = gather_op_1(x1) + c = gather_op_2(x2) + h = gather_op_3(x3) + w = gather_op_4(x4) + + shape_1 = concat_op_1(u_op_1(n), u_op_2(a), u_op_3(div_op(c, b)), u_op_4(h), u_op_5(w)) + shape_2 = concat_op_2(u_op_6(n), u_op_7(c), u_op_8(h), u_op_9(w)) + + y = reshape_op_2(transpose_op_1(reshape_op_1(x, shape_1)), shape_2) + + self.graph = aidge_core.generate_graph([y]) + + + def tearDown(self): + pass + + def test_constant_shape_folding(self): + # Note: Except Div every operator are backend independent + self.graph.set_backend("cpu") + self.graph.set_datatype(aidge_core.dtype.float32) + + aidge_core.constant_shape_folding(self.graph, [[5, 12, 24, 24]]) + self.assertEqual(len(self.graph.get_nodes()), 6, "After forward dims with constant folding we don't have the expected number of nodes.") + + +if __name__ == '__main__': + unittest.main() diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 8a78d8bfc0c2a39e82a483298659a88540e0db2e..be325cb9697348031fbc16ff0fcda0c506519c03 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -522,6 +522,9 @@ public: * - same number of input/output connections in oldNodes, parents and children are linked according * to these connections IDs * - different number of input/output connections in oldNodes => return false + * Case 4: newNodes set has no input and one output, oldNodes has any input and one output + * - reconnect output + * - all input are disconnected * @param oldNodes * @param newNodes * @return true replacement has been performed diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 9b16f76d52e1a8d19a225d5ead2d1d47e465fd30..cd2ca38dfdf0f18391e239578bdad67fb61a3750 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -326,7 +326,7 @@ public: * @param otherInId ID of the other Node input to connect to the current Node. * Default to the first available data input. * - * @note otherNode shared_ptr is passed by refenrece in order to be able to detect + * @note otherNode shared_ptr is passed by reference in order to be able to detect * possible dangling connection situations in debug using ref counting. */ void addChild(const NodePtr& otherNode, @@ -507,7 +507,7 @@ private: * @param outId * @param otherInId * - * @note otherNode shared_ptr is passed by refenrece in order to be able to detect + * @note otherNode shared_ptr is passed by reference in order to be able to detect * possible dangling connection situations in debug using ref counting. */ void addChildOp(const NodePtr& otherNode, const IOIndex_t outId, diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index b0bc6dcef823204bff248164d2b6bc13de9b35ec..2d04fc4268ace79eb374cdbe8f874fb71b718ce3 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -22,7 +22,22 @@ namespace Aidge { -void constantFolding(std::shared_ptr<GraphView> graph); +/** + * @brief Retrieve part of the graph that can be pre-computed and replace them by a Producer. + * + * @param graph Graph to fold the constant + * @param constant_shape If true Shape operators are considered to be constant + * @return bool True if the graph has been modified + */ +bool constantFolding(std::shared_ptr<GraphView> graph, bool constantShape=false); + +/** + * @brief Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer. + * + * @param graph Graph to fold the constant + * @return bool True if the graph has been modified + */ +bool constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims = {}); // FUSE MATMUL + ADD -> FC diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index 500367cb8f58bbfbf76394b0c83cd8ff848ff8cb..68ad81b8b57062cb863a75810e086900a4ce5a6d 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -25,6 +25,27 @@ namespace Aidge { void init_Recipes(py::module &m) { + m.def("constant_folding", static_cast<bool(*)(std::shared_ptr<GraphView>, bool)>(constantFolding), py::arg("graph_view"), py::arg("constant_shape") = false, R"mydelimiter( + Retrieve part of the graph that can be pre-computed and replace them by a Producer. + + :param graph_view: Graph view on which we want to apply the recipe + :type graph_view: :py:class:`aidge_core.GraphView` + :param constant_shape: If true, ``Shape`` operator are considered constant, default=False + :type constant_shape: bool, optional + :return: True if the graph has been modified + :rtype: bool + )mydelimiter"); + + m.def("constant_shape_folding", static_cast<bool(*)(std::shared_ptr<GraphView>, const std::vector<std::vector<DimSize_t>>&)>(constantShapeFolding), py::arg("graph_view"), py::arg("dims") = std::vector<std::vector<DimSize_t>>(), R"mydelimiter( + Retrieve part of the graph that can be pre-computed by setting Shape as constant and replace them by a Producer. + + :param graph_view: Graph view on which we want to apply the recipe + :type graph_view: :py:class:`aidge_core.GraphView` + :param constant_shape: If true, ``Shape`` operator are considered constant, default=False + :type constant_shape: bool, optional + :return: True if the graph has been modified + :rtype: bool + )mydelimiter"); m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter( Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index b372fd3929dd3abca12956d00e2b2effcc49ce0e..e193a0af4ccbc0fffe5abdc1a9921226f493e625 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -28,10 +28,10 @@ #include "aidge/data/Tensor.hpp" #include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Memorize.hpp" #include "aidge/operator/MetaOperator.hpp" #include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/Producer.hpp" -#include "aidge/operator/Memorize.hpp" #include "aidge/utils/Directories.hpp" #include "aidge/utils/FileManagement.hpp" #include "aidge/utils/ErrorHandling.hpp" @@ -521,7 +521,7 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ Log::debug("Initializing dimension propagation"); // Establish initial list of dims forwardable nodes: graph input node + Producers childs std::set<std::shared_ptr<Node>> dimsForwarded; ///< List of nodes that are already dims forwarded - std::set<std::shared_ptr<Node>> listNodes = inputNodes(); + std::set<std::shared_ptr<Node>> listNodes = inputNodes(); // list of node to forward dims for (const auto& nodePtr : getNodes()) { if (nodePtr->type() == Producer_Op::Type) { // Producers are already dims forwarded! @@ -534,13 +534,17 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ } } } - do { - std::set<std::shared_ptr<Node>> nextList; + Log::debug("List of node to forward dimensions:"); + for(auto node : listNodes){ + Log::debug("\t- Node {} (of type {})", node->name(), node->type()); + } + std::set<std::shared_ptr<Node>> nextList; // future listNodes for (const auto& nodePtr : listNodes) { + Log::debug("Trying to forward dims of node {} (of type {})", nodePtr->name(), nodePtr->type()); + if (nodePtr->getOperator()->operatorType() == OperatorType::Tensor) { const auto op = std::static_pointer_cast<OperatorTensor>(nodePtr->getOperator()); - bool anyParent = false; bool parentsForwarded = true; for (const auto& parent : nodePtr->getParents()) { @@ -577,12 +581,17 @@ bool Aidge::GraphView::forwardDims(const std::vector<std::vector<Aidge::DimSize_ if (parentsForwarded) { Log::debug("Unable to forward dimensions for node {} (of type {})", nodePtr->name(), nodePtr->type()); } + Log::debug("Adding back node {} (of type {}) to the list of nodes to forward dimensions", nodePtr->name(), nodePtr->type()); nextList.insert(nodePtr); } } + else { + AIDGE_THROW_OR_ABORT(std::runtime_error, "Node {} (of type {}) as it is not an OperatorTensor. ForwardDims is currently only supported for OperatorTensor.", nodePtr->name(), nodePtr->type()); + } + Log::debug("- - - - -"); } - Log::debug("********************"); + Log::debug("Finished treating current list of nodes ..."); // Internal check to make sure we won't enter in an infinite loop! if (nextList == listNodes) { @@ -818,7 +827,7 @@ bool Aidge::GraphView::add(std::set<std::shared_ptr<Node>> otherNodes, bool incl if (mNodeRegistry.find(node->name()) != mNodeRegistry.end()) { std::string newName = node->createUniqueName(node->name()); while (mNodeRegistry.find(newName) != mNodeRegistry.end()) { - newName = node->createUniqueName(newName + "_1"); + newName = node->createUniqueName(newName + "_1"); } Log::notice("node name \"{}\" is a duplicate, renaming to {}.\n", node->name(), newName); node->setName(newName); @@ -1119,12 +1128,12 @@ void Aidge::GraphView::insertParent(NodePtr childNode, /** * Inputs conditions: - * | old \ new | 1 node, 1 input | >1 node, 1 input | 1 node, >1 inputs | >1 node, >1 inputs | - * | ------------------- | ---------------- | ----------------- | ------------------ | ------------------ | - * | 1 node, 1 input | trivial | trivial | broadcast | broadcast | - * | >1 node, 1 input | trivial | trivial | broadcast | broadcast | - * | 1 node, >1 inputs | (take first) | (take first) | same order | X | - * | >1 node, >1 inputs | X | X | X | X | + * | old \ new | 1 node, 1 input | >1 node, 1 input | 1 node, >1 inputs | >1 node, >1 inputs | >=1 node, 0 inputs | + * | ------------------- | ---------------- | ----------------- | ------------------ | ------------------ | ------------------ | + * | 1 node, 1 input | trivial | trivial | broadcast | broadcast | trivial | + * | >1 node, 1 input | trivial | trivial | broadcast | broadcast | trivial | + * | 1 node, >1 inputs | (take first) | (take first) | same order | X | trivial | + * | >1 node, >1 inputs | X | X | X | X | trivial | * * Outputs conditions: * | old \ new | 1 node, 1 output | >1 node, 1 output | 1 node, >1 outputs | >1 node, >1 outputs | @@ -1218,6 +1227,7 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const for (const auto& nodePtr : newNodes) { nodePtr->removeView(newGraph); } + Log::warn("Discrepancy between the number of input/output of the graph to replace.\n\t- OLD NB INPUTS: {} - NEW NB INPUTS {}\n\t- OLD NB OUTPUTS: {} - NEW NB OUTPUTS {}", oldOIn.size(), newOIn.size(), oldOOut.size(), newOOut.size()); return false; } @@ -1287,14 +1297,27 @@ bool Aidge::GraphView::replace(const std::shared_ptr<GraphView>& oldGraph, const newOOut[o].first -> addChild(child.first, newOOut[o].second, child.second); } } - } - else { + } else if ( // for folding case + ((newOIn.size() == 0)) && + ((oldOOut.size() == newOOut.size()) && (newOOut.size() == 1)) + ) { + // Case 4 + // Replace any nodes by a Producer + // No need to remove old inputs it is removed later on ... + for (std::size_t o = 0; o < oldOOut.size(); ++o) { + for (const auto& child : outputChildren[o]) { + newOOut[o].first -> addChild(child.first, newOOut[o].second, child.second); + } + } + + } else { for (const auto& nodePtr : oldNodes) { nodePtr->removeView(oldGraph); } for (const auto& nodePtr : newNodes) { nodePtr->removeView(newGraph); } + Log::warn("Could not replace"); return false; } } diff --git a/src/recipes/ConstantFolding.cpp b/src/recipes/ConstantFolding.cpp index 613393756f7f6b7104118ea97593e3130055ceeb..40cd30a7d908ef04fe1b02a55106211f6153fa38 100644 --- a/src/recipes/ConstantFolding.cpp +++ b/src/recipes/ConstantFolding.cpp @@ -17,39 +17,56 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" +#include "aidge/operator/Shape.hpp" #include "aidge/recipes/Recipes.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/Types.h" -void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { +bool Aidge::constantFolding(std::shared_ptr<GraphView> graph, bool constantShape) { + bool modified = false; + Log::info("Running constant folding on graph {}", graph->name()); bool folded; do { folded = false; std::set<std::shared_ptr<Node>> candidates; for (const std::shared_ptr<Node>& nodePtr : graph->getNodes()) { - if (nodePtr->type() == Producer_Op::Type) { + if (nodePtr->type() == Producer_Op::Type || (constantShape && (nodePtr->type() == Shape_Op::Type))) { const auto& childs = nodePtr->getChildren(); candidates.insert(childs.begin(), childs.end()); } } for (const auto& node : candidates) { + Log::debug("Checking if node {} (of type {}) is foldable", node->name(), node->type()); bool foldable = true; auto replaceGraph = std::make_shared<GraphView>(); size_t i = 0; for (const auto& input : node->inputs()) { if (input.first) { - if (input.first->type() != Producer_Op::Type) { + if (!(input.first->type() == Producer_Op::Type || (constantShape && (input.first->type() == Shape_Op::Type)))) { + Log::debug("Input {} of node {} (of type {}) not foldable, because {} (of type {}) is not a constant.", + i, node->name(), node->type(), input.first->name(), input.first->type()); foldable = false; break; } - const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); - if (!producer->constant()) { - Log::info("Node {} (of type {}) not foldable because Producer input {} not Constant", - node->name(), node->type(), input.first->name()); - foldable = false; - break; + if (constantShape && (input.first->type() == Shape_Op::Type)){ + if (!std::static_pointer_cast<OperatorTensor>(input.first->getOperator())->dimsForwarded()){ + Log::debug("Node {} (of type {}) not foldable because Shape input [{}] {} dims has not been forwarded", + node->name(), node->type(), i, input.first->name()); + foldable = false; + break; + } + } + + if (input.first->type() == Producer_Op::Type){ + const auto& producer = std::static_pointer_cast<Producer_Op>(input.first->getOperator()); + if (!producer->constant()) { + Log::debug("Node {} (of type {}) not foldable because Producer input {} not Constant", + node->name(), node->type(), input.first->name()); + foldable = false; + break; + } } replaceGraph->add(input.first, false); @@ -57,6 +74,8 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { else if (node->inputCategory(i) != InputCategory::OptionalData && node->inputCategory(i) != InputCategory::OptionalParam) { + Log::debug("Input {} of node {} (of type {}) is mandatory but not set, cannot fold.", + i, node->name(), node->type()); foldable = false; break; } @@ -79,9 +98,17 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { // Add output in right order prodGraph->add(newProd); } - + Log::debug("Trying to replace:"); + for(auto nodeToReplace: replaceGraph->getNodes()){ + Log::debug("\t- {} ({})", nodeToReplace->name(), nodeToReplace->type()); + } + Log::debug("With:"); + for(auto nodeReplacing: prodGraph->getNodes()){ + Log::debug("\t- {} ({})", nodeReplacing->name(), nodeReplacing->type()); + } if (GraphView::replace(replaceGraph, prodGraph)) { folded = true; + modified = true; } else { Log::warn("Error with replace when folding node {} (of type {})", @@ -91,4 +118,5 @@ void Aidge::constantFolding(std::shared_ptr<GraphView> graph) { } } while (folded); + return modified; } diff --git a/src/recipes/ShapeFolding.cpp b/src/recipes/ShapeFolding.cpp new file mode 100644 index 0000000000000000000000000000000000000000..72d13ab52ca4a9bc84b7b41c6e2349577f207732 --- /dev/null +++ b/src/recipes/ShapeFolding.cpp @@ -0,0 +1,42 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +// #include <cassert> +#include <memory> +#include <set> +#include <string> + +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Shape.hpp" +#include "aidge/recipes/Recipes.hpp" +#include "aidge/utils/Log.hpp" +// #include "aidge/utils/Types.h" + +bool Aidge::constantShapeFolding(std::shared_ptr<GraphView> graph, const std::vector<std::vector<DimSize_t>>& dims) { + bool modified = false; + bool forwarded = false; + bool not_shape_present = true; + bool was_modified = false; + for (auto nodePtr: graph->getNodes()) + not_shape_present &= (nodePtr->type() != Shape_Op::Type); + if (not_shape_present) + return false; + do{ + forwarded = graph->forwardDims(dims, false); + modified = constantFolding(graph, true); + was_modified = true; + } while(modified); + if (!forwarded){ + Log::warn("Failed to forward GraphView."); + } + + return was_modified; +}