Skip to content
Snippets Groups Projects

Adding Rounding attribute to the bitshift operator

Merged Noam Zerah requested to merge noamzerah/aidge_core:bitshift_rounding into dev
1 file
+ 57
54
Compare changes
  • Side-by-side
  • Inline
@@ -9,58 +9,61 @@
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <string>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/BitShift.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
namespace py = pybind11;
namespace Aidge {
void init_BitShift(py::module &m) {
// Binding for BitShiftOp class
auto pyBitShiftOp = py::class_<BitShift_Op, std::shared_ptr<BitShift_Op>, OperatorTensor>(m, "BitShiftOp", py::multiple_inheritance(),R"mydelimiter(
BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
This class allows shifting tensor values either to the left or right based on the
specified direction. The direction can be accessed and controlled using the
BitShiftDirection enum.
:param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
:type direction: BitShiftDirection
:param name: name of the node.
)mydelimiter")
.def(py::init<BitShift_Op::BitShiftDirection>(), py::arg("direction"))
.def("direction", &BitShift_Op::direction, "Get the direction of the bit shift (left or right).")
.def_static("get_inputs_name", &BitShift_Op::getInputsName, "Get the names of the input tensors.")
.def_static("get_outputs_name", &BitShift_Op::getOutputsName, "Get the names of the output tensors.")
.def_static("attributes_name", []() {
std::vector<std::string> result;
auto attributes = BitShift_Op::attributesName();
for (size_t i = 0; i < size(EnumStrings<BitShiftAttr>::data); ++i) {
result.emplace_back(attributes[i]);
}
return result;
});
// Enum binding under BitShiftOp class
py::enum_<BitShift_Op::BitShiftDirection>(pyBitShiftOp, "BitShiftDirection")
.value("Right", BitShift_Op::BitShiftDirection::right)
.value("Left", BitShift_Op::BitShiftDirection::left)
.export_values();
// Binding for the BitShift function
m.def("BitShift", &BitShift, py::arg("direction") = BitShift_Op::BitShiftDirection::right, py::arg("name") = "",
R"mydelimiter(
BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
This class allows shifting tensor values either to the left or right based on the
specified direction. The direction can be accessed and controlled using the
BitShiftDirection enum.
:param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
:type direction: BitShiftDirection
:param name: name of the node.
)mydelimiter");
}
} // namespace Aidge
\ No newline at end of file
#include <string>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/BitShift.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Types.h"
namespace py = pybind11;
namespace Aidge {
void init_BitShift(py::module &m) {
// Binding for BitShiftOp class
auto pyBitShiftOp = py::class_<BitShift_Op, std::shared_ptr<BitShift_Op>, OperatorTensor>(m, "BitShiftOp", py::multiple_inheritance(),R"mydelimiter(
BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
This class allows shifting tensor values either to the left or right based on the
specified direction. The direction can be accessed and controlled using the
BitShiftDirection enum.
:param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
:type direction: BitShiftDirection
:param name: name of the node.
)mydelimiter")
.def(py::init<BitShift_Op::BitShiftDirection>(), py::arg("direction"))
.def("direction", &BitShift_Op::direction, "Get the direction of the bit shift (left or right).")
.def("rounding", &BitShift_Op::rounding, "Apply bitshift rounding")
.def_static("get_inputs_name", &BitShift_Op::getInputsName, "Get the names of the input tensors.")
.def_static("get_outputs_name", &BitShift_Op::getOutputsName, "Get the names of the output tensors.")
.def_static("attributes_name", []() {
std::vector<std::string> result;
auto attributes = BitShift_Op::attributesName();
for (size_t i = 0; i < size(EnumStrings<BitShiftAttr>::data); ++i) {
result.emplace_back(attributes[i]);
}
return result;
});
// Enum binding under BitShiftOp class
py::enum_<BitShift_Op::BitShiftDirection>(pyBitShiftOp, "BitShiftDirection")
.value("Right", BitShift_Op::BitShiftDirection::right)
.value("Left", BitShift_Op::BitShiftDirection::left)
.export_values();
// Binding for the BitShift function
m.def("BitShift", &BitShift, py::arg("direction") = BitShift_Op::BitShiftDirection::right,py::arg("rounding") = false, py::arg("name") = "",
R"mydelimiter(
BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
This class allows shifting tensor values either to the left or right based on the
specified direction. The direction can be accessed and controlled using the
BitShiftDirection enum.
:param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
:type direction: BitShiftDirection
:param rounding: flag to apply bitshift rounding
:type rounding: boolean
:param name: name of the node.
)mydelimiter");
}
} // namespace Aidge
\ No newline at end of file
Loading