diff --git a/include/aidge/operator/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp index 05e026ff38ef85e70c12b8cb7690f1f3ddf31862..4e51025d4a70ecd69c085b631170dfe88590c208 100644 --- a/include/aidge/operator/PTQMetaOps.hpp +++ b/include/aidge/operator/PTQMetaOps.hpp @@ -16,6 +16,7 @@ #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" +#include "aidge/data/Data.hpp" namespace Aidge { @@ -25,6 +26,7 @@ namespace Aidge { void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax); void removeRound(std::shared_ptr<Node>& quantizer); void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer); + void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType); /// @brief Quantizer acts as a meta-operator to handle scaling operations in the PTQ, replacing the Scaling Operator. /// This operator is composed of a sequence of [Mul] -> [Clip] -> [Round] operations. diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index 7db4f2704fb7a912e45cdb7a8e31252da248e94c..5928bdda295561989e423287131e0f8f86ecde9b 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -180,11 +180,15 @@ namespace Aidge { void quantizeNetwork(std::shared_ptr<GraphView> graphView, std::uint8_t nbBits, std::vector<std::shared_ptr<Tensor>> inputDataSet, Clipping clippingMode, bool noQuant, bool optimizeSigns, bool singleShift, bool useCuda, bool foldGraph, bool verbose); /** - * @brief Take a quantized graphview and cast it to true integer precision. Also optionally replaces the scaling - * nodes contained in the activation quantizers with bit-shift nodes. - * @param graphView The GraphView to edit. - */ - void castQuantizedNetwork(std::shared_ptr<GraphView> graphView /*, Aidge::DataType targetType, bool singleShift, bool bitShiftRounding*/); + * @brief Take a quantized graphview and cast it to true integer precision. If single-shift option is + * set to True, the scaling nodes contained in the activation quantizers are replaced with bit-shifts. + * @param graphView The GraphView to modify. + * @param targetType The desired precision of the cast. + * @param singleShift If set to True, replace the scaling-factors by bit-shifts. + * @param bitShiftRounding If singleShift is True, specifies the kind of bit-shift roundinqg.. + * + */ + void castQuantizedNetwork(std::shared_ptr<GraphView> graphView, Aidge::DataType targetType, bool singleShift /*, bool bitShiftRounding*/); /** * @brief Compute the weight ranges of every affine node. Provided for debugging purposes. diff --git a/python_binding/pybind_PTQ.cpp b/python_binding/pybind_PTQ.cpp index e7ed07c3258d7cf57ffbb4e0bbd92c11f4f7b12b..b2038fbfdc9ad41f4e62f0991de6449609ad2e7c 100644 --- a/python_binding/pybind_PTQ.cpp +++ b/python_binding/pybind_PTQ.cpp @@ -93,7 +93,8 @@ void init_PTQ(py::module &m) { :type verbose: bool )mydelimiter"); - m.def("cast_quantized_network", &castQuantizedNetwork, py::arg("network")); + // TODO :add doc for this .... + m.def("cast_quantized_network", &castQuantizedNetwork, py::arg("network"), py::arg("target_type"), py::arg("single_shift") /*, py::arg("bitshift_rounding")*/); m.def("quantize_network", &quantizeNetwork, py::arg("network"), py::arg("nb_bits"), py::arg("input_dataset"), py::arg("clipping_mode") = Clipping::MAX , py::arg("no_quantization") = false, py::arg("optimize_signs") = false, py::arg("single_shift") = false, py::arg("use_cuda") = false, py::arg("fold_graph") = true, py::arg("verbose") = false, R"mydelimiter( diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index c2e8fb50c7715d9dc76e304be6da0ed785c527b1..d162421417e8e81fb6f42415137aca1777ff9a5f 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -33,7 +33,6 @@ #include "aidge/recipes/Recipes.hpp" #include "aidge/recipes/QuantRecipes.hpp" - #include "aidge/operator/MetaOperator.hpp" @@ -274,17 +273,43 @@ void foldProducerQuantizers(std::shared_ptr<GraphView> graphView) constantFolding(graphView); } -void castQuantizedNetwork(std::shared_ptr<GraphView> graphView /*, Aidge::DataType targetType, bool singleShift, bool bitShiftRounding*/) +void castQuantizedNetwork(std::shared_ptr<GraphView> graphView, Aidge::DataType targetType, bool singleShift /*, bool bitShiftRounding*/) { - std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); - for (std::shared_ptr<Node> node : nodes) - if (node->type() == "BaseQuantizer") - removeRound(node); - - nodes = graphView->getNodes(); // must be called again because of removeRound() - for (std::shared_ptr<Node> node : nodes) - if (node->type() == "BaseQuantizer" && hasAttr(node, "isActivationScaling")) - replaceScalingWithBitShift(node); + if (singleShift) + { + // Remove the round nodes (that cannot round integers) + + std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); + for (std::shared_ptr<Node> node : nodes) + if (node->type() == "BaseQuantizer") + removeRound(node); + + // Replace the scaling nodes with bit-shifts (activations only) + + nodes = graphView->getNodes(); // must be called again because of removeRound() + for (std::shared_ptr<Node> node : nodes) + if (node->type() == "BaseQuantizer" && hasAttr(node, "isActivationScaling")) + replaceScalingWithBitShift(node); + + // Cast all the graph tensors to integers + + graphView->setDataType(targetType); + } + else + { + // Set all the nodes, excepted the quantizers, to have integer IOs + + std::set<std::shared_ptr<Node>> nodes = graphView->getNodes(); + for (std::shared_ptr<Node> node : nodes) + if (node->type() != "BaseQuantizer") + node->getOperator()->setDataType(targetType); + + // Cast the quantizers input and outputs by inserting Cast nodes + + for (std::shared_ptr<Node> node : nodes) + if (node->type() == "BaseQuantizer") + castQuantizerIOs(node, targetType); + } } double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index 58563a2e7231ede1dd3b097faeca084c2ac6c478..442a3c383bd18fb1883ce7aad2561a140a0836df 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -19,6 +19,7 @@ #include "aidge/operator/Mul.hpp" #include "aidge/operator/Round.hpp" #include "aidge/operator/BitShift.hpp" +#include "aidge/operator/Cast.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/OpArgs.hpp" @@ -157,16 +158,10 @@ void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double cl microGraph->setDataType(dataType); microGraph->setBackend(backend); - // create the new meta-operator + // create the new quantizer and replace the previous one std::shared_ptr<Node> newQuantizer = MetaOperator("BaseQuantizer", microGraph, {}, quantizer->name()); - - // Copy the flags - copyDynamicAttributes(quantizer, newQuantizer); - - // replace the previous quantizer with the new one - GraphView::replace({quantizer}, {newQuantizer}); // XXX : replace the old pointer with the new one (by reference) @@ -251,7 +246,8 @@ void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max) maxProducer->getOperator()->setOutput(0, newMaxTensor); } -// XXX +// XXX TODO : manage the datatype / backend + void removeRound(std::shared_ptr<Node>& quantizer) { // Retreive a clone of the microGraph @@ -259,7 +255,7 @@ void removeRound(std::shared_ptr<Node>& quantizer) auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); auto microGraph = quantizerOp->getMicroGraph()->clone(); - // retreive the multiplicative (scaling) node + // retreive the rounding node std::shared_ptr<Node> roundNode = nullptr; for (auto node : microGraph->getNodes()) @@ -273,16 +269,10 @@ void removeRound(std::shared_ptr<Node>& quantizer) microGraph->replace({roundNode}, {}); - // Create the new meta-operator + // Create the new quantizer and replace the previous one std::shared_ptr<Node> newQuantizer = MetaOperator("BaseQuantizer", microGraph, {}, quantizer->name()); - - // Copy the flags - copyDynamicAttributes(quantizer, newQuantizer); - - // replace the previous quantizer with the new one - GraphView::replace({quantizer}, {newQuantizer}); // XXX : replace the old pointer with the new one (by reference) @@ -311,7 +301,7 @@ void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer) auto dataType = mulOp->getOutput(0)->dataType(); auto backend = mulOp->getOutput(0)->backend(); - // compute the shift value + // compute the shift value (ALL OF THIS MUST BE REWORKED !) double scaling = getScalingFactor(quantizer); int bitShiftAmount = std::round(std::log2(scaling)); @@ -335,16 +325,49 @@ void replaceScalingWithBitShift(std::shared_ptr<Node>& quantizer) microGraph->setDataType(dataType); microGraph->setBackend(backend); - // create the new meta-operator + // create the new quantizer and replace the previous one std::shared_ptr<Node> newQuantizer = MetaOperator("BaseQuantizer", microGraph, {}, quantizer->name()); + copyDynamicAttributes(quantizer, newQuantizer); + GraphView::replace({quantizer}, {newQuantizer}); + + // XXX : replace the old pointer with the new one (by reference) - // Copy the flags + quantizer = newQuantizer; +} - copyDynamicAttributes(quantizer, newQuantizer); +void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType) +{ + // Retreive a clone of the microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph()->clone(); + + // Edit the micrograph (insert Cast nodes at it's IOs) + + auto mulNode = *(microGraph->inputNodes().begin()); // TODO : assert that mulNode is a Mul ! + auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); - // replace the previous quantizer with the new one + auto internalType = mulOp->getOutput(0)->dataType(); + auto castInputNode = Cast(internalType, ""); // add a name ! + auto castOutputNode = Cast(externalType, ""); // add a name ! + microGraph = Sequential({castInputNode, microGraph, castOutputNode}); + + // Set the micrograph datatype + + microGraph->setDataType(internalType); + castOutputNode->getOperator()->setDataType(externalType); + + // Set the micrograph backend + + auto backend = mulOp->getOutput(0)->backend(); + microGraph->setBackend(backend); + + // Create the new quantizer and replace the old one + + std::shared_ptr<Node> newQuantizer = MetaOperator("BaseQuantizer", microGraph, {}, quantizer->name()); + copyDynamicAttributes(quantizer, newQuantizer); GraphView::replace({quantizer}, {newQuantizer}); // XXX : replace the old pointer with the new one (by reference)