From b05a6d6324149825bb3cbba21339f177c26e5965 Mon Sep 17 00:00:00 2001
From: hrouis <houssemeddine.rouis92@gmail.com>
Date: Thu, 23 Nov 2023 15:26:20 +0100
Subject: [PATCH] add RemoveDropout recipe

---
 aidge_core/unit_tests/test_recipies.py      | 12 +++++
 include/aidge/operator/Concat.hpp           |  2 +-
 include/aidge/utils/Recipies.hpp            | 18 +++++++
 python_binding/recipies/pybind_Recipies.cpp |  7 +++
 src/recipies/RemoveDropout.cpp              | 56 +++++++++++++++++++++
 5 files changed, 94 insertions(+), 1 deletion(-)
 create mode 100644 src/recipies/RemoveDropout.cpp

diff --git a/aidge_core/unit_tests/test_recipies.py b/aidge_core/unit_tests/test_recipies.py
index 754907443..353b51310 100644
--- a/aidge_core/unit_tests/test_recipies.py
+++ b/aidge_core/unit_tests/test_recipies.py
@@ -20,6 +20,18 @@ class test_recipies(unittest.TestCase):
     def tearDown(self):
         pass
 
+    def test_remove_dropout(self):
+        graph_view = aidge_core.sequential([
+            aidge_core.GenericOperator("Conv", 1, 1, 1, "Conv0");
+            aidge_core.GenericOperator("Dropout", 1, 1, 1, name="Dropout0")
+        ])
+        old_nodes = graph_view.get_nodes()
+        aidge_core.remove_dropout(graph_view)
+        self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 1)
+        self.assertTrue("Dropout0" not in [i.name for i in graph_view.get_nodes()])
+
+        self.assertTrue(all([i in old_nodes for i in graph_view.get_nodes()]))
+
     def test_remove_flatten(self):
         graph_view = aidge_core.sequential([
             aidge_core.GenericOperator("Flatten", 1, 1, 1, name="Flatten0"),
diff --git a/include/aidge/operator/Concat.hpp b/include/aidge/operator/Concat.hpp
index 7a090e2cd..13543c67b 100644
--- a/include/aidge/operator/Concat.hpp
+++ b/include/aidge/operator/Concat.hpp
@@ -170,7 +170,7 @@ public:
     inline IOIndex_t nbDataInputs() const noexcept override final { return mNbIn; }
     inline IOIndex_t nbOutputs() const noexcept override final { return 1; }
     static const std::vector<std::string> getInputsName(){
-        return {"data_input"}; //TODO fix input names cannot access mNbIn bacause of static type
+        return {"data_input_0", "data_input_n"};
     }
     static const std::vector<std::string> getOutputsName(){
         return {"data_output"};
diff --git a/include/aidge/utils/Recipies.hpp b/include/aidge/utils/Recipies.hpp
index bf6683d34..197e959d0 100644
--- a/include/aidge/utils/Recipies.hpp
+++ b/include/aidge/utils/Recipies.hpp
@@ -42,6 +42,24 @@ void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
  */
 void fuseMulAdd(std::shared_ptr<GraphView> graphView);
 
+// REMOVE Dropout
+
+/**
+ * @brief Remove ``Dropout`` Node.
+ *
+ * @param nodes Node to remove.
+ */
+void removeDropout(std::shared_ptr<Node> dropout);
+
+
+void removeDropout(std::shared_ptr<MatchSolution> solution);
+
+/**
+ * @brief Remove ``Dropout`` Node.
+ *
+ * @param graphView Graph view to use graph matching on, in order to apply transfomrations.
+ */
+void removeDropout(std::shared_ptr<GraphView> graphView);
 
 // REMOVE FLATTEN + FC -> FC
 
diff --git a/python_binding/recipies/pybind_Recipies.cpp b/python_binding/recipies/pybind_Recipies.cpp
index 87abf3207..0bc89e7d4 100644
--- a/python_binding/recipies/pybind_Recipies.cpp
+++ b/python_binding/recipies/pybind_Recipies.cpp
@@ -36,6 +36,13 @@ void init_Recipies(py::module &m) {
   //   :type nodes: list of :py:class:`aidge_core.Node`
   //   )mydelimiter");
 
+  m.def("remove_dropout",static_cast<void(*)(std::shared_ptr<GraphView>)>(removeDropout), py::arg("graph_view"), R"mydelimiter(
+    Recipie to remove a dropout operator.
+
+    :param graph_view: Graph view on which we want to apply the recipie
+    :type graph_view: :py:class:`aidge_core.GraphView`
+    )mydelimiter");
+
   m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
     Recipie to remove a flatten operator.
 
diff --git a/src/recipies/RemoveDropout.cpp b/src/recipies/RemoveDropout.cpp
new file mode 100644
index 000000000..c1b3da4a5
--- /dev/null
+++ b/src/recipies/RemoveDropout.cpp
@@ -0,0 +1,56 @@
+/********************************************************************************
+ * Copyright (c) 2023 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+
+#include <memory>
+
+#include "aidge/graph/Node.hpp"
+#include "aidge/graph/GraphView.hpp"
+#include "aidge/utils/Recipies.hpp"
+
+//Graph Regex
+#include "aidge/graphRegex/GraphRegex.hpp"
+
+
+namespace Aidge {
+    void removeDropout(std::shared_ptr<Node> dropout) {
+
+        std::set<NodePtr> nodesToRemove;
+        for (auto nodePtr: dropout->getParents())
+        {
+            if(nodePtr->type() == "Producer")
+            {
+                nodesToRemove.insert(nodePtr);
+            }
+        }
+        nodesToRemove.insert(dropout);
+        GraphView::replace(nodesToRemove, {});
+    }
+
+    void removeDropout(std::shared_ptr<MatchSolution> solution){
+
+        assert(solution->at("Dropout").size() == 1 && "Wrong number of nodes Dropout to replace\n");
+
+        for (const auto& dropout : solution->at("Dropout")) {
+
+            removeDropout(dropout);
+        }
+    }
+
+    void removeDropout(std::shared_ptr<GraphView> graphView){
+        std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
+        regex->setNodeKey("Dropout","getType($) =='Dropout'");
+        regex->addQuery("Dropout");
+
+        for (const auto& solution : regex->match(graphView)) {
+            removeDropout(solution);
+        }
+    }
+}
-- 
GitLab