Skip to content
Snippets Groups Projects
Commit 2900ebd5 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Added generic removeNode and removeIdentity recipes

parent dd5e7529
No related branches found
No related tags found
2 merge requests!152Update Aidge export to take a graph view has an argument instead of a...,!131New features to simplify exports
Pipeline #47345 failed
......@@ -44,24 +44,30 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
*/
void fuseMulAdd(std::shared_ptr<GraphView> graphView);
// REMOVE Dropout
/**
* @brief Remove ``Dropout`` Node.
* @brief Remove a node type.
*
* @param nodes Node to remove.
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
* @param type Type of the nodes to remove
* @param incProducers If true, also remove the producers attached to the removed nodes
* @return size_t Number of identity nodes removed
*/
void removeDropout(std::shared_ptr<Node> dropout);
void removeDropout(std::shared_ptr<MatchSolution> solution);
size_t removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers = false);
/**
* @brief Remove ``Dropout`` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
* @return size_t Number of identity nodes removed
*/
void removeDropout(std::shared_ptr<GraphView> graphView);
size_t removeDropout(std::shared_ptr<GraphView> graphView);
/**
* Remove all identity nodes
* @param graph Graph to manipulate
* @return size_t Number of identity nodes removed
*/
size_t removeIdentity(std::shared_ptr<GraphView> graph);
// REMOVE FLATTEN + FC -> FC
......
......@@ -15,42 +15,37 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/recipes/Recipes.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge {
void removeDropout(std::shared_ptr<Node> dropout) {
std::set<NodePtr> nodesToRemove;
for (auto nodePtr: dropout->getParents())
{
if(nodePtr->type() == "Producer")
{
nodesToRemove.insert(nodePtr);
size_t Aidge::removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers) {
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey(type, "getType($) =='" + type + "'");
regex->addQuery(type + "#");
const auto matches = regex->match(graphView);
for (const auto& solution : matches) {
assert(solution->at(type).size() == 1 && "Wrong number of nodes to replace\n");
std::set<NodePtr> nodesToRemove = solution->at(type);
if (incProducers) {
for (const auto& nodePtr: (*solution->at(type).begin())->getParents()) {
if (nodePtr->type() == "Producer") {
nodesToRemove.insert(nodePtr);
}
}
}
nodesToRemove.insert(dropout);
GraphView::replace(nodesToRemove, {});
}
void removeDropout(std::shared_ptr<MatchSolution> solution){
assert(solution->at("Dropout").size() == 1 && "Wrong number of nodes Dropout to replace\n");
for (const auto& dropout : solution->at("Dropout")) {
removeDropout(dropout);
}
}
return matches.size();
}
void removeDropout(std::shared_ptr<GraphView> graphView){
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Dropout","getType($) =='Dropout'");
regex->addQuery("Dropout#");
size_t Aidge::removeDropout(std::shared_ptr<GraphView> graphView) {
return removeNode(graphView, "Dropout", true);
}
for (const auto& solution : regex->match(graphView)) {
removeDropout(solution);
}
}
size_t Aidge::removeIdentity(std::shared_ptr<GraphView> graphView) {
return removeNode(graphView, "Identity");
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment