From b05a6d6324149825bb3cbba21339f177c26e5965 Mon Sep 17 00:00:00 2001 From: hrouis <houssemeddine.rouis92@gmail.com> Date: Thu, 23 Nov 2023 15:26:20 +0100 Subject: [PATCH] add RemoveDropout recipe --- aidge_core/unit_tests/test_recipies.py | 12 +++++ include/aidge/operator/Concat.hpp | 2 +- include/aidge/utils/Recipies.hpp | 18 +++++++ python_binding/recipies/pybind_Recipies.cpp | 7 +++ src/recipies/RemoveDropout.cpp | 56 +++++++++++++++++++++ 5 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 src/recipies/RemoveDropout.cpp diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 754907443..353b51310 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -20,6 +20,18 @@ class test_recipies(unittest.TestCase): def tearDown(self): pass + def test_remove_dropout(self): + graph_view = aidge_core.sequential([ + aidge_core.GenericOperator("Conv", 1, 1, 1, "Conv0"); + aidge_core.GenericOperator("Dropout", 1, 1, 1, name="Dropout0") + ]) + old_nodes = graph_view.get_nodes() + aidge_core.remove_dropout(graph_view) + self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1) + self.assertTrue("Dropout0" not in [i.name for i in graph_view.get_nodes()]) + + self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()])) + def test_remove_flatten(self): graph_view = aidge_core.sequential([ aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"), diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 7a090e2cd..13543c67b 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -170,7 +170,7 @@ public: inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } static const std::vector<std::string> getInputsName(){ - return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type + return {"data_input_0", "data_input_n"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index bf6683d34..197e959d0 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -42,6 +42,24 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); */ void fuseMulAdd(std::shared_ptr<GraphView> graphView); +// REMOVE Dropout + +/** + * @brief Remove ``Dropout`` Node. + * + * @param nodes Node to remove. + */ +void removeDropout(std::shared_ptr<Node> dropout); + + +void removeDropout(std::shared_ptr<MatchSolution> solution); + +/** + * @brief Remove ``Dropout`` Node. + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ +void removeDropout(std::shared_ptr<GraphView> graphView); // REMOVE FLATTEN + FC -> FC diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 87abf3207..0bc89e7d4 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -36,6 +36,13 @@ void init_Recipies(py::module &m) { // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); + m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter( + Recipie to remove a dropout operator. + + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); + m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. diff --git a/src/recipies/RemoveDropout.cpp b/src/recipies/RemoveDropout.cpp new file mode 100644 index 000000000..c1b3da4a5 --- /dev/null +++ b/src/recipies/RemoveDropout.cpp @@ -0,0 +1,56 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/utils/Recipies.hpp" + +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + + +namespace Aidge { + void removeDropout(std::shared_ptr<Node> dropout) { + + std::set<NodePtr> nodesToRemove; + for (auto nodePtr: dropout->getParents()) + { + if(nodePtr->type() == "Producer") + { + nodesToRemove.insert(nodePtr); + } + } + nodesToRemove.insert(dropout); + GraphView::replace(nodesToRemove, {}); + } + + void removeDropout(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("Dropout").size() == 1 && "Wrong number of nodes Dropout to replace\n"); + + for (const auto& dropout : solution->at("Dropout")) { + + removeDropout(dropout); + } + } + + void removeDropout(std::shared_ptr<GraphView> graphView){ + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Dropout","getType($) =='Dropout'"); + regex->addQuery("Dropout"); + + for (const auto& solution : regex->match(graphView)) { + removeDropout(solution); + } + } +} -- GitLab