diff --git a/include/aidge/graph/GraphView.hpp b/include/aidge/graph/GraphView.hpp index 081c429e869a7897d9a24b4633f87a7f6efd68e3..507d34f157b4b73d3bc96748d5afb6a20217e395 100644 --- a/include/aidge/graph/GraphView.hpp +++ b/include/aidge/graph/GraphView.hpp @@ -136,7 +136,7 @@ public: /////////////////////////////////////////////////////// public: /** @brief Get reference to the set of input Nodes. */ - std::set<NodePtr> inputNodes() const; + std::set<NodePtr> inputNodes(InputCategory filter = InputCategory::All) const; /** @brief Get reference to the set of output Nodes. */ std::set<NodePtr> outputNodes() const; diff --git a/include/aidge/graph/Node.hpp b/include/aidge/graph/Node.hpp index 7319db245442e1d37fb3489ab289794ec9091f47..c06cf0c8b791394423cd96926db448fd1938ed7e 100644 --- a/include/aidge/graph/Node.hpp +++ b/include/aidge/graph/Node.hpp @@ -283,13 +283,13 @@ public: inline IOIndex_t getFirstFreeDataInput() const { IOIndex_t i = 0; for (; i < nbInputs(); ++i) { - if ((inputCategory(i) == InputCategory::Data || inputCategory(i) == InputCategory::OptionalData) + if (to_underlying(inputCategory(i) & InputCategory::Data) && input(i).second == gk_IODefaultIndex) { - break; + return i; } } - return (i < nbInputs()) ? i : gk_IODefaultIndex; + return gk_IODefaultIndex; } /** diff --git a/include/aidge/operator/Operator.hpp b/include/aidge/operator/Operator.hpp index 0392dcfa23af7510b6db96743267fbe56eb18a63..b53c47c18176f1050cf2c2e317de8d49fb2a319e 100644 --- a/include/aidge/operator/Operator.hpp +++ b/include/aidge/operator/Operator.hpp @@ -17,6 +17,8 @@ #include <string> #include <vector> #include <utility> +#include <type_traits> +#include <cstddef> #ifdef PYBIND #include <pybind11/pybind11.h> @@ -27,6 +29,7 @@ #include "aidge/data/Data.hpp" #include "aidge/utils/Attributes.hpp" #include "aidge/utils/Types.h" +#include "aidge/utils/BitwiseUtils.hpp" #ifdef PYBIND namespace py = pybind11; @@ -56,13 +59,25 @@ enum class OperatorType { * @enum InputCategory * @brief Describes the category of an input for an operator. */ -enum class InputCategory { - Data, /**< Regular data input. */ - Param, /**< Parameter input. */ - OptionalData, /**< Optional data input. */ - OptionalParam /**< Optional parameter input. */ +enum class InputCategory : unsigned int { + Optional = 1 << 0, // First bit indicate if optional + Data = 1 << 1, /**< Regular data input. */ + OptionalData = (1 << 1) | Optional, /**< Optional data input. */ + Param = 1 << 2, /**< Parameter input. */ + OptionalParam = (1 << 2) | Optional, /**< Optional parameter input. */ + All = static_cast<unsigned int>(-1) }; + +#ifdef _MSC_VER +template <> +constexpr bool enable_bitmask_operators<InputCategory> = true; +#else +template <> +inline constexpr bool enable_bitmask_operators<InputCategory> = true; +#endif + + /** * @class Operator * @brief Base class for all operator types in the Aidge framework. diff --git a/include/aidge/utils/BitwiseUtils.hpp b/include/aidge/utils/BitwiseUtils.hpp new file mode 100644 index 0000000000000000000000000000000000000000..208115fa14fc4ed1e91c7d54e7a7341cd733d5e1 --- /dev/null +++ b/include/aidge/utils/BitwiseUtils.hpp @@ -0,0 +1,87 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#ifndef AIDGE_BITWISEUTILS_H_ +#define AIDGE_BITWISEUTILS_H_ + +#include <type_traits> + +namespace Aidge +{ + +// Define a trait to enable bitwise operators for an enum class +template <typename E> +constexpr bool enable_bitmask_operators = false; + +// Define bitwise OR +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E> +operator|(E lhs, E rhs) { + using underlying = typename std::underlying_type_t<E>; + return static_cast<E>(static_cast<underlying>(lhs) | static_cast<underlying>(rhs)); +} + +// Define bitwise AND +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E> +operator&(E lhs, E rhs) { + using underlying = typename std::underlying_type_t<E>; + return static_cast<E>(static_cast<underlying>(lhs) & static_cast<underlying>(rhs)); +} + +// Define bitwise XOR +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E> +operator^(E lhs, E rhs) { + using underlying = typename std::underlying_type_t<E>; + return static_cast<E>(static_cast<underlying>(lhs) ^ static_cast<underlying>(rhs)); +} + +// Define bitwise NOT +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E> +operator~(E value) { + using underlying = typename std::underlying_type_t<E>; + return static_cast<E>(~static_cast<underlying>(value)); +} + +// Define compound OR assignment +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&> +operator|=(E& lhs, E rhs) { + lhs = lhs | rhs; + return lhs; +} + +// Define compound AND assignment +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&> +operator&=(E& lhs, E rhs) { + lhs = lhs & rhs; + return lhs; +} + +// Define compound XOR assignment +template <typename E> +constexpr typename std::enable_if_t<enable_bitmask_operators<E>, E&> +operator^=(E& lhs, E rhs) { + lhs = lhs ^ rhs; + return lhs; +} + +template <typename E> +constexpr auto to_underlying(E e) noexcept +{ + return static_cast<std::underlying_type_t<E>>(e); +} + +} // namespace Aidge + +#endif // diff --git a/python_binding/graph/pybind_GraphView.cpp b/python_binding/graph/pybind_GraphView.cpp index d1b99c305d0e067a74c13a33cde062b2c6f2ddfa..4abb07fac3eee886c1db787ad4bdbee7a4f5bde7 100644 --- a/python_binding/graph/pybind_GraphView.cpp +++ b/python_binding/graph/pybind_GraphView.cpp @@ -50,6 +50,7 @@ void init_GraphView(py::module& m) { )mydelimiter") .def("get_input_nodes", &GraphView::inputNodes, + py::arg("filter") = Aidge::InputCategory::All, R"mydelimiter( Get set of input Nodes. diff --git a/python_binding/operator/pybind_Operator.cpp b/python_binding/operator/pybind_Operator.cpp index 18708adee1f9644af7d6824cfb7896e786da2108..94b41f691051101871626b907f2a8c81a49c5f13 100644 --- a/python_binding/operator/pybind_Operator.cpp +++ b/python_binding/operator/pybind_Operator.cpp @@ -28,11 +28,13 @@ void init_Operator(py::module& m){ .value("Data", OperatorType::Data) .value("Tensor", OperatorType::Tensor); - py::enum_<InputCategory>(m, "InputCategory") - .value("Data", InputCategory::Data) - .value("Param", InputCategory::Param) - .value("OptionalData", InputCategory::OptionalData) - .value("OptionalParam", InputCategory::OptionalParam); + py::enum_<Aidge::InputCategory>(m, "InputCategory", py::arithmetic()) + .value("Optional", Aidge::InputCategory::Optional) + .value("Data", Aidge::InputCategory::Data) + .value("OptionalData", Aidge::InputCategory::OptionalData) + .value("Param", Aidge::InputCategory::Param) + .value("OptionalParam", Aidge::InputCategory::OptionalParam) + .value("All", Aidge::InputCategory::All); py::class_<Operator, std::shared_ptr<Operator>>(m, "Operator") .def("__repr__", &Operator::repr) diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 61a0a271c6dd23f30065f31a711d0383395f5d9d..ea738cceaff22201aff58c782d099b12b06cc5a4 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -135,12 +135,6 @@ void init_Aidge(py::module& m) { init_Attributes(m); init_Spikegen(m); - init_Node(m); - init_GraphView(m); - init_OpArgs(m); - init_Connector(m); - init_SinglePassGraphMatching(m); - init_OperatorImpl(m); init_Log(m); init_Operator(m); @@ -149,6 +143,12 @@ void init_Aidge(py::module& m) { init_StaticAnalysis(m); init_DynamicAnalysis(m); + init_Node(m); + init_GraphView(m); + init_OpArgs(m); + init_Connector(m); + init_SinglePassGraphMatching(m); + init_Abs(m); init_Add(m); init_And(m); diff --git a/src/graph/GraphView.cpp b/src/graph/GraphView.cpp index c45bacf46958779102c9a09efe958d5de2012ea6..9d6557054f2f584e4dc4137ea7fdf544aba8fff3 100644 --- a/src/graph/GraphView.cpp +++ b/src/graph/GraphView.cpp @@ -299,11 +299,11 @@ void Aidge::GraphView::setRootNode(NodePtr node) { // TENSOR MANAGEMENT /////////////////////////////////////////////////////// -std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes() const { +std::set<std::shared_ptr<Aidge::Node>> Aidge::GraphView::inputNodes(InputCategory filter) const { std::set<std::shared_ptr<Aidge::Node>> nodes; for (const auto& node : mInputNodes) { // Do not include dummy inputs - if (node.first) { + if (node.first && to_underlying(node.first->inputCategory(node.second) & filter)) { nodes.insert(node.first); } } diff --git a/src/graph/Node.cpp b/src/graph/Node.cpp index 74e0cab37489c275512f5ba53290bdb5eac065e0..19eaa6d14cbe9def21d29e2cbcd2f5f47adbc628 100644 --- a/src/graph/Node.cpp +++ b/src/graph/Node.cpp @@ -147,7 +147,9 @@ bool Aidge::Node::valid() const { Aidge::IOIndex_t Aidge::Node::getNbFreeDataInputs() const { IOIndex_t nbFreeDataIn = 0; for (IOIndex_t i = 0; i < nbInputs(); ++i) { - if (input(i).second == gk_IODefaultIndex) { + if ((inputCategory(i) == InputCategory::Data + || inputCategory(i) == InputCategory::OptionalData) + && input(i).second == gk_IODefaultIndex) { ++nbFreeDataIn; } } @@ -389,10 +391,17 @@ void Aidge::Node::addChild(const std::shared_ptr<Node>& otherNode, const IOIndex void Aidge::Node::addChild(std::shared_ptr<GraphView> otherView, const IOIndex_t outId, std::pair<std::shared_ptr<Node>, IOIndex_t> otherInId) { if (!otherInId.first) { - AIDGE_ASSERT(otherView->inputNodes().size() == 1U, + AIDGE_ASSERT(otherView->getNbFreeDataInputs() == 1U, "Input node of GraphView {} need to be specified, because it has more than one input ({} inputs), when trying to add it as a child of node {} (of type {})", otherView->name(), otherView->inputNodes().size(), name(), type()); - otherInId.first = *(otherView->inputNodes().begin()); + + otherInId.first = *(std::find_if( + otherView->inputNodes().begin(), + otherView->inputNodes().end(), + [](const auto &node) { + return node->getFirstFreeDataInput() != gk_IODefaultIndex; + } + )); } otherInId.second = (otherInId.second != gk_IODefaultIndex) ? otherInId.second