From e45c2c7e167d432e7729f111223624bfd8be761a Mon Sep 17 00:00:00 2001 From: Charles Villard <charles.villard@cea.fr> Date: Thu, 23 Jan 2025 16:24:39 +0100 Subject: [PATCH] WIP: filter input with bitwise category --- include/aidge/graph/GraphView.hpp | 2 +- include/aidge/operator/Operator.hpp | 18 ++++-- include/aidge/utils/BitwiseUtils.hpp | 87 ++++++++++++++++++++++++++++ src/graph/GraphView.cpp | 4 +- 4 files changed, 103 insertions(+), 8 deletions(-) create mode 100644 include/aidge/utils/BitwiseUtils.hpp diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 76f5dcdfc..ebb36b49a 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -136,7 +136,7 @@ public: /////////////////////////////////////////////////////// public: /** @brief Get reference to the set of input Nodes. */ - std::set<NodePtr> inputNodes() const; + std::set<NodePtr> inputNodes(InputCategory filter = InputCategory::All) const; /** @brief Get reference to the set of output Nodes. */ std::set<NodePtr> outputNodes() const; diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 95698b751..6b1f366b6 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -16,6 +16,7 @@ #include <string> #include <vector> #include <utility> +#include <type_traits> #include <cstddef> #ifdef PYBIND @@ -28,6 +29,7 @@ #include "aidge/data/Data.hpp" #include "aidge/utils/Attributes.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/BitwiseUtils.hpp" #ifdef PYBIND @@ -40,13 +42,18 @@ enum class OperatorType { Tensor }; -enum class InputCategory { - Data, - Param, - OptionalData, - OptionalParam +enum class InputCategory : unsigned int { + Optional = 1 << 0, // First bit indicate if optional + Data = 1 << 1, //Second bit is a Data + OptionalData = (1 << 1) | Optional, + Param = 1 << 2, // Third bit is for param + OptionalParam = (1 << 2) | Optional, + All = static_cast<unsigned int>(-1) }; +template <> +constexpr bool enable_bitmask_operators<InputCategory> = true; + class Operator : public std::enable_shared_from_this<Operator> { protected: std::shared_ptr<OperatorImpl> mImpl; // implementation of the operator @@ -243,6 +250,7 @@ public: } #endif }; + } // namespace Aidge #endif /* AIDGE_CORE_OPERATOR_OPERATOR_H_ */ diff --git a/include/aidge/utils/BitwiseUtils.hpp b/include/aidge/utils/BitwiseUtils.hpp new file mode 100644 index 000000000..208115fa1 --- /dev/null +++ b/include/aidge/utils/BitwiseUtils.hpp @@ -0,0 +1,87 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#ifndef AIDGE_BITWISEUTILS_H_ +#define AIDGE_BITWISEUTILS_H_ + +#include <type_traits> + +namespace Aidge +{ + +// Define a trait to enable bitwise operators for an enum class +template <typename E> +constexpr bool enable_bitmask_operators = false; + +// Define bitwise OR +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E> +operator|(E lhs, E rhs) { + using underlying = typename std::underlying_type_t<E>; + return static_cast<E>(static_cast<underlying>(lhs) | static_cast<underlying>(rhs)); +} + +// Define bitwise AND +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E> +operator&(E lhs, E rhs) { + using underlying = typename std::underlying_type_t<E>; + return static_cast<E>(static_cast<underlying>(lhs) & static_cast<underlying>(rhs)); +} + +// Define bitwise XOR +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E> +operator^(E lhs, E rhs) { + using underlying = typename std::underlying_type_t<E>; + return static_cast<E>(static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs)); +} + +// Define bitwise NOT +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E> +operator~(E value) { + using underlying = typename std::underlying_type_t<E>; + return static_cast<E>(~static_cast<underlying>(value)); +} + +// Define compound OR assignment +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&> +operator|=(E& lhs, E rhs) { + lhs = lhs | rhs; + return lhs; +} + +// Define compound AND assignment +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&> +operator&=(E& lhs, E rhs) { + lhs = lhs & rhs; + return lhs; +} + +// Define compound XOR assignment +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&> +operator^=(E& lhs, E rhs) { + lhs = lhs ^ rhs; + return lhs; +} + +template <typename E> +constexpr auto to_underlying(E e) noexcept +{ + return static_cast<std::underlying_type_t<E>>(e); +} + +} // namespace Aidge + +#endif // diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index 465359757..9bd142d14 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -275,11 +275,11 @@ void Aidge::GraphView::setRootNode(NodePtr node) { // TENSOR MANAGEMENT /////////////////////////////////////////////////////// -std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const { +std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes(InputCategory filter) const { std::set<std::shared_ptr<Aidge::Node>> nodes; for (const auto& node : mInputNodes) { // Do not include dummy inputs - if (node.first) { + if (node.first && to_underlying(node.first->inputCategory(node.second) & filter)) { nodes.insert(node.first); } } -- GitLab