Skip to content
Snippets Groups Projects
Commit b05a6d63 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add RemoveDropout recipe

parent f0acf0de
No related branches found
No related tags found
2 merge requests!59Improvements and fixes,!47Vit operators
...@@ -20,6 +20,18 @@ class test_recipies(unittest.TestCase): ...@@ -20,6 +20,18 @@ class test_recipies(unittest.TestCase):
def tearDown(self): def tearDown(self):
pass pass
def test_remove_dropout(self):
graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Conv", 1, 1, 1, "Conv0");
aidge_core.GenericOperator("Dropout", 1, 1, 1, name="Dropout0")
])
old_nodes = graph_view.get_nodes()
aidge_core.remove_dropout(graph_view)
self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1)
self.assertTrue("Dropout0" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
def test_remove_flatten(self): def test_remove_flatten(self):
graph_view = aidge_core.sequential([ graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"), aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
......
...@@ -170,7 +170,7 @@ public: ...@@ -170,7 +170,7 @@ public:
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; } inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){ static const std::vector<std::string> getInputsName(){
return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type return {"data_input_0", "data_input_n"};
} }
static const std::vector<std::string> getOutputsName(){ static const std::vector<std::string> getOutputsName(){
return {"data_output"}; return {"data_output"};
......
...@@ -42,6 +42,24 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); ...@@ -42,6 +42,24 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
*/ */
void fuseMulAdd(std::shared_ptr<GraphView> graphView); void fuseMulAdd(std::shared_ptr<GraphView> graphView);
// REMOVE Dropout
/**
* @brief Remove ``Dropout`` Node.
*
* @param nodes Node to remove.
*/
void removeDropout(std::shared_ptr<Node> dropout);
void removeDropout(std::shared_ptr<MatchSolution> solution);
/**
* @brief Remove ``Dropout`` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void removeDropout(std::shared_ptr<GraphView> graphView);
// REMOVE FLATTEN + FC -> FC // REMOVE FLATTEN + FC -> FC
......
...@@ -36,6 +36,13 @@ void init_Recipies(py::module &m) { ...@@ -36,6 +36,13 @@ void init_Recipies(py::module &m) {
// :type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter"); // )mydelimiter");
m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a dropout operator.
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator. Recipie to remove a flatten operator.
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/utils/Recipies.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge {
void removeDropout(std::shared_ptr<Node> dropout) {
std::set<NodePtr> nodesToRemove;
for (auto nodePtr: dropout->getParents())
{
if(nodePtr->type() == "Producer")
{
nodesToRemove.insert(nodePtr);
}
}
nodesToRemove.insert(dropout);
GraphView::replace(nodesToRemove, {});
}
void removeDropout(std::shared_ptr<MatchSolution> solution){
assert(solution->at("Dropout").size() == 1 && "Wrong number of nodes Dropout to replace\n");
for (const auto& dropout : solution->at("Dropout")) {
removeDropout(dropout);
}
}
void removeDropout(std::shared_ptr<GraphView> graphView){
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Dropout","getType($) =='Dropout'");
regex->addQuery("Dropout");
for (const auto& solution : regex->match(graphView)) {
removeDropout(solution);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment