diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index c42b285dacb6c59c5fa30388c268f1680152a5e0..aea39ded3e5f2547f6f47fbc5aa27d5f1ee4821f 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -50,6 +50,13 @@ void matMulToFC(std::shared_ptr<GraphView> graphView); */ size_t removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers = false); +/** + * @brief Fuses constant => Generic | constantOfShape and transforms it into a Producer + * @param graph Graph to manipulate + * @return size_t Number of replacement + */ +size_t removeConstantOfShape(std::shared_ptr<GraphView> graph_view); + /** * @brief Remove ``Dropout`` Node. * diff --git a/python_binding/recipes/pybind_Recipes.cpp b/python_binding/recipes/pybind_Recipes.cpp index b68dfd035921a1dce4d12b9071a8df194e2ffdd5..a23b54e6f02f832fbc70482329966445f723b573 100644 --- a/python_binding/recipes/pybind_Recipes.cpp +++ b/python_binding/recipes/pybind_Recipes.cpp @@ -15,6 +15,7 @@ #include <cstddef> #include <string> +#include "aidge/graph/GraphView.hpp" #include "aidge/recipes/Recipes.hpp" #include "aidge/utils/Types.h" @@ -78,12 +79,12 @@ void init_Recipes(py::module &m) :type graph_view: :py:class:`aidge_core.GraphView` )mydelimiter"); - // m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter( - // Recipe to remove a flatten operator. + m.def("remove_constantOfShape", static_cast<size_t(*)(std::shared_ptr<GraphView>)>(removeConstantOfShape), py::arg("graph_view"), R"mydelimiter( + Fuses constant => Generic | constantOfShape and transforms it into a Producer - // :param nodes: The flatten operator to remove. - // :type nodes: list of :py:class:`aidge_core.Node` - // )mydelimiter"); + :param graph_view: Graph view on which we want to apply the recipe. + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( Recipe to remove a flatten operator. diff --git a/src/recipes/removeConstantOfShape.cpp b/src/recipes/removeConstantOfShape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..5e84f7b494815ecb5a8937bb6f76ba1de80ad3f9 --- /dev/null +++ b/src/recipes/removeConstantOfShape.cpp @@ -0,0 +1,128 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include "aidge/recipes/Recipes.hpp" + +#include <algorithm> +#include <cassert> +#include <cstddef> +#include <cstdint> +#include <cstdio> +#include <functional> +#include <memory> +#include <numeric> +#include <set> +#include <stdexcept> +#include <string> + +#include "aidge/data/Data.hpp" +#include "aidge/data/Tensor.hpp" +#include "aidge/filler/Filler.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Matching.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/ConstantOfShape.hpp" +#include "aidge/operator/GenericOperator.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +// Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + +namespace Aidge { + +size_t removeConstantOfShape(std::shared_ptr<GraphView> graph_view) { + const auto matches = + SinglePassGraphMatching(graph_view).match("Producer->ConstantOfShape"); + + size_t nbReplaced = 0; + for (const auto &match : matches) { + const auto prod_node = match.graph->rootNode(); + const auto prod_op = + std::static_pointer_cast<Producer_Op>(prod_node->getOperator()); + + const NodePtr constantofshape_node = + prod_node->getOrderedChildren().at(0).at(0); + + const auto constantofshape_op = + std::static_pointer_cast<ConstantOfShape_Op>( + constantofshape_node->getOperator()); + + if (prod_op->getOutput(0)->nbDims() != 1) { + Log::debug("{} : Producer output dimension number is {} != 1 and {} " + "input has to have 1 dim, skipping match.", + __func__, prod_op->getOutput(0)->nbDims(), + ConstantOfShape_Op::Type); + continue; + } + if (!prod_op->constant()) { + Log::debug("{} : Producer is not constant, skipping match.", __func__); + continue; + } + if (prod_op->getOutput(0)->dataType() != DataType::Int64) { + AIDGE_THROW_OR_ABORT( + std::runtime_error, + "{} : Producer output dtype is {} != int64 and {} " + "input type is restricted to int64_t, this is an error." + "Fix your network. skipping match.", + __func__, prod_op->getOutput(0)->dataType(), + ConstantOfShape_Op::Type); + continue; + } + + auto graph_to_replace = std::make_shared<GraphView>(); + auto new_graph = std::make_shared<GraphView>(); + graph_to_replace->add(constantofshape_node); + if (prod_node->getChildren().size() == 1) { + graph_to_replace->add(prod_node); + } else { + Log::debug("{} : Producer node has multiple children, only" + "replacing the {} node.", + __func__, ConstantOfShape_Op::Type); + } + + prod_node->forward(); + std::shared_ptr<Tensor> prod_output = prod_op->getOutput(0); + std::vector<DimSize_t> new_input_dims; + new_input_dims.reserve(prod_output->dims()[0]); + for (DimSize_t i = 0; i < prod_output->size(); ++i) { + new_input_dims.push_back(prod_output->get<int64_t>(i)); + } + + auto new_input = std::make_shared<Tensor>(new_input_dims); + new_input->setBackend(prod_op->backend() == "" ? "cpu" + : prod_op->backend()); + new_input->setDataType(constantofshape_op->value().dataType()); + for (std::size_t i = 0; i < new_input->size(); ++i) { + new_input->getImpl()->copy( + constantofshape_op->value().getImpl()->rawPtr(), 1, i); + } + auto new_prod = + Producer(new_input, prod_node->name() + "_constant_of_shape", true); + new_graph->add(new_prod); + + const auto success = GraphView::replace(graph_to_replace, new_graph); + if (!success) { + Log::warn("Could not replace Producer({})->ConstantOfShape({}) with" + "Producer", + prod_node->name(), constantofshape_node->name()); + } else { + ++nbReplaced; + } + } + + Log::info("Replaced {} (out of {}) matching Producer->ConstantOfShape with " + "Producers", + nbReplaced, matches.size()); + return nbReplaced; +} +} // namespace Aidge + diff --git a/unit_tests/recipes/Test_removeConstantOfShape.cpp b/unit_tests/recipes/Test_removeConstantOfShape.cpp new file mode 100644 index 0000000000000000000000000000000000000000..247149a0fdb1087f14ac17d125659d677ccfb506 --- /dev/null +++ b/unit_tests/recipes/Test_removeConstantOfShape.cpp @@ -0,0 +1,50 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include "aidge/graph/GraphView.hpp" +#include "aidge/operator/Identity.hpp" +#include "aidge/recipes/Recipes.hpp" + +#include <cstddef> +#include <cstdint> +#include <memory> +#include <vector> + +#include <catch2/catch_test_macros.hpp> + +#include "aidge/graph/OpArgs.hpp" +#include "aidge/operator/Add.hpp" +#include "aidge/operator/ConstantOfShape.hpp" +#include "aidge/operator/Conv.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/ReLU.hpp" +#include "aidge/utils/ArrayHelpers.hpp" +#include "aidge/utils/Types.h" + +using namespace Aidge; + +TEST_CASE("[cpu/recipies] removeConstantOfShape", + "[ConstantOfShape][removeConstantOfShape][recipies]") { + auto input_T = std::make_shared<Tensor>(Array1D<int64_t, 4>({1, 1, 3, 3})); + + auto model = std::make_shared<GraphView>(); + SECTION("Sequential model") { + model = Sequential({Producer(input_T, "prod_0", true), + ConstantOfShape(3, "constantOfShape_0"), + Conv(1, 1, {3, 3}, "Conv_0"), ReLU("ReLU_1")}); + model->save("test_removeConstantOfShape_model_before_1"); + CHECK(removeConstantOfShape(model) == 1); + CHECK(model->forwardDims()); + model->save("test_removeConstantOfShape_model_after_1"); + } +} +