Skip to content
Snippets Groups Projects

Adding Rounding attribute to the bitshift operator

Merged Noam Zerah requested to merge noamzerah/aidge_core:bitshift_rounding into dev
All threads resolved!
@@ -27,10 +27,14 @@ namespace Aidge {
enum class BitShiftAttr {
/**
*
*/
BitShiftdirection
/*
Direction of the Bitshift [Right/Left]
*/
BitShiftdirection,
/*
Apply BitShift Rounding
*/
Rounding
};
}
namespace {
@@ -38,7 +42,7 @@ namespace {
* @brief Specialization of `EnumStrings` for `BitShiftAttr`.
*/
template <>
const char* const EnumStrings<Aidge::BitShiftAttr>::data[] = {"bit_shift_direction"};
const char* const EnumStrings<Aidge::BitShiftAttr>::data[] = {"bit_shift_direction","rounding"};
}
namespace Aidge {
/**
@@ -71,7 +75,7 @@ public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<BitShiftAttr, BitShiftDirection>;
using Attributes_ = StaticAttributes<BitShiftAttr, BitShiftDirection, bool>;
template <BitShiftAttr e>
using attr = typename Attributes_::template attr<e>;
@@ -83,10 +87,12 @@ public:
* @brief Constructor to initialize the `BitShift_Op` with a shift direction.
* @param[in] direction The direction of the bitwise shift (left or right).
*/
BitShift_Op(BitShiftDirection direction)
BitShift_Op(BitShiftDirection direction, bool rounding = false)
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>(
attr<BitShiftAttr::BitShiftdirection>(direction))) {}
attr<BitShiftAttr::BitShiftdirection>(direction),
attr<BitShiftAttr::Rounding>(rounding)))
{}
/**
* @brief Copy-constructor. Copies operator attributes and output tensors but not input tensors.
@@ -139,6 +145,13 @@ public:
inline BitShiftDirection& direction() const noexcept {
return mAttributes->template getAttr<BitShiftAttr::BitShiftdirection>();
}
/**
* @brief Retrieve the rounding flag.
* @return A boolean (True: Apply bitshift rounding).
*/
inline bool rounding() const noexcept {
return mAttributes->template getAttr<BitShiftAttr::Rounding>();
}
/**
* @brief Get the names of the input tensors.
@@ -168,11 +181,12 @@ public:
/**
* @brief Factory function to create a `BitShift` node.
* @param[in] direction The direction of the bitwise shift (`left` or `right`).
* @param[in] rounding Apply rounding
* @param[in] name (Optional) Name of the node.
* @return A shared pointer to the created node.
*/
inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction), name);
inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction,bool rounding = false, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction,rounding), name);
}
} // namespace Aidge
Loading