diff --git a/python_binding/operator/pybind_PTQMetaOps.cpp b/python_binding/operator/pybind_PTQMetaOps.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c8df2992702da05d9336208d7dc09ac7510dc4f5 --- /dev/null +++ b/python_binding/operator/pybind_PTQMetaOps.cpp @@ -0,0 +1,121 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + + #include <pybind11/pybind11.h> + #include <pybind11/stl.h> + #include <pybind11/functional.h> + + #include "aidge/operator/PTQMetaOps.hpp" + #include "aidge/graph/Node.hpp" + #include "aidge/utils/Types.h" + + namespace py = pybind11; + + namespace Aidge { + + void init_PTQMetaOps(py::module &m) { + // Quantizer creation and manipulation + m.def("quantizer", &Quantizer, + py::arg("scaling_factor"), + py::arg("name") = "", + R"doc( + Create a quantizer node with specified scaling factor. + + Args: + scaling_factor (float): The scaling factor to apply + name (str): Optional name for the quantizer node + + Returns: + Node: The created quantizer node + )doc"); + + m.def("multiply_scaling_factor", &multiplyScalingFactor, + py::arg("quantizer"), + py::arg("coefficient"), + R"doc( + Multiply the scaling factor of a quantizer by a coefficient. + + Args: + quantizer (Node): The quantizer node to modify + coefficient (float): The multiplication factor + )doc"); + + m.def("get_scaling_factor", &getScalingFactor, + py::arg("quantizer"), + R"doc( + Get the current scaling factor of a quantizer. + + Args: + quantizer (Node): The quantizer node to query + + Returns: + float: The current scaling factor + )doc"); + + // Quantizer modification functions + m.def("append_round_clip", &appendRoundClip, + py::arg("quantizer"), + py::arg("clip_min"), + py::arg("clip_max"), + R"doc( + Append round and clip operations to a quantizer. + + Args: + quantizer (Node): The quantizer node to modify + clip_min (float): Minimum clipping value + clip_max (float): Maximum clipping value + )doc"); + + m.def("set_clip_range", &setClipRange, + py::arg("quantizer"), + py::arg("min"), + py::arg("max"), + R"doc( + Set the clipping range of a quantizer that already has clip operations. + + Args: + quantizer (Node): The quantizer node to modify + min (float): New minimum clipping value + max (float): New maximum clipping value + )doc"); + + m.def("remove_round", &removeRound, + py::arg("quantizer"), + R"doc( + Remove the round operation from a quantizer. + + Args: + quantizer (Node): The quantizer node to modify + )doc"); + + // Advanced quantization operations + m.def("replace_scaling_with_bitshift", &replaceScalingWithBitShift, + py::arg("quantizer"), + R"doc( + Replace multiplicative scaling with bit-shift operations. + + Args: + quantizer (Node): The quantizer node to modify + )doc"); + + m.def("cast_quantizer_ios", &castQuantizerIOs, + py::arg("quantizer"), + py::arg("external_type"), + R"doc( + Cast the input/output of a quantizer to specified data type. + + Args: + quantizer (Node): The quantizer node to modify + external_type (DataType): Target data type for I/O + )doc"); + } + + } // namespace Aidge \ No newline at end of file diff --git a/python_binding/pybind_Quantization.cpp b/python_binding/pybind_Quantization.cpp index 7ac344dcfcd4fc93e3bba1dcd19c1413f5a29d0c..27f6885d48fcafc81d8679deebced6f36ce67e4a 100644 --- a/python_binding/pybind_Quantization.cpp +++ b/python_binding/pybind_Quantization.cpp @@ -31,6 +31,7 @@ void init_DoReFa(py::module& m); // quantization routines void init_PTQ(py::module &m); +void init_PTQMetaOps(py::module &m); void init_QAT_FixedQ(py::module &m); void init_QAT_LSQ(py::module &m); void init_QuantRecipes(py::module &m); @@ -45,6 +46,7 @@ PYBIND11_MODULE(aidge_quantization, m) init_DoReFa(m); init_PTQ(m); + init_PTQMetaOps(m); init_QAT_FixedQ(m); init_QAT_LSQ(m); init_QuantRecipes(m);