-
Benjamin Halimi authoredBenjamin Halimi authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
CLE.cpp 3.57 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include "aidge/quantization/PTQ/CLE.hpp"
#include "aidge/quantization/PTQ/Clip.hpp"
#include "aidge/quantization/PTQ/PTQ.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace Aidge
{
static std::shared_ptr<Tensor> getWeightTensor(std::shared_ptr<Node> node)
{
return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
}
static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
{
return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
}
static void rescaleTensor(std::shared_ptr<Tensor> tensor, float scaling)
{
// Get the tensor data pointer
float * castedTensor = static_cast <float *> (tensor->getImpl()->rawPtr());
// Rescale the tensor
for(std::size_t i = 0; i < tensor->size(); i++)
castedTensor[i] *= scaling;
}
static float getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
{
// Get the tensor data pointer and edit it
float * castedTensor = static_cast<float*>(tensor->getImpl()->rawPtr());
// Get the tensor absolute max value
float maxValue = 0.0f;
for(std::size_t i = 0; i < tensor->size(); ++i) {
if(std::fabs(castedTensor[i]) > maxValue) {
maxValue = std::fabs(castedTensor[i]);
}
}
return maxValue;
}
void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDelta)
{
std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);
// Check if the CLE can be applied ...
for (std::shared_ptr<Node> node : nodeVector)
if (node->getChildren().size() > 1)
{
Log::info(" Network have multiple branches, skipping the CLE ... ");
return;
}
Log::info(" Applying the Cross-Layer Equalization ... ");
// Get the vector of affine nodes
std::vector<std::shared_ptr<Node>> affineNodeVector;
for (std::shared_ptr<Node> node : nodeVector)
if (isAffine(node))
affineNodeVector.push_back(node);
float maxRangeDelta;
do
{
maxRangeDelta = 0.0;
/*
std::cout << " ----- " << std::endl;
for (std::shared_ptr<Node> node : affineNodeVector)
std::cout << getTensorAbsoluteMax(getWeightTensor(node)) << std::endl;
*/
for (size_t i = 0; i < (affineNodeVector.size() - 1); i++)
{
std::shared_ptr<Node> n1 = affineNodeVector[i];
std::shared_ptr<Node> n2 = affineNodeVector[i+1];
float r1 = getTensorAbsoluteMax(getWeightTensor(n1));
float r2 = getTensorAbsoluteMax(getWeightTensor(n2));
float s1 = std::sqrt(r1 * r2) / r1;
float s2 = std::sqrt(r1 * r2) / r2;
rescaleTensor(getWeightTensor(n1), s1);
rescaleTensor(getWeightTensor(n2), s2);
rescaleTensor(getBiasTensor(n1), s1);
float rangeDelta = std::abs(r1 - r2);
if (rangeDelta > maxRangeDelta)
maxRangeDelta = rangeDelta;
}
}
while (maxRangeDelta > targetDelta);
}
}