Skip to content
Snippets Groups Projects
Commit 8494bb90 authored by Noam Zerah's avatar Noam Zerah Committed by Cyril Moineau
Browse files

Adding Rounding attribute to the bitshift operator

parent 0e89cde0
No related branches found
No related tags found
1 merge request!353Adding Rounding attribute to the bitshift operator
Pipeline #67676 failed
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
#include "aidge/utils/StaticAttributes.hpp" #include "aidge/utils/StaticAttributes.hpp"
#define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection) #define LIST_BITSHIFT_ATTR(X) X(BitShiftdirection, "bit_shift_direction", BitShiftDirection), X(Rounding,"rounding",bool)
namespace Aidge { namespace Aidge {
enum class BitShiftAttr { enum class BitShiftAttr {
...@@ -87,10 +87,12 @@ public: ...@@ -87,10 +87,12 @@ public:
* @brief Constructor to initialize the `BitShift_Op` with a shift direction. * @brief Constructor to initialize the `BitShift_Op` with a shift direction.
* @param[in] direction The direction of the bitwise shift (left or right). * @param[in] direction The direction of the bitwise shift (left or right).
*/ */
BitShift_Op(BitShiftDirection direction) BitShift_Op(BitShiftDirection direction, bool rounding = false)
: OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1), : OperatorTensor(Type, {InputCategory::Data, InputCategory::Data}, 1),
mAttributes(std::make_shared<Attributes_>( mAttributes(std::make_shared<Attributes_>(
attr<BitShiftAttr::BitShiftdirection>(direction))) {} attr<BitShiftAttr::BitShiftdirection>(direction),
attr<BitShiftAttr::Rounding>(rounding)))
{}
/** /**
* @brief Copy-constructor. Copies operator attributes and output tensors but not input tensors. * @brief Copy-constructor. Copies operator attributes and output tensors but not input tensors.
...@@ -143,6 +145,13 @@ public: ...@@ -143,6 +145,13 @@ public:
inline BitShiftDirection& direction() const noexcept { inline BitShiftDirection& direction() const noexcept {
return mAttributes->template getAttr<BitShiftAttr::BitShiftdirection>(); return mAttributes->template getAttr<BitShiftAttr::BitShiftdirection>();
} }
/**
* @brief Retrieve the rounding flag.
* @return A boolean (True: Apply bitshift rounding).
*/
inline bool rounding() const noexcept {
return mAttributes->template getAttr<BitShiftAttr::Rounding>();
}
/** /**
* @brief Get the names of the input tensors. * @brief Get the names of the input tensors.
...@@ -172,11 +181,12 @@ public: ...@@ -172,11 +181,12 @@ public:
/** /**
* @brief Factory function to create a `BitShift` node. * @brief Factory function to create a `BitShift` node.
* @param[in] direction The direction of the bitwise shift (`left` or `right`). * @param[in] direction The direction of the bitwise shift (`left` or `right`).
* @param[in] rounding Apply rounding
* @param[in] name (Optional) Name of the node. * @param[in] name (Optional) Name of the node.
* @return A shared pointer to the created node. * @return A shared pointer to the created node.
*/ */
inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction, const std::string& name = "") { inline std::shared_ptr<Node> BitShift(const BitShift_Op::BitShiftDirection direction,bool rounding = false, const std::string& name = "") {
return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction), name); return std::make_shared<Node>(std::make_shared<BitShift_Op>(direction,rounding), name);
} }
} // namespace Aidge } // namespace Aidge
......
...@@ -9,58 +9,61 @@ ...@@ -9,58 +9,61 @@
* *
********************************************************************************/ ********************************************************************************/
#include <pybind11/pybind11.h> #include <pybind11/pybind11.h>
#include <string> #include <string>
#include "aidge/backend/OperatorImpl.hpp" #include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp" #include "aidge/data/Tensor.hpp"
#include "aidge/operator/BitShift.hpp" #include "aidge/operator/BitShift.hpp"
#include "aidge/operator/OperatorTensor.hpp" #include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h" #include "aidge/utils/Types.h"
namespace py = pybind11; namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_BitShift(py::module &m) { void init_BitShift(py::module &m) {
// Binding for BitShiftOp class // Binding for BitShiftOp class
auto pyBitShiftOp = py::class_<BitShift_Op, std::shared_ptr<BitShift_Op>, OperatorTensor>(m, "BitShiftOp", py::multiple_inheritance(),R"mydelimiter( auto pyBitShiftOp = py::class_<BitShift_Op, std::shared_ptr<BitShift_Op>, OperatorTensor>(m, "BitShiftOp", py::multiple_inheritance(),R"mydelimiter(
BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements. BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
This class allows shifting tensor values either to the left or right based on the This class allows shifting tensor values either to the left or right based on the
specified direction. The direction can be accessed and controlled using the specified direction. The direction can be accessed and controlled using the
BitShiftDirection enum. BitShiftDirection enum.
:param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right) :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
:type direction: BitShiftDirection :type direction: BitShiftDirection
:param name: name of the node. :param rounding: flag to apply bitshift rounding
)mydelimiter") :type rounding: boolean
.def(py::init<BitShift_Op::BitShiftDirection>(), py::arg("direction")) :param name: name of the node.
.def("direction", &BitShift_Op::direction, "Get the direction of the bit shift (left or right).") )mydelimiter")
.def_static("get_inputs_name", &BitShift_Op::getInputsName, "Get the names of the input tensors.") .def(py::init<BitShift_Op::BitShiftDirection,bool>(), py::arg("direction"),py::arg("rounding"))
.def_static("get_outputs_name", &BitShift_Op::getOutputsName, "Get the names of the output tensors.") .def_static("get_inputs_name", &BitShift_Op::getInputsName, "Get the names of the input tensors.")
.def_static("attributes_name", []() { .def_static("get_outputs_name", &BitShift_Op::getOutputsName, "Get the names of the output tensors.")
std::vector<std::string> result; .def_static("attributes_name", []() {
auto attributes = BitShift_Op::attributesName(); std::vector<std::string> result;
for (size_t i = 0; i < size(EnumStrings<BitShiftAttr>::data); ++i) { auto attributes = BitShift_Op::attributesName();
result.emplace_back(attributes[i]); for (size_t i = 0; i < size(EnumStrings<BitShiftAttr>::data); ++i) {
} result.emplace_back(attributes[i]);
return result; }
}); return result;
});
// Enum binding under BitShiftOp class
py::enum_<BitShift_Op::BitShiftDirection>(pyBitShiftOp, "BitShiftDirection") // Enum binding under BitShiftOp class
.value("Right", BitShift_Op::BitShiftDirection::right) py::enum_<BitShift_Op::BitShiftDirection>(pyBitShiftOp, "BitShiftDirection")
.value("Left", BitShift_Op::BitShiftDirection::left) .value("Right", BitShift_Op::BitShiftDirection::right)
.export_values(); .value("Left", BitShift_Op::BitShiftDirection::left)
.export_values();
// Binding for the BitShift function
m.def("BitShift", &BitShift, py::arg("direction") = BitShift_Op::BitShiftDirection::right, py::arg("name") = "", // Binding for the BitShift function
R"mydelimiter( m.def("BitShift", &BitShift, py::arg("direction") = BitShift_Op::BitShiftDirection::right,py::arg("rounding") = false, py::arg("name") = "",
BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements. R"mydelimiter(
This class allows shifting tensor values either to the left or right based on the BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
specified direction. The direction can be accessed and controlled using the This class allows shifting tensor values either to the left or right based on the
BitShiftDirection enum. specified direction. The direction can be accessed and controlled using the
:param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right) BitShiftDirection enum.
:type direction: BitShiftDirection :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
:param name: name of the node. :type direction: BitShiftDirection
)mydelimiter"); :param rounding: flag to apply bitshift rounding
} :type rounding: boolean
} // namespace Aidge :param name: name of the node.
\ No newline at end of file )mydelimiter");
}
} // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment