Forked from
Eclipse Projects / aidge / aidge_core
2183 commits behind the upstream repository.
-
Maxence Naud authoredMaxence Naud authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
Recipies.hpp 2.23 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#ifndef AIDGE_CORE_UTILS_RECIPIES_H_
#define AIDGE_CORE_UTILS_RECIPIES_H_
#include <memory>
#include <set>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
namespace Aidge{
// FUSE MATMUL + ADD -> FC
/**
* @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
*
* @param nodes Strict set of Node to merge.
*/
void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
/**
* @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void fuseMulAdd(std::shared_ptr<GraphView> graphView);
// REMOVE FLATTEN + FC -> FC
/**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
*
* @param nodes Strict set of Node to merge.
*/
void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
/**
* @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void removeFlatten(std::shared_ptr<GraphView> graphView);
// FUSE BN + FC || CONV -> FC || CONV
/**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
*
* @param nodes Strict set of Node to merge.
*/
void fuseBatchNorm(std::set<std::shared_ptr<Node>> nodes);
/**
* @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
* Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void fuseBatchNorm(std::shared_ptr<GraphView> graphView);
}
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */