Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
CLE.cpp 3.57 KiB
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include "aidge/quantization/PTQ/CLE.hpp"
#include "aidge/quantization/PTQ/Clip.hpp"
#include "aidge/quantization/PTQ/PTQ.hpp"

#include "aidge/graph/GraphView.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/scheduler/Scheduler.hpp"
#include "aidge/utils/Log.hpp"
#include "aidge/operator/OperatorTensor.hpp"

namespace Aidge
{

static std::shared_ptr<Tensor> getWeightTensor(std::shared_ptr<Node> node)
{
    return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(1);
}

static std::shared_ptr<Tensor> getBiasTensor(std::shared_ptr<Node> node)
{
    return std::static_pointer_cast<OperatorTensor>(node->getOperator())->getInput(2);
}

static void rescaleTensor(std::shared_ptr<Tensor> tensor, float scaling)
{
    // Get the tensor data pointer
    float * castedTensor = static_cast <float *> (tensor->getImpl()->rawPtr());

    // Rescale the tensor
    for(std::size_t i = 0; i < tensor->size(); i++)
        castedTensor[i] *= scaling;
}

static float getTensorAbsoluteMax(std::shared_ptr <Tensor> tensor)
{
    // Get the tensor data pointer and edit it
    float * castedTensor = static_cast<float*>(tensor->getImpl()->rawPtr());

    // Get the tensor absolute max value
    float maxValue = 0.0f;
    for(std::size_t i = 0; i < tensor->size(); ++i) {
        if(std::fabs(castedTensor[i]) > maxValue) {
            maxValue = std::fabs(castedTensor[i]);
        }
    }
    return maxValue;
}

void crossLayerEqualization(std::shared_ptr<GraphView> graphView, float targetDelta)
{
    std::vector<std::shared_ptr<Node>> nodeVector = retrieveNodeVector(graphView);

    // Check if the CLE can be applied ...

    for (std::shared_ptr<Node> node : nodeVector)
        if (node->getChildren().size() > 1)
        {
            Log::info(" Network have multiple branches, skipping the CLE ... ");
            return;
        }    

    Log::info(" Applying the Cross-Layer Equalization ... ");

    // Get the vector of affine nodes

    std::vector<std::shared_ptr<Node>> affineNodeVector;
    for (std::shared_ptr<Node> node : nodeVector)
        if (isAffine(node))
            affineNodeVector.push_back(node);

    float maxRangeDelta;

    do 
    {
        maxRangeDelta = 0.0;
        /*
        std::cout << " ----- " << std::endl;
        for (std::shared_ptr<Node> node : affineNodeVector)
            std::cout << getTensorAbsoluteMax(getWeightTensor(node)) << std::endl;
        */
        for (size_t i = 0; i < (affineNodeVector.size() - 1); i++)
        {
            std::shared_ptr<Node> n1 = affineNodeVector[i];
            std::shared_ptr<Node> n2 = affineNodeVector[i+1];

            float r1 = getTensorAbsoluteMax(getWeightTensor(n1));
            float r2 = getTensorAbsoluteMax(getWeightTensor(n2));

            float s1 = std::sqrt(r1 * r2) / r1;
            float s2 = std::sqrt(r1 * r2) / r2;

            rescaleTensor(getWeightTensor(n1), s1);
            rescaleTensor(getWeightTensor(n2), s2);

            rescaleTensor(getBiasTensor(n1), s1);

            float rangeDelta = std::abs(r1 - r2);
            if (rangeDelta > maxRangeDelta)
                maxRangeDelta = rangeDelta;
        }
    }
    while (maxRangeDelta > targetDelta);
}

}