From 608330434b03a9bb4b94d6f45d88fd18e0d5e0d9 Mon Sep 17 00:00:00 2001
From: Noam ZERAH <noam.zerah@cea.fr>
Date: Fri, 28 Feb 2025 14:35:26 +0000
Subject: [PATCH] Adding python binding

---
 python_binding/operator/pybind_BitShift.cpp | 111 ++++++++++----------
 1 file changed, 57 insertions(+), 54 deletions(-)

diff --git a/python_binding/operator/pybind_BitShift.cpp b/python_binding/operator/pybind_BitShift.cpp
index f2f4b223d..9313f46c3 100644
--- a/python_binding/operator/pybind_BitShift.cpp
+++ b/python_binding/operator/pybind_BitShift.cpp
@@ -9,58 +9,61 @@
  *
  ********************************************************************************/
 
-#include <pybind11/pybind11.h>
+ #include <pybind11/pybind11.h>
 
-#include <string>
-#include "aidge/backend/OperatorImpl.hpp"
-#include "aidge/data/Tensor.hpp"
-#include "aidge/operator/BitShift.hpp"
-#include "aidge/operator/OperatorTensor.hpp"
-#include "aidge/utils/Types.h"
-
-namespace py = pybind11;
-namespace Aidge {
-
-void init_BitShift(py::module &m) {
-    // Binding for BitShiftOp class
-    auto pyBitShiftOp = py::class_<BitShift_Op, std::shared_ptr<BitShift_Op>, OperatorTensor>(m, "BitShiftOp", py::multiple_inheritance(),R"mydelimiter(
-        BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
-        This class allows shifting tensor values either to the left or right based on the 
-        specified direction. The direction can be accessed and controlled using the 
-        BitShiftDirection enum.
-        :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
-        :type direction: BitShiftDirection
-        :param name: name of the node.
-    )mydelimiter")
-        .def(py::init<BitShift_Op::BitShiftDirection>(), py::arg("direction"))
-        .def("direction", &BitShift_Op::direction, "Get the direction of the bit shift (left or right).")
-        .def_static("get_inputs_name", &BitShift_Op::getInputsName, "Get the names of the input tensors.")
-        .def_static("get_outputs_name", &BitShift_Op::getOutputsName, "Get the names of the output tensors.")
-		.def_static("attributes_name", []() {
-			std::vector<std::string> result;
-			auto attributes = BitShift_Op::attributesName();
-			for (size_t i = 0; i < size(EnumStrings<BitShiftAttr>::data); ++i) {
-				result.emplace_back(attributes[i]);
-			}
-			return result;
-		});
-
-    // Enum binding under BitShiftOp class
-    py::enum_<BitShift_Op::BitShiftDirection>(pyBitShiftOp, "BitShiftDirection")
-        .value("Right", BitShift_Op::BitShiftDirection::right)
-        .value("Left", BitShift_Op::BitShiftDirection::left)
-        .export_values();
-
-    // Binding for the BitShift function
-    m.def("BitShift", &BitShift, py::arg("direction") = BitShift_Op::BitShiftDirection::right, py::arg("name") = "",
-        R"mydelimiter(
-        BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
-        This class allows shifting tensor values either to the left or right based on the 
-        specified direction. The direction can be accessed and controlled using the 
-        BitShiftDirection enum.
-        :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
-        :type direction: BitShiftDirection
-        :param name: name of the node.
-    )mydelimiter");
-}
-} // namespace Aidge
\ No newline at end of file
+ #include <string>
+ #include "aidge/backend/OperatorImpl.hpp"
+ #include "aidge/data/Tensor.hpp"
+ #include "aidge/operator/BitShift.hpp"
+ #include "aidge/operator/OperatorTensor.hpp"
+ #include "aidge/utils/Types.h"
+ 
+ namespace py = pybind11;
+ namespace Aidge {
+ 
+ void init_BitShift(py::module &m) {
+     // Binding for BitShiftOp class
+     auto pyBitShiftOp = py::class_<BitShift_Op, std::shared_ptr<BitShift_Op>, OperatorTensor>(m, "BitShiftOp", py::multiple_inheritance(),R"mydelimiter(
+         BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
+         This class allows shifting tensor values either to the left or right based on the 
+         specified direction. The direction can be accessed and controlled using the 
+         BitShiftDirection enum.
+         :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
+         :type direction: BitShiftDirection
+         :param name: name of the node.
+     )mydelimiter")
+         .def(py::init<BitShift_Op::BitShiftDirection>(), py::arg("direction"))
+         .def("direction", &BitShift_Op::direction, "Get the direction of the bit shift (left or right).")
+         .def("rounding", &BitShift_Op::rounding, "Apply bitshift rounding")
+         .def_static("get_inputs_name", &BitShift_Op::getInputsName, "Get the names of the input tensors.")
+         .def_static("get_outputs_name", &BitShift_Op::getOutputsName, "Get the names of the output tensors.")
+         .def_static("attributes_name", []() {
+             std::vector<std::string> result;
+             auto attributes = BitShift_Op::attributesName();
+             for (size_t i = 0; i < size(EnumStrings<BitShiftAttr>::data); ++i) {
+                 result.emplace_back(attributes[i]);
+             }
+             return result;
+         });
+ 
+     // Enum binding under BitShiftOp class
+     py::enum_<BitShift_Op::BitShiftDirection>(pyBitShiftOp, "BitShiftDirection")
+         .value("Right", BitShift_Op::BitShiftDirection::right)
+         .value("Left", BitShift_Op::BitShiftDirection::left)
+         .export_values();
+ 
+     // Binding for the BitShift function
+     m.def("BitShift", &BitShift, py::arg("direction") = BitShift_Op::BitShiftDirection::right,py::arg("rounding") = false, py::arg("name") = "",
+         R"mydelimiter(
+         BitShiftOp is a tensor operator that performs bitwise shifts on tensor elements.
+         This class allows shifting tensor values either to the left or right based on the 
+         specified direction. The direction can be accessed and controlled using the 
+         BitShiftDirection enum.
+         :param direction: direction of the bit shift (BitShiftDirection.Left or BitShiftDirection.Right)
+         :type direction: BitShiftDirection
+         :param rounding: flag to apply bitshift rounding
+         :type rounding: boolean
+         :param name: name of the node.
+     )mydelimiter");
+ }
+ } // namespace Aidge
\ No newline at end of file
-- 
GitLab