Skip to content
Snippets Groups Projects
Commit a5622da9 authored by Noam Zerah's avatar Noam Zerah Committed by Olivier BICHLER
Browse files

Updating with new attributes Macro

parent 0e89cde0
No related branches found
No related tags found
1 merge request!353Adding Rounding attribute to the bitshift operator
......@@ -23,7 +23,7 @@
#include "aidge/utils/Types.h"
#include "aidge/utils/StaticAttributes.hpp"
#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection)
#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection), X(Rounding,"rounding",bool)
namespace Aidge {
enum class BitShiftAttr {
......@@ -87,10 +87,12 @@ public:
* @brief Constructor to initialize the `BitShift_Op` with a shift direction.
* @param[in] direction The direction of the bitwise shift (left or right).
*/
BitShift_Op(BitShiftDirection direction)
BitShift_Op(BitShiftDirection direction, bool rounding = false)
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<BitShiftAttr::BitShiftdirection>(direction))) {}
attr<BitShiftAttr::BitShiftdirection>(direction),
attr<BitShiftAttr::Rounding>(rounding)))
{}
/**
* @brief Copy-constructor. Copies operator attributes and output tensors but not input tensors.
......@@ -143,6 +145,13 @@ public:
inline BitShiftDirection& direction() const noexcept {
return mAttributes->template getAttr<BitShiftAttr::BitShiftdirection>();
}
/**
* @brief Retrieve the rounding flag.
* @return A boolean (True: Apply bitshift rounding).
*/
inline bool rounding() const noexcept {
return mAttributes->template getAttr<BitShiftAttr::Rounding>();
}
/**
* @brief Get the names of the input tensors.
......@@ -172,11 +181,12 @@ public:
/**
* @brief Factory function to create a `BitShift` node.
* @param[in] direction The direction of the bitwise shift (`left` or `right`).
* @param[in] rounding Apply rounding
* @param[in] name (Optional) Name of the node.
* @return A shared pointer to the created node.
*/
inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction), name);
inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction,bool rounding = false, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction,rounding), name);
}
} // namespace Aidge
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment