diff --git a/src/PTQ/CLE.cpp b/src/PTQ/CLE.cpp index 74bbc90c3937db70444d4ad6e8f1b3a51bd80529..30ddc9e4d0c44f0be66d4cdb880dad414a52927f 100644 --- a/src/PTQ/CLE.cpp +++ b/src/PTQ/CLE.cpp @@ -20,24 +20,17 @@ #include "aidge/quantization/PTQ/PTQ.hpp" // retrieveNodeVector #include "aidge/graph/GraphView.hpp" - #include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/Scheduler.hpp" #include "aidge/utils/Log.hpp" #include "aidge/operator/OperatorTensor.hpp" -#include "aidge/utils/Log.hpp" - -#include "aidge/operator/Mul.hpp" -#include "aidge/operator/ArgMax.hpp" -#include "aidge/operator/Abs.hpp" -#include "aidge/operator/Reshape.hpp" -#include "aidge/operator/Round.hpp" #include "aidge/operator/Mul.hpp" #include "aidge/operator/ArgMax.hpp" #include "aidge/operator/Abs.hpp" #include "aidge/operator/Reshape.hpp" #include "aidge/operator/Round.hpp" +#include "aidge/operator/MetaOperator.hpp" namespace Aidge { @@ -62,23 +55,34 @@ static bool nodeHasBias(std::shared_ptr<Node> node) return false; } -// What is this thing ??? -// Function used to extract the local tensor (from a ProducerScalingNode) -std::shared_ptr<Aidge::Tensor> getLocalTensor(std::shared_ptr<Node> node) +std::shared_ptr<Aidge::Tensor> getScaledWeightTensor(std::shared_ptr<Node> node) { - if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling")) + if (node->getParent(1)->attributes()->hasAttr("quantization.ptq.isProducerScaling")) { - std::shared_ptr<Aidge::OperatorTensor> operatorTensor = std::static_pointer_cast<OperatorTensor>(node->getParent(1)->getOperator()); - operatorTensor->forward(); // We need the forward pass to compute the scaled value of the Tensor - return operatorTensor->getOutput(0); - } else { - return getWeightTensor(node); + auto quantizer = node->getParent(1); + + // perform an inference on the branch + + auto graphView = Sequential({quantizer}); + graphView->add(quantizer->getParent(0)); + SequentialScheduler scheduler(graphView); + scheduler.forward(true, {}); + + // gather and return the result + + auto op = std::static_pointer_cast<MetaOperator_Op>(quantizer->getOperator()); + auto result = op->getOutput(0); + return result; + } + else + { + auto result = getWeightTensor(node); + return result; } } void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetDelta) { -/* std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView); // Check if the CLE can be applied ... @@ -116,29 +120,32 @@ void crossLayerEqualization(std::shared_ptr<GraphView> graphView, double targetD std::shared_ptr<Node> n1 = affineNodeVector[i]; std::shared_ptr<Node> n2 = affineNodeVector[i+1]; - std::shared_ptr<Aidge::Tensor> n1localTensor = getLocalTensor(n1); - std::shared_ptr<Aidge::Tensor> n2localTensor = getLocalTensor(n2); + std::shared_ptr<Aidge::Tensor> w1 = getScaledWeightTensor(n1); + std::shared_ptr<Aidge::Tensor> w2 = getScaledWeightTensor(n2); - double r1 = getTensorAbsoluteMax(n1localTensor); - double r2 = getTensorAbsoluteMax(n2localTensor); + //Log::notice(" TENSOR : \n {}", *w1); + + double r1 = getTensorAbsoluteMax(w1); + double r2 = getTensorAbsoluteMax(w2); double s1 = std::sqrt(r1 * r2) / r1; double s2 = std::sqrt(r1 * r2) / r2; - insertScalingBelowProducer(n1->getParent(1), s1, graphView); + multiplyScalingFactor(n1->getParent(1), s1); if (nodeHasBias(n1)) - insertScalingBelowProducer(n1->getParent(2), s1, graphView); + multiplyScalingFactor(n1->getParent(2), s1); - insertScalingBelowProducer(n2->getParent(1), s2, graphView); + multiplyScalingFactor(n2->getParent(1), s2); double rangeDelta = std::abs(r1 - r2); if (rangeDelta > maxRangeDelta) maxRangeDelta = rangeDelta; } + + // Log::notice(" CLE delta = {} ", maxRangeDelta); } while (maxRangeDelta > targetDelta); -*/ } } \ No newline at end of file