Skip to content
Snippets Groups Projects
Commit b05a6d63 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add RemoveDropout recipe

parent f0acf0de
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
......@@ -20,6 +20,18 @@ class test_recipies(unittest.TestCase):
def tearDown(self):
pass
def test_remove_dropout(self):
graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Conv", 1, 1, 1, "Conv0");
aidge_core.GenericOperator("Dropout", 1, 1, 1, name="Dropout0")
])
old_nodes = graph_view.get_nodes()
aidge_core.remove_dropout(graph_view)
self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1)
self.assertTrue("Dropout0" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
def test_remove_flatten(self):
graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
......
......@@ -170,7 +170,7 @@ public:
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type
return {"data_input_0", "data_input_n"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
......
......@@ -42,6 +42,24 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
*/
void fuseMulAdd(std::shared_ptr<GraphView> graphView);
// REMOVE Dropout
/**
* @brief Remove ``Dropout`` Node.
*
* @param nodes Node to remove.
*/
void removeDropout(std::shared_ptr<Node> dropout);
void removeDropout(std::shared_ptr<MatchSolution> solution);
/**
* @brief Remove ``Dropout`` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void removeDropout(std::shared_ptr<GraphView> graphView);
// REMOVE FLATTEN + FC -> FC
......
......@@ -36,6 +36,13 @@ void init_Recipies(py::module &m) {
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a dropout operator.
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/utils/Recipies.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge {
void removeDropout(std::shared_ptr<Node> dropout) {
std::set<NodePtr> nodesToRemove;
for (auto nodePtr: dropout->getParents())
{
if(nodePtr->type() == "Producer")
{
nodesToRemove.insert(nodePtr);
}
}
nodesToRemove.insert(dropout);
GraphView::replace(nodesToRemove, {});
}
void removeDropout(std::shared_ptr<MatchSolution> solution){
assert(solution->at("Dropout").size() == 1 && "Wrong number of nodes Dropout to replace\n");
for (const auto& dropout : solution->at("Dropout")) {
removeDropout(dropout);
}
}
void removeDropout(std::shared_ptr<GraphView> graphView){
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Dropout","getType($) =='Dropout'");
regex->addQuery("Dropout");
for (const auto& solution : regex->match(graphView)) {
removeDropout(solution);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment