diff --git a/include/aidge/operator/PTQMetaOps.hpp b/include/aidge/operator/PTQMetaOps.hpp index 96182202f38be20afa539eb41a8d32b989afcf9f..8768b715b26b946cb3c21dee1152031401be24ba 100644 --- a/include/aidge/operator/PTQMetaOps.hpp +++ b/include/aidge/operator/PTQMetaOps.hpp @@ -66,6 +66,13 @@ namespace Aidge { */ void castQuantizerIOs(std::shared_ptr<Node>& quantizer, Aidge::DataType externalType); + /** + * @brief Given a Quantizer, set the coefficient of it's Mul node. + * @param quantizer The quantizer containing the multiplicative node. + * @param value The new value of the multiplicative coefficient. + */ + void setScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double value); + /** * @brief Given a Quantizer, retreive the coefficient of it's Mul node. * @param quantizer The quantizer containing the multiplicative coefficient. diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 9e2a62c8b975bbca4f63c90d500324a75e492e64..bab86d05c70da7f46bf7a151f19ae0282eec4a2d 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -1226,18 +1226,18 @@ void performSingleShiftApproximation(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> linearNode = node->getParent(0); double base = getScalingFactor(node); - double approx = std::pow(2, std::ceil(std::log2(base))); - double ratio = approx / base; + double approx = std::pow(2, static_cast<int>(std::ceil(std::log2(base)))); // set the scaling factor value to the approximation ... - multiplyScalingFactor(node, ratio); + setScalingFactor(node, approx); // compensate the ratio using the previous node scaling factors ... - multiplyScalingFactor(linearNode->getParent(1), 1.0 / ratio); + double ratio = base / approx; + multiplyScalingFactor(linearNode->getParent(1), ratio); if (nodeHasBias(linearNode)) - multiplyScalingFactor(linearNode->getParent(2), 1.0 / ratio); + multiplyScalingFactor(linearNode->getParent(2), ratio); } } } diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index 82b2e501dc275b361899a0ae8284f8a5409d32dc..c83e60019388409ee78d871e8216c038c1e4213a 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -82,7 +82,35 @@ std::shared_ptr<Node> Quantizer(double scalingFactor, const std::string& name) return quantizer; } -void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff) + +double getScalingFactor(std::shared_ptr<Node> quantizer) +{ + // Retreive the previous microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph(); + + // Get the Mul node from the microGraph + + std::shared_ptr<Node> mulNode = nullptr; + for (auto node : microGraph->getNodes()) + if (node->type() == "Mul") + mulNode = node; + + auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); + + // Retreive the scaling factor + + auto scalingFactorTensor = mulOp->getInput(1); + + std::shared_ptr<Tensor> fallback; + const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); + double scalingFactor = localTensor.get<double>(0); + + return scalingFactor; +} + +void setScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double value) { auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); @@ -104,7 +132,7 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff) // Create the new scaling factor tensor - std::shared_ptr<Tensor> newScalingFactorTensor = std::make_shared<Tensor>(prevScalingFactor * coeff); + std::shared_ptr<Tensor> newScalingFactorTensor = std::make_shared<Tensor>(value); newScalingFactorTensor->setBackend(scalingFactorTensor->backend()); newScalingFactorTensor->setDataType(scalingFactorTensor->dataType()); @@ -114,6 +142,12 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff) producer->getOperator()->setOutput(0, newScalingFactorTensor); } +void multiplyScalingFactor(std::shared_ptr<Aidge::Node> quantizer, double coeff) +{ + double prevScalingFactor = getScalingFactor(quantizer); + setScalingFactor(quantizer, coeff * prevScalingFactor); +} + void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double clipMax) { // Retreive a clone of the microGraph @@ -131,7 +165,7 @@ void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double cl // append round - auto roundNode = Round(quantizer->name() + "_RoundQuant"); + auto roundNode = Round(Round_Op::HalfwayRounding::NextInteger, quantizer->name() + "_RoundQuant"); outputNode->addChild(roundNode, 0, 0); microGraph->add(roundNode); @@ -168,32 +202,7 @@ void appendRoundClip(std::shared_ptr<Node>& quantizer, double clipMin, double cl quantizer = newQuantizer; } -double getScalingFactor(std::shared_ptr<Node> quantizer) -{ - // Retreive the previous microGraph - auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); - auto microGraph = quantizerOp->getMicroGraph(); - - // Get the Mul node from the microGraph - - std::shared_ptr<Node> mulNode = nullptr; - for (auto node : microGraph->getNodes()) - if (node->type() == "Mul") - mulNode = node; - - auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); - - // Retreive the scaling factor - - auto scalingFactorTensor = mulOp->getInput(1); - - std::shared_ptr<Tensor> fallback; - const Tensor& localTensor = scalingFactorTensor->refCastFrom(fallback, DataType::Float64, "cpu"); - double scalingFactor = localTensor.get<double>(0); - - return scalingFactor; -} void setClipRange(std::shared_ptr<Node> quantizer, double min, double max) {