diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py index 754907443530f7e73d1e10ed9549d0c8eb78a011..353b5131038bc11bd5279300fe7e4da8bb3f5664 100644 --- a/aidge_core/unit_tests/test_recipies.py +++ b/aidge_core/unit_tests/test_recipies.py @@ -20,6 +20,18 @@ class test_recipies(unittest.TestCase): def tearDown(self): pass + def test_remove_dropout(self): + graph_view = aidge_core.sequential([ + aidge_core.GenericOperator("Conv", 1, 1, 1, "Conv0"); + aidge_core.GenericOperator("Dropout", 1, 1, 1, name="Dropout0") + ]) + old_nodes = graph_view.get_nodes() + aidge_core.remove_dropout(graph_view) + self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1) + self.assertTrue("Dropout0" not in [i.name for i in graph_view.get_nodes()]) + + self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()])) + def test_remove_flatten(self): graph_view = aidge_core.sequential([ aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"), diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp index 7a090e2cd084d552de0fe71e62b1f2e1b23a0d4a..13543c67b1d7b632692961786ef4e951d7758100 100644 --- a/include/aidge/operator/Concat.hpp +++ b/include/aidge/operator/Concat.hpp @@ -170,7 +170,7 @@ public: inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; } inline IOIndex_t nbOutputs() const noexcept override final { return 1; } static const std::vector<std::string> getInputsName(){ - return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type + return {"data_input_0", "data_input_n"}; } static const std::vector<std::string> getOutputsName(){ return {"data_output"}; diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp index bf6683d342f2b140aef41459bd6633340de3e93d..197e959d01156b840e5a86489c056deb06a37d4d 100644 --- a/include/aidge/utils/Recipies.hpp +++ b/include/aidge/utils/Recipies.hpp @@ -42,6 +42,24 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add); */ void fuseMulAdd(std::shared_ptr<GraphView> graphView); +// REMOVE Dropout + +/** + * @brief Remove ``Dropout`` Node. + * + * @param nodes Node to remove. + */ +void removeDropout(std::shared_ptr<Node> dropout); + + +void removeDropout(std::shared_ptr<MatchSolution> solution); + +/** + * @brief Remove ``Dropout`` Node. + * + * @param graphView Graph view to use graph matching on, in order to apply transfomrations. + */ +void removeDropout(std::shared_ptr<GraphView> graphView); // REMOVE FLATTEN + FC -> FC diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp index 87abf32073734b37803e4330d56888388c63b9af..0bc89e7d428181dac0fe45e935f59433cca70b89 100644 --- a/python_binding/recipies/pybind_Recipies.cpp +++ b/python_binding/recipies/pybind_Recipies.cpp @@ -36,6 +36,13 @@ void init_Recipies(py::module &m) { // :type nodes: list of :py:class:`aidge_core.Node` // )mydelimiter"); + m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter( + Recipie to remove a dropout operator. + + :param graph_view: Graph view on which we want to apply the recipie + :type graph_view: :py:class:`aidge_core.GraphView` + )mydelimiter"); + m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter( Recipie to remove a flatten operator. diff --git a/src/recipies/RemoveDropout.cpp b/src/recipies/RemoveDropout.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c1b3da4a54601a73a2d29deb7aceec8f893040e0 --- /dev/null +++ b/src/recipies/RemoveDropout.cpp @@ -0,0 +1,56 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <memory> + +#include "aidge/graph/Node.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/utils/Recipies.hpp" + +//Graph Regex +#include "aidge/graphRegex/GraphRegex.hpp" + + +namespace Aidge { + void removeDropout(std::shared_ptr<Node> dropout) { + + std::set<NodePtr> nodesToRemove; + for (auto nodePtr: dropout->getParents()) + { + if(nodePtr->type() == "Producer") + { + nodesToRemove.insert(nodePtr); + } + } + nodesToRemove.insert(dropout); + GraphView::replace(nodesToRemove, {}); + } + + void removeDropout(std::shared_ptr<MatchSolution> solution){ + + assert(solution->at("Dropout").size() == 1 && "Wrong number of nodes Dropout to replace\n"); + + for (const auto& dropout : solution->at("Dropout")) { + + removeDropout(dropout); + } + } + + void removeDropout(std::shared_ptr<GraphView> graphView){ + std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); + regex->setNodeKey("Dropout","getType($) =='Dropout'"); + regex->addQuery("Dropout"); + + for (const auto& solution : regex->match(graphView)) { + removeDropout(solution); + } + } +}