diff --git a/include/aidge/operator/BitShift.hpp b/include/aidge/operator/BitShift.hpp index c54d6a99fc2f945a02396af446d356004e94efc1..afe0745869e828ed9004c0ff3856f5ffef5c23dc 100644 --- a/include/aidge/operator/BitShift.hpp +++ b/include/aidge/operator/BitShift.hpp @@ -23,7 +23,7 @@ #include "aidge/utils/Types.h" #include "aidge/utils/StaticAttributes.hpp" -#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection) +#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection), X(Rounding,"rounding",bool) namespace Aidge { enum class BitShiftAttr { @@ -87,10 +87,12 @@ public: * @brief Constructor to initialize the `BitShift_Op` with a shift direction. * @param[in] direction The direction of the bitwise shift (left or right). */ - BitShift_Op(BitShiftDirection direction) + BitShift_Op(BitShiftDirection direction, bool rounding = false) : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1), mAttributes(std::make_shared<Attributes_>( - attr<BitShiftAttr::BitShiftdirection>(direction))) {} + attr<BitShiftAttr::BitShiftdirection>(direction), + attr<BitShiftAttr::Rounding>(rounding))) + {} /** * @brief Copy-constructor. Copies operator attributes and output tensors but not input tensors. @@ -143,6 +145,13 @@ public: inline BitShiftDirection& direction() const noexcept { return mAttributes->template getAttr<BitShiftAttr::BitShiftdirection>(); } + /** + * @brief Retrieve the rounding flag. + * @return A boolean (True: Apply bitshift rounding). + */ + inline bool rounding() const noexcept { + return mAttributes->template getAttr<BitShiftAttr::Rounding>(); + } /** * @brief Get the names of the input tensors. @@ -172,11 +181,12 @@ public: /** * @brief Factory function to create a `BitShift` node. * @param[in] direction The direction of the bitwise shift (`left` or `right`). + * @param[in] rounding Apply rounding * @param[in] name (Optional) Name of the node. * @return A shared pointer to the created node. */ -inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction, const std::string& name = "") { - return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction), name); +inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction,bool rounding = false, const std::string& name = "") { + return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction,rounding), name); } } // namespace Aidge diff --git a/python_binding/operator/pybind_BitShift.cpp b/python_binding/operator/pybind_BitShift.cpp index f2f4b223df788c27dc1378d8564c881b907901c4..4efb2c96fc683a34923accf14238005104eb5132 100644 --- a/python_binding/operator/pybind_BitShift.cpp +++ b/python_binding/operator/pybind_BitShift.cpp @@ -9,58 +9,61 @@ * ********************************************************************************/ -#include <pybind11/pybind11.h> + #include <pybind11/pybind11.h> -#include <string> -#include "aidge/backend/OperatorImpl.hpp" -#include "aidge/data/Tensor.hpp" -#include "aidge/operator/BitShift.hpp" -#include "aidge/operator/OperatorTensor.hpp" -#include "aidge/utils/Types.h" - -namespace py = pybind11; -namespace Aidge { - -void init_BitShift(py::module &m) { - // Binding for BitShiftOp class - auto pyBitShiftOp = py::class_<BitShift_Op, std::shared_ptr<BitShift_Op>, OperatorTensor>(m, "BitShiftOp", py::multiple_inheritance(),R"mydelimiter( - BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements. - This class allows shifting tensor values either to the left or right based on the - specified direction. The direction can be accessed and controlled using the - BitShiftDirection enum. - :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right) - :type direction: BitShiftDirection - :param name: name of the node. - )mydelimiter") - .def(py::init<BitShift_Op::BitShiftDirection>(), py::arg("direction")) - .def("direction", &BitShift_Op::direction, "Get the direction of the bit shift (left or right).") - .def_static("get_inputs_name", &BitShift_Op::getInputsName, "Get the names of the input tensors.") - .def_static("get_outputs_name", &BitShift_Op::getOutputsName, "Get the names of the output tensors.") - .def_static("attributes_name", []() { - std::vector<std::string> result; - auto attributes = BitShift_Op::attributesName(); - for (size_t i = 0; i < size(EnumStrings<BitShiftAttr>::data); ++i) { - result.emplace_back(attributes[i]); - } - return result; - }); - - // Enum binding under BitShiftOp class - py::enum_<BitShift_Op::BitShiftDirection>(pyBitShiftOp, "BitShiftDirection") - .value("Right", BitShift_Op::BitShiftDirection::right) - .value("Left", BitShift_Op::BitShiftDirection::left) - .export_values(); - - // Binding for the BitShift function - m.def("BitShift", &BitShift, py::arg("direction") = BitShift_Op::BitShiftDirection::right, py::arg("name") = "", - R"mydelimiter( - BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements. - This class allows shifting tensor values either to the left or right based on the - specified direction. The direction can be accessed and controlled using the - BitShiftDirection enum. - :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right) - :type direction: BitShiftDirection - :param name: name of the node. - )mydelimiter"); -} -} // namespace Aidge \ No newline at end of file + #include <string> + #include "aidge/backend/OperatorImpl.hpp" + #include "aidge/data/Tensor.hpp" + #include "aidge/operator/BitShift.hpp" + #include "aidge/operator/OperatorTensor.hpp" + #include "aidge/utils/Types.h" + + namespace py = pybind11; + namespace Aidge { + + void init_BitShift(py::module &m) { + // Binding for BitShiftOp class + auto pyBitShiftOp = py::class_<BitShift_Op, std::shared_ptr<BitShift_Op>, OperatorTensor>(m, "BitShiftOp", py::multiple_inheritance(),R"mydelimiter( + BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements. + This class allows shifting tensor values either to the left or right based on the + specified direction. The direction can be accessed and controlled using the + BitShiftDirection enum. + :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right) + :type direction: BitShiftDirection + :param rounding: flag to apply bitshift rounding + :type rounding: boolean + :param name: name of the node. + )mydelimiter") + .def(py::init<BitShift_Op::BitShiftDirection,bool>(), py::arg("direction"),py::arg("rounding")) + .def_static("get_inputs_name", &BitShift_Op::getInputsName, "Get the names of the input tensors.") + .def_static("get_outputs_name", &BitShift_Op::getOutputsName, "Get the names of the output tensors.") + .def_static("attributes_name", []() { + std::vector<std::string> result; + auto attributes = BitShift_Op::attributesName(); + for (size_t i = 0; i < size(EnumStrings<BitShiftAttr>::data); ++i) { + result.emplace_back(attributes[i]); + } + return result; + }); + + // Enum binding under BitShiftOp class + py::enum_<BitShift_Op::BitShiftDirection>(pyBitShiftOp, "BitShiftDirection") + .value("Right", BitShift_Op::BitShiftDirection::right) + .value("Left", BitShift_Op::BitShiftDirection::left) + .export_values(); + + // Binding for the BitShift function + m.def("BitShift", &BitShift, py::arg("direction") = BitShift_Op::BitShiftDirection::right,py::arg("rounding") = false, py::arg("name") = "", + R"mydelimiter( + BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements. + This class allows shifting tensor values either to the left or right based on the + specified direction. The direction can be accessed and controlled using the + BitShiftDirection enum. + :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right) + :type direction: BitShiftDirection + :param rounding: flag to apply bitshift rounding + :type rounding: boolean + :param name: name of the node. + )mydelimiter"); + } + } // namespace Aidge \ No newline at end of file