/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #ifndef AIDGE_CORE_UTILS_RECIPIES_H_ #define AIDGE_CORE_UTILS_RECIPIES_H_ #include <memory> #include <set> #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" namespace Aidge{ // FUSE MATMUL + ADD -> FC /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * * @param nodes Strict set of Node to merge. */ void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); /** * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * * @param graphView Graph view to use graph matching on, in order to apply transfomrations. */ void fuseMulAdd(std::shared_ptr<GraphView> graphView); // REMOVE FLATTEN + FC -> FC /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * * @param nodes Strict set of Node to merge. */ void removeFlatten(std::set<std::shared_ptr<Node>> nodes); /** * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node. * * @param graphView Graph view to use graph matching on, in order to apply transfomrations. */ void removeFlatten(std::shared_ptr<GraphView> graphView); // FUSE BN + FC || CONV -> FC || CONV /** * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ * * @param nodes Strict set of Node to merge. */ void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes); /** * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes. * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/ * * @param graphView Graph view to use graph matching on, in order to apply transfomrations. */ void fuseBatchNorm(std::shared_ptr<GraphView> graphView); } #endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */