From e45c2c7e167d432e7729f111223624bfd8be761a Mon Sep 17 00:00:00 2001
From: Charles Villard <charles.villard@cea.fr>
Date: Thu, 23 Jan 2025 16:24:39 +0100
Subject: [PATCH] WIP: filter input with bitwise category

---
 include/aidge/graph/GraphView.hpp    |  2 +-
 include/aidge/operator/Operator.hpp  | 18 ++++--
 include/aidge/utils/BitwiseUtils.hpp | 87 ++++++++++++++++++++++++++++
 src/graph/GraphView.cpp              |  4 +-
 4 files changed, 103 insertions(+), 8 deletions(-)
 create mode 100644 include/aidge/utils/BitwiseUtils.hpp

diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp
index 76f5dcdfc..ebb36b49a 100644
--- a/include/aidge/graph/GraphView.hpp
+++ b/include/aidge/graph/GraphView.hpp
@@ -136,7 +136,7 @@ public:
 ///////////////////////////////////////////////////////
 public:
     /** @brief Get reference to the set of input Nodes. */
-    std::set<NodePtr> inputNodes() const;
+    std::set<NodePtr> inputNodes(InputCategory filter = InputCategory::All) const;
 
     /** @brief Get reference to the set of output Nodes. */
     std::set<NodePtr> outputNodes() const;
diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp
index 95698b751..6b1f366b6 100644
--- a/include/aidge/operator/Operator.hpp
+++ b/include/aidge/operator/Operator.hpp
@@ -16,6 +16,7 @@
 #include <string>
 #include <vector>
 #include <utility>
+#include <type_traits>
 #include <cstddef>
 
 #ifdef PYBIND
@@ -28,6 +29,7 @@
 #include "aidge/data/Data.hpp"
 #include "aidge/utils/Attributes.hpp"
 #include "aidge/utils/Types.h"
+#include "aidge/utils/BitwiseUtils.hpp"
 
 
 #ifdef PYBIND
@@ -40,13 +42,18 @@ enum class OperatorType {
     Tensor
 };
 
-enum class InputCategory {
-    Data,
-    Param,
-    OptionalData,
-    OptionalParam
+enum class InputCategory : unsigned int {
+    Optional = 1 << 0, // First bit indicate if optional
+    Data = 1 << 1, //Second bit is a Data
+    OptionalData = (1 << 1) | Optional,
+    Param = 1 << 2, // Third bit is for param
+    OptionalParam = (1 << 2) | Optional,
+    All = static_cast<unsigned int>(-1)
 };
 
+template <>
+constexpr bool enable_bitmask_operators<InputCategory> = true;
+
 class Operator : public std::enable_shared_from_this<Operator> {
 protected:
     std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator
@@ -243,6 +250,7 @@ public:
     }
 #endif
 };
+
 } // namespace Aidge
 
 #endif /* AIDGE_CORE_OPERATOR_OPERATOR_H_ */
diff --git a/include/aidge/utils/BitwiseUtils.hpp b/include/aidge/utils/BitwiseUtils.hpp
new file mode 100644
index 000000000..208115fa1
--- /dev/null
+++ b/include/aidge/utils/BitwiseUtils.hpp
@@ -0,0 +1,87 @@
+/********************************************************************************
+ * Copyright (c) 2024 CEA-List
+ *
+ * This program and the accompanying materials are made available under the
+ * terms of the Eclipse Public License 2.0 which is available at
+ * http://www.eclipse.org/legal/epl-2.0.
+ *
+ * SPDX-License-Identifier: EPL-2.0
+ *
+ ********************************************************************************/
+#ifndef AIDGE_BITWISEUTILS_H_
+#define AIDGE_BITWISEUTILS_H_
+
+#include <type_traits>
+
+namespace Aidge
+{
+
+// Define a trait to enable bitwise operators for an enum class
+template <typename E>
+constexpr bool enable_bitmask_operators = false;
+
+// Define bitwise OR
+template <typename E>
+constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E>
+operator|(E lhs, E rhs) {
+    using underlying = typename std::underlying_type_t<E>;
+    return static_cast<E>(static_cast<underlying>(lhs) | static_cast<underlying>(rhs));
+}
+
+// Define bitwise AND
+template <typename E>
+constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E>
+operator&(E lhs, E rhs) {
+    using underlying = typename std::underlying_type_t<E>;
+    return static_cast<E>(static_cast<underlying>(lhs) & static_cast<underlying>(rhs));
+}
+
+// Define bitwise XOR
+template <typename E>
+constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E>
+operator^(E lhs, E rhs) {
+    using underlying = typename std::underlying_type_t<E>;
+    return static_cast<E>(static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs));
+}
+
+// Define bitwise NOT
+template <typename E>
+constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E>
+operator~(E value) {
+    using underlying = typename std::underlying_type_t<E>;
+    return static_cast<E>(~static_cast<underlying>(value));
+}
+
+// Define compound OR assignment
+template <typename E>
+constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&>
+operator|=(E& lhs, E rhs) {
+    lhs = lhs | rhs;
+    return lhs;
+}
+
+// Define compound AND assignment
+template <typename E>
+constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&>
+operator&=(E& lhs, E rhs) {
+    lhs = lhs & rhs;
+    return lhs;
+}
+
+// Define compound XOR assignment
+template <typename E>
+constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&>
+operator^=(E& lhs, E rhs) {
+    lhs = lhs ^ rhs;
+    return lhs;
+}
+
+template <typename E>
+constexpr auto to_underlying(E e) noexcept
+{
+    return static_cast<std::underlying_type_t<E>>(e);
+}
+
+} // namespace Aidge
+
+#endif //
diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp
index 465359757..9bd142d14 100644
--- a/src/graph/GraphView.cpp
+++ b/src/graph/GraphView.cpp
@@ -275,11 +275,11 @@ void Aidge::GraphView::setRootNode(NodePtr node) {
 //        TENSOR MANAGEMENT
 ///////////////////////////////////////////////////////
 
-std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const {
+std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes(InputCategory filter) const {
     std::set<std::shared_ptr<Aidge::Node>> nodes;
     for (const auto& node : mInputNodes) {
         // Do not include dummy inputs
-        if (node.first) {
+        if (node.first && to_underlying(node.first->inputCategory(node.second) & filter)) {
             nodes.insert(node.first);
         }
     }
-- 
GitLab