diff --git a/src/operator/PTQMetaOps.cpp b/src/operator/PTQMetaOps.cpp index 971546b5d1416eb35cea48f5ee0ca4b6e4ee4903..4134d16cba524e8f243b7578b4e62e0fc1bb7376 100644 --- a/src/operator/PTQMetaOps.cpp +++ b/src/operator/PTQMetaOps.cpp @@ -18,6 +18,7 @@ #include "aidge/operator/Clip.hpp" #include "aidge/operator/Mul.hpp" #include "aidge/operator/Round.hpp" +#include "aidge/operator/BitShift.hpp" #include "aidge/graph/Node.hpp" #include "aidge/graph/OpArgs.hpp" @@ -33,6 +34,16 @@ namespace Aidge { +static bool hasAttr(std::shared_ptr<Aidge::Node> node, std::string attr) +{ + return node->attributes()->hasAttr("quantization.ptq." + attr); +} + +static void addAttr(std::shared_ptr<Aidge::Node> node, std::string attr, double value = 0.0) +{ + node->attributes()->addAttr("quantization.ptq." + attr, value); +} + std::shared_ptr<Node> BaseQuantizer(double scalingFactor, const std::string& name) { std::shared_ptr<Node> mulNode = Mul(name + "_MulQuant"); @@ -56,19 +67,19 @@ std::shared_ptr<Node> BaseQuantizer(double scalingFactor, const std::string& nam // alternative : capture the Producer ... // std::shared_ptr<GraphView> connectedGraphView = getConnectedGraphView(mulNode); - std::shared_ptr<Node> metaopNode = MetaOperator("BaseQuantizer", graphView, {}, name); // XXX alternative prototype -> + std::shared_ptr<Node> quantizerNode = MetaOperator("BaseQuantizer", graphView, {}, name); // XXX alternative prototype -> - return metaopNode; + return quantizerNode; } void multiplyScalingFactor(std::shared_ptr<Aidge::Node> scalingNode, double coeff) { - auto metaOperatorOp = std::static_pointer_cast<MetaOperator_Op> (scalingNode->getOperator()); + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (scalingNode->getOperator()); // Get the Mul node from the microGraph std::shared_ptr<Node> mulNode = nullptr; - auto microGraph = metaOperatorOp->getMicroGraph(); + auto microGraph = quantizerOp->getMicroGraph(); for (auto node : microGraph->getNodes()) if (node->type() == "Mul") mulNode = node; @@ -94,81 +105,75 @@ void multiplyScalingFactor(std::shared_ptr<Aidge::Node> scalingNode, double coef // XXX prev way : mulNode->input(1).first->getOperator()->setOutput(0, resultTensor); } -void appendRoundClip(std::shared_ptr<Node> metaOpNode, double clipMin, double clipMax) +std::shared_ptr<Node> appendRoundClip(std::shared_ptr<Node> quantizer, double clipMin, double clipMax) { - // Retreive the previous microGraph + // Retreive a clone of the microGraph - auto metaOperatorOp = std::static_pointer_cast<MetaOperator_Op> (metaOpNode->getOperator()); - auto microGraph = metaOperatorOp->getMicroGraph(); + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph()->clone(); - // Get the Mul node from the microGraph + // Save the datatype / backend - std::shared_ptr<Node> mulNode = nullptr; - for (auto node : microGraph->getNodes()) - if (node->type() == "Mul") - mulNode = node; + auto outputNode = *(microGraph->outputNodes().begin()); + auto outputOp = std::static_pointer_cast<OperatorTensor> (outputNode->getOperator()); - auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); + auto dataType = outputOp->getOutput(0)->dataType(); + auto backend = outputOp->getOutput(0)->backend(); - // save the backend and datatype + // append round - auto backend = mulOp->getInput(0)->backend(); - auto dataType = mulOp->getInput(0)->dataType(); + auto roundNode = Round(); + outputNode->addChild(roundNode, 0, 0); + microGraph->add(roundNode); - // create the new microGraph nodes + // append clip - auto newMulNode = Mul(); - auto roundNode = Round(); - auto clipNode = Clip(""); //, clipMin, clipMax); - auto newCoeffNode = mulNode->getParent(1)->clone(); // XXX Producer(coeffTensor); + auto clipNode = Clip(); - // create the new micrograph + auto minTensor = std::make_shared<Tensor>(clipMin); + auto minNode = Producer(minTensor); + minNode->addChild(clipNode, 0, 1); + microGraph->add(minNode); - std::shared_ptr<GraphView> newMicroGraph = Sequential({newMulNode, roundNode, clipNode}); - newCoeffNode->addChild(newMulNode, 0, 1); // 1 was not specified !!! - newMicroGraph->add(newCoeffNode); + auto maxTensor = std::make_shared<Tensor>(clipMax); + auto maxNode = Producer(maxTensor); + maxNode->addChild(clipNode, 0, 2); + microGraph->add(maxNode); - // manually connect the IOs !!! + roundNode->addChild(clipNode, 0, 0); + microGraph->add(clipNode); - auto newMulOp = std::static_pointer_cast<OperatorTensor> (newMulNode->getOperator()); - newMulOp->associateInput(0, mulOp->getInput(0)); // MANDATORY (because we need an input tensor) - auto clipOp = std::static_pointer_cast<Clip_Op> (clipNode->getOperator()); - clipOp->associateOutput(0, mulOp->getOutput(0)); // MANDATORY ? YES !!! + // set the datatype / backend - // Log::notice( " old mul ref count : {}", mulOp->getOutput(0)->getImpl().use_count()); - // Log::notice( " new mul ref count : {}", newMulOp->getOutput(0)->getImpl().use_count()); + microGraph->setDataType(dataType); + microGraph->setBackend(backend); - // Connect the clip min and max tensors + // create the new meta-operator - auto minTensor = std::make_shared<Tensor>(clipMin); - auto maxTensor = std::make_shared<Tensor>(clipMax); - auto minNode = Producer(minTensor); - auto maxNode = Producer(maxTensor); - minNode->addChild(clipNode, 0, 1); - maxNode->addChild(clipNode, 0, 2); - newMicroGraph->add(minNode); - newMicroGraph->add(maxNode); + std::shared_ptr<Node> newQuantizer = MetaOperator("BaseQuantizer", microGraph, {}, quantizer->name()); - // set the backend/datatype + // Copy the flags - newMicroGraph->setBackend(backend); - newMicroGraph->setDataType(dataType); - - // reset the scheduling + if (hasAttr(quantizer, "isProducerScaling")) + addAttr(newQuantizer, "isProducerScaling"); - SequentialScheduler scheduler(newMicroGraph); - scheduler.resetScheduling(); - //scheduler.generateScheduling(); - - // set the micrograph + if (hasAttr(quantizer, "isActivationScaling")) + addAttr(newQuantizer, "isActivationScaling"); - *microGraph = *newMicroGraph; -} + // replace the previous quantizer with the new one + + GraphView::replace({quantizer}, {newQuantizer}); + // TODO : replace the old pointer with the new one (by reference) -void updateScalingFactor(std::shared_ptr<Node> metaOpNode, double scalingFactor) + // quantizer = newQuantizer; + + return newQuantizer; +} + +void updateScalingFactor(std::shared_ptr<Node> quantizerNode, double scalingFactor) { - // TODO : implement or remove the function ... + // XXX TODO : implement or remove the function ... Log::error(" updateScalingFactor() : not yet implemented ... "); } @@ -200,18 +205,6 @@ double getScalingFactor(std::shared_ptr<Node> quantizerNode) return scalingFactor; } -/* -static std::shared_ptr<Node> getSubNode(std::shared_ptr<GraphView> graphView, std::string nodeType) -{ - std::shared_ptr<Node> mulNode = nullptr; - for(std::shared_ptr<Node> node : graphView->getNodes()) - if (node->type() == nodeType) - mulNode = node; - - return mulNode; -} -*/ - void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max) { auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizerNode->getOperator()); @@ -255,4 +248,72 @@ void setClipRange(std::shared_ptr<Node> quantizerNode, double min, double max) maxProducer->getOperator()->setOutput(0, newMaxTensor); } +std::shared_ptr<Node> replaceScalingWithBitShift(std::shared_ptr<Node> quantizer) +{ + // Retreive a clone of the microGraph + + auto quantizerOp = std::static_pointer_cast<MetaOperator_Op> (quantizer->getOperator()); + auto microGraph = quantizerOp->getMicroGraph()->clone(); + + // retreive the multiplicative (scaling) node + + std::shared_ptr<Node> mulNode = nullptr; + for (auto node : microGraph->getNodes()) + if (node->type() == "Mul") + mulNode = node; + + auto mulOp = std::static_pointer_cast<OperatorTensor> (mulNode->getOperator()); + + // Save the datatype / backend + + auto dataType = mulOp->getOutput(0)->dataType(); + auto backend = mulOp->getOutput(0)->backend(); + + // compute the shift value + + double scaling = getScalingFactor(quantizer); + int bitShiftAmount = std::round(std::log2(scaling)); + auto bitShiftDirection = BitShift_Op::BitShiftDirection::right; + bool bitShiftRounding = false; // XXX XXX XXX use an argument !!! + + // create the replacement bit-shift nodes + + auto bitShiftNode = BitShift(bitShiftDirection, bitShiftRounding, ""); // XXX add a name !!! + auto bitShiftTensor = std::make_shared<Tensor>(Array1D<int, 1> {bitShiftAmount}); + + auto bitShiftProducer = Producer(bitShiftTensor, ""); // XXX add a name !!! + bitShiftProducer->addChild(bitShiftNode, 0, 1); + + // edit the micrograph + + microGraph->replace({mulNode, mulNode->getParent(1)}, {bitShiftNode, bitShiftNode->getParent(1)}); + + // set the datatype / backend + + microGraph->setDataType(dataType); + microGraph->setBackend(backend); + + // create the new meta-operator + + std::shared_ptr<Node> newQuantizer = MetaOperator("BaseQuantizer", microGraph, {}, quantizer->name()); + + // Copy the flags + + if (hasAttr(quantizer, "isProducerScaling")) + addAttr(newQuantizer, "isProducerScaling"); + + if (hasAttr(quantizer, "isActivationScaling")) + addAttr(newQuantizer, "isActivationScaling"); + + // replace the previous quantizer with the new one + + GraphView::replace({quantizer}, {newQuantizer}); + + // TODO : replace the old pointer with the new one (by reference) + + // quantizer = newQuantizer; + + return newQuantizer; +} + } \ No newline at end of file