Skip to content
Snippets Groups Projects
Commit b05a6d63 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add RemoveDropout recipe

parent f0acf0de
No related branches found
No related tags found
No related merge requests found
......@@ -20,6 +20,18 @@ class test_recipies(unittest.TestCase):
def tearDown(self):
pass
def test_remove_dropout(self):
graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Conv", 1, 1, 1, "Conv0");
aidge_core.GenericOperator("Dropout", 1, 1, 1, name="Dropout0")
])
old_nodes = graph_view.get_nodes()
aidge_core.remove_dropout(graph_view)
self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1)
self.assertTrue("Dropout0" not in [i.name for i in graph_view.get_nodes()])
self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
def test_remove_flatten(self):
graph_view = aidge_core.sequential([
aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
......
......@@ -170,7 +170,7 @@ public:
inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
static const std::vector<std::string> getInputsName(){
return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type
return {"data_input_0", "data_input_n"};
}
static const std::vector<std::string> getOutputsName(){
return {"data_output"};
......
......@@ -42,6 +42,24 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
*/
void fuseMulAdd(std::shared_ptr<GraphView> graphView);
// REMOVE Dropout
/**
* @brief Remove ``Dropout`` Node.
*
* @param nodes Node to remove.
*/
void removeDropout(std::shared_ptr<Node> dropout);
void removeDropout(std::shared_ptr<MatchSolution> solution);
/**
* @brief Remove ``Dropout`` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transfomrations.
*/
void removeDropout(std::shared_ptr<GraphView> graphView);
// REMOVE FLATTEN + FC -> FC
......
......@@ -36,6 +36,13 @@ void init_Recipies(py::module &m) {
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a dropout operator.
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/utils/Recipies.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
namespace Aidge {
void removeDropout(std::shared_ptr<Node> dropout) {
std::set<NodePtr> nodesToRemove;
for (auto nodePtr: dropout->getParents())
{
if(nodePtr->type() == "Producer")
{
nodesToRemove.insert(nodePtr);
}
}
nodesToRemove.insert(dropout);
GraphView::replace(nodesToRemove, {});
}
void removeDropout(std::shared_ptr<MatchSolution> solution){
assert(solution->at("Dropout").size() == 1 && "Wrong number of nodes Dropout to replace\n");
for (const auto& dropout : solution->at("Dropout")) {
removeDropout(dropout);
}
}
void removeDropout(std::shared_ptr<GraphView> graphView){
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Dropout","getType($) =='Dropout'");
regex->addQuery("Dropout");
for (const auto& solution : regex->match(graphView)) {
removeDropout(solution);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment