Skip to content
Snippets Groups Projects

Adding Rounding attribute to the bitshift operator

Merged Noam Zerah requested to merge noamzerah/aidge_core:bitshift_rounding into dev
Files
2
@@ -23,7 +23,7 @@
@@ -23,7 +23,7 @@
#include "aidge/utils/Types.h"
#include "aidge/utils/Types.h"
#include "aidge/utils/StaticAttributes.hpp"
#include "aidge/utils/StaticAttributes.hpp"
#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection)
#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection), X(Rounding,"rounding",bool)
namespace Aidge {
namespace Aidge {
enum class BitShiftAttr {
enum class BitShiftAttr {
@@ -87,10 +87,12 @@ public:
@@ -87,10 +87,12 @@ public:
* @brief Constructor to initialize the `BitShift_Op` with a shift direction.
* @brief Constructor to initialize the `BitShift_Op` with a shift direction.
* @param[in] direction The direction of the bitwise shift (left or right).
* @param[in] direction The direction of the bitwise shift (left or right).
*/
*/
BitShift_Op(BitShiftDirection direction)
BitShift_Op(BitShiftDirection direction, bool rounding = false)
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1),
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
mAttributes(std::make_shared<Attributes_>(
attr<BitShiftAttr::BitShiftdirection>(direction))) {}
attr<BitShiftAttr::BitShiftdirection>(direction),
 
attr<BitShiftAttr::Rounding>(rounding)))
 
{}
/**
/**
* @brief Copy-constructor. Copies operator attributes and output tensors but not input tensors.
* @brief Copy-constructor. Copies operator attributes and output tensors but not input tensors.
@@ -143,6 +145,13 @@ public:
@@ -143,6 +145,13 @@ public:
inline BitShiftDirection& direction() const noexcept {
inline BitShiftDirection& direction() const noexcept {
return mAttributes->template getAttr<BitShiftAttr::BitShiftdirection>();
return mAttributes->template getAttr<BitShiftAttr::BitShiftdirection>();
}
}
 
/**
 
* @brief Retrieve the rounding flag.
 
* @return A boolean (True: Apply bitshift rounding).
 
*/
 
inline bool rounding() const noexcept {
 
return mAttributes->template getAttr<BitShiftAttr::Rounding>();
 
}
/**
/**
* @brief Get the names of the input tensors.
* @brief Get the names of the input tensors.
@@ -172,11 +181,12 @@ public:
@@ -172,11 +181,12 @@ public:
/**
/**
* @brief Factory function to create a `BitShift` node.
* @brief Factory function to create a `BitShift` node.
* @param[in] direction The direction of the bitwise shift (`left` or `right`).
* @param[in] direction The direction of the bitwise shift (`left` or `right`).
 
* @param[in] rounding Apply rounding
* @param[in] name (Optional) Name of the node.
* @param[in] name (Optional) Name of the node.
* @return A shared pointer to the created node.
* @return A shared pointer to the created node.
*/
*/
inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction, const std::string& name = "") {
inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction,bool rounding = false, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction), name);
return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction,rounding), name);
}
}
} // namespace Aidge
} // namespace Aidge
Loading