Skip to content
Snippets Groups Projects
Recipes.hpp 4.04 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

Olivier BICHLER's avatar
Olivier BICHLER committed
#ifndef AIDGE_CORE_UTILS_RECIPES_H_
#define AIDGE_CORE_UTILS_RECIPES_H_
Cyril Moineau's avatar
Cyril Moineau committed

#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graphRegex/matchFsm/MatchResult.hpp"

Cyril Moineau's avatar
Cyril Moineau committed

namespace Aidge {
void constantFolding(std::shared_ptr<GraphView> graph);

// FUSE MATMUL + ADD -> FC

/**
 * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
 *
 * @param nodes Strict set of Node to merge.
 */
//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);

void fuseMulAdd(std::shared_ptr<MatchSolution> solution);

void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);

/**
 * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
 *
 * @param graphView Graph view to use graph matching on, in order to apply transformations.
void fuseMulAdd(std::shared_ptr<GraphView> graphView);
Cyril Moineau's avatar
Cyril Moineau committed

// REMOVE Dropout

/**
 * @brief Remove ``Dropout`` Node.
 *
 * @param nodes Node to remove.
 */
void removeDropout(std::shared_ptr<Node> dropout);


void removeDropout(std::shared_ptr<MatchSolution> solution);

/**
 * @brief Remove ``Dropout`` Node.
 *
 * @param graphView Graph view to use graph matching on, in order to apply transfomrations.
 */
void removeDropout(std::shared_ptr<GraphView> graphView);

// REMOVE FLATTEN + FC -> FC

/**
 * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
 *
 * @param nodes Strict set of Node to merge.
 */
void removeFlatten(std::shared_ptr<Node> flatten);


void removeFlatten(std::shared_ptr<MatchSolution> solution);

/**
 * @brief Remove ``Flatten`` before :cpp:function:`Aidge::FC` Node.
 *
 * @param graphView Graph view to use graph matching on, in order to apply transformations.
void removeFlatten(std::shared_ptr<GraphView> graphView);
// FUSE BN + FC || CONV -> FC || CONV

/**
 * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
 * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
 *
 * @param nodes Strict set of Node to merge.
 */
void fuseBatchNorm(std::shared_ptr<Node> conv,std::shared_ptr<Node> batchnorm);



void fuseBatchNorm(std::shared_ptr<MatchSolution> solution);

/**
 * @brief Fuse :cpp:function:`Aidge::BatchNorm` with :cpp:function:`Aidge::Conv` or :cpp:function:`Aidge::FC` Nodes.
 * Ref: https://nenadmarkus.com/p/fusing-batchnorm-and-conv/
 *
 * @param graphView Graph view to use graph matching on, in order to apply transformations.
 */
void fuseBatchNorm(std::shared_ptr<GraphView> graphView);
Cyril Moineau's avatar
Cyril Moineau committed

std::set<std::shared_ptr<Node>> getConvHorizontalTiling(const std::shared_ptr<Node>& node, const DimIdx_t axis, const std::size_t nbSlices);
// void horizontalTiling(std::shared_ptr<Node> node, DimIdx_t dim, std::size_t nbSlices);
// std::set<std::shared_ptr<Node>> getHorizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);
// void horizontalTiling(std::set<std::shared_ptr<Node>> setOfNodes, DimIdx_t dim, std::size_t nbSlices);

/**
 * Add Convert operators where needed to ensure no conversion needs to be done
 * at the Operator level.
*/
void explicitCastMove(std::shared_ptr<GraphView> graphView);
/**
 * Flatten the graph by replacing the meta operators by their micro graph.
 * @param recursive If true, recursively replace meta operators until there is
 * no more meta operator in the graph.
*/
void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false);
Olivier BICHLER's avatar
Olivier BICHLER committed
void fuseToMetaOps(std::shared_ptr<GraphView> graph, const std::string& query, const std::string& name = "");

Maxence Naud's avatar
Maxence Naud committed
} // namespace Aidge
Cyril Moineau's avatar
Cyril Moineau committed

Olivier BICHLER's avatar
Olivier BICHLER committed
#endif /* AIDGE_CORE_UTILS_RECIPES_H_ */