diff --git a/include/aidge/quantization/PTQ/PTQ.hpp b/include/aidge/quantization/PTQ/PTQ.hpp index f55894c71beded9e6b7a20c53f9f22bbea671a01..1c911801c543cac8cb464acaab80e6061703e6e7 100644 --- a/include/aidge/quantization/PTQ/PTQ.hpp +++ b/include/aidge/quantization/PTQ/PTQ.hpp @@ -74,6 +74,12 @@ namespace Aidge { */ bool isNotQuantized(std::shared_ptr<Node> node); + /** + * @brief Compute the absolute max of a tensor + * @param tensor The Tensor to process + */ + double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor); + /** * @brief Retrieve the scheduled vector of node of a graphView, without the Producer nodes. * @param graphView The graphView containing the nodes diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp index f70793cd297bdba3806bce68b57704fcdc7c4d3d..57787a8951a513cd0dc8660c6ef3a99b63e74729 100644 --- a/src/PTQ/CLE.cpp +++ b/src/PTQ/CLE.cpp @@ -52,62 +52,14 @@ static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node) return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2); } -static void rescaleTensor(std::shared_ptr<Tensor> tensor, double scaling) +static bool nodeHasBias(std::shared_ptr<Node> node) { - auto mulOp = Mul_Op(); - mulOp.setDataType(tensor->dataType()); - mulOp.setBackend(tensor->backend()); - - std::shared_ptr<Aidge::Tensor> scalingTensor = std::make_shared<Aidge::Tensor>(Aidge::Array1D<double, 1> {scaling}); - scalingTensor->setDataType(tensor->dataType()); - scalingTensor->setBackend(tensor->backend()); - - mulOp.associateInput(0, tensor); - mulOp.associateInput(1, scalingTensor); - - mulOp.forward(); - - auto outTensor = mulOp.getOutput(0); - *tensor = *outTensor; - //tensor->copyCast(*outTensor); -} - -// TODO : make the retreival of argmax values backend independant (refCastFrom) -static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) -{ - // get the abs tensor - std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR - std::shared_ptr<Tensor> absTensor = std::make_shared<Tensor>(tensor->abs()); - - // flatten the abs tensor - - std::int64_t nbElement = tensor->size(); - - auto reshapeOp = Reshape_Op({nbElement}); - reshapeOp.setDataType(tensor->dataType()); - reshapeOp.setBackend(tensor->backend()); - - reshapeOp.associateInput(0, absTensor); - reshapeOp.forward(); - std::shared_ptr<Tensor> flatTensor = reshapeOp.getOutput(0); - const Tensor& localFlatTensor = flatTensor->refCastFrom(fallback, DataType::Float64, "cpu"); - - // Get the argmax - - auto argmaxOp = ArgMax_Op(0, true, false); - argmaxOp.setDataType(tensor->dataType()); - argmaxOp.setBackend(tensor->backend()); - - argmaxOp.associateInput(0, flatTensor); - argmaxOp.forward(); - - const Tensor& argMaxTensor = argmaxOp.getOutput(0)->refCastFrom(fallback, DataType::Float64, "cpu"); - - // Return the max - - int maxIndex = std::round(argMaxTensor.get<double>(0)); - - return localFlatTensor.get<double>(maxIndex); + if (node->getParents().size() == 3) { + std::shared_ptr<Tensor> biasTensor = getBiasTensor(node); + if (biasTensor) + return true; + } + return false; } // What is this thing ??? @@ -174,9 +126,8 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD insertScalingBelowProducer(n1->getParent(1), s1, graphView); - if (n1->type() != "MatMul") // TODO : exclude every node that we can't call getParent(2) on ! - if (n1->getParent(2)) - insertScalingBelowProducer(n1->getParent(2), s1, graphView); + if (nodeHasBias(n1)) + insertScalingBelowProducer(n1->getParent(2), s1, graphView); insertScalingBelowProducer(n2->getParent(1), s2, graphView); diff --git a/src/PTQ/PTQ.cpp b/src/PTQ/PTQ.cpp index 91a003d55c0234d9edb9b173d655e1872361d5b6..0eecc450d7567b8eb0421cd95251ba8ace447a7e 100644 --- a/src/PTQ/PTQ.cpp +++ b/src/PTQ/PTQ.cpp @@ -266,7 +266,7 @@ bool insertRoundBelowProducer(std::shared_ptr<Node> node, std::shared_ptr<GraphV return false; } -static double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) +double getTensorAbsoluteMax(std::shared_ptr<Tensor> tensor) { // get the abs tensor std::shared_ptr<Tensor> fallback; //Fallback tensor for refCastFR @@ -571,8 +571,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) // Residual nodes should enter in this category but their ratio is 1 ... if (isAffine(node)) { - Log::warn(" affine : {} ", node->name()); - // Rescale the weight tensor std::shared_ptr<Tensor> weightTensor = getWeightTensor(node); @@ -623,8 +621,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) { if (node->type() == "MatMul") { - Log::warn(" matmul : {} ", node->name()); - // Multiply the input scaling factors ! double leftRatio = accumulatedRatios[node->getParent(0)]; @@ -636,8 +632,6 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) { // Use a maximum arbitration ! - Log::warn(" merging : {} ", node->name()); - std::vector<std::shared_ptr<Node>> mergingNodes = node->getParents(); // Compute the max ratio ...