diff --git a/include/aidge/operator/BitShift.hpp b/include/aidge/operator/BitShift.hpp index c54d6a99fc2f945a02396af446d356004e94efc1..afe0745869e828ed9004c0ff3856f5ffef5c23dc 100644 --- a/include/aidge/operator/BitShift.hpp +++ b/include/aidge/operator/BitShift.hpp @@ -23,7 +23,7 @@ #include "aidge/utils/Types.h" #include "aidge/utils/StaticAttributes.hpp" -#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection) +#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection), X(Rounding,"rounding",bool) namespace Aidge { enum class BitShiftAttr { @@ -87,10 +87,12 @@ public: * @brief Constructor to initialize the `BitShift_Op` with a shift direction. * @param[in] direction The direction of the bitwise shift (left or right). */ - BitShift_Op(BitShiftDirection direction) + BitShift_Op(BitShiftDirection direction, bool rounding = false) : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1), mAttributes(std::make_shared<Attributes_>( - attr<BitShiftAttr::BitShiftdirection>(direction))) {} + attr<BitShiftAttr::BitShiftdirection>(direction), + attr<BitShiftAttr::Rounding>(rounding))) + {} /** * @brief Copy-constructor. Copies operator attributes and output tensors but not input tensors. @@ -143,6 +145,13 @@ public: inline BitShiftDirection& direction() const noexcept { return mAttributes->template getAttr<BitShiftAttr::BitShiftdirection>(); } + /** + * @brief Retrieve the rounding flag. + * @return A boolean (True: Apply bitshift rounding). + */ + inline bool rounding() const noexcept { + return mAttributes->template getAttr<BitShiftAttr::Rounding>(); + } /** * @brief Get the names of the input tensors. @@ -172,11 +181,12 @@ public: /** * @brief Factory function to create a `BitShift` node. * @param[in] direction The direction of the bitwise shift (`left` or `right`). + * @param[in] rounding Apply rounding * @param[in] name (Optional) Name of the node. * @return A shared pointer to the created node. */ -inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction), name); +inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction,bool rounding = false, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction,rounding), name); } } // namespace Aidge