/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/
#include <memory>
#include <random>  // normal_distribution, uniform_real_distribution

#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/utils/Random.hpp"

template <typename T>
void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor,
                                T scaling, Aidge::VarianceNorm varianceNorm) {
    AIDGE_ASSERT(tensor->getImpl(),
                 "Tensor got no implementation, cannot fill it.");
    AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");

    unsigned int fanIn, fanOut = 0;
    Aidge::calculateFanInFanOut(tensor, fanIn, fanOut);

    const T n((varianceNorm == Aidge::VarianceNorm::FanIn) ? fanIn
              : (varianceNorm == Aidge::VarianceNorm::Average)
                  ? (fanIn + fanOut) / 2.0
                  : fanOut);
    AIDGE_ASSERT(n > 0,
                 "Something went wrong division by zero or square root of "
                 "negative value.");
    const T scale(std::sqrt(3.0 / n));

    std::uniform_real_distribution<T> uniformDist(-scale, scale);

    std::shared_ptr<Aidge::Tensor> cpyTensor;
    // Create cpy only if tensor not on CPU
    Aidge::Tensor& tensorWithValues =
        tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
    // Setting values
    for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
        tensorWithValues.set<T>(
            idx, scaling * uniformDist(Aidge::Random::Generator::get()));
    }

    // Copy values back to the original tensors (actual copy only if needed)
    tensor->copyCastFrom(tensorWithValues);
}
template <typename T>
void Aidge::xavierNormalFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling,
                               Aidge::VarianceNorm varianceNorm) {
    AIDGE_ASSERT(tensor->getImpl(),
                 "Tensor got no implementation, cannot fill it.");
    AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");

    unsigned int fanIn, fanOut = 0;
    Aidge::calculateFanInFanOut(tensor, fanIn, fanOut);

    const T n((varianceNorm == Aidge::VarianceNorm::FanIn) ? fanIn
              : (varianceNorm == Aidge::VarianceNorm::Average)
                  ? (fanIn + fanOut) / 2.0
                  : fanOut);
    const double stdDev(std::sqrt(1.0 / n));

    std::normal_distribution<T> normalDist(0.0, stdDev);

    std::shared_ptr<Aidge::Tensor> cpyTensor;
    // Create cpy only if tensor not on CPU
    Aidge::Tensor& tensorWithValues =
        tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");

    // Setting values
    for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
        tensorWithValues.set<T>(
            idx, scaling * normalDist(Aidge::Random::Generator::get()));
    }

    // Copy values back to the original tensors (actual copy only if needed)
    tensor->copyCastFrom(tensorWithValues);
}

template void Aidge::xavierUniformFiller<float>(std::shared_ptr<Aidge::Tensor>,
                                                float, Aidge::VarianceNorm);
template void Aidge::xavierUniformFiller<double>(std::shared_ptr<Aidge::Tensor>,
                                                 double, Aidge::VarianceNorm);

template void Aidge::xavierNormalFiller<float>(std::shared_ptr<Aidge::Tensor>,
                                               float, Aidge::VarianceNorm);
template void Aidge::xavierNormalFiller<double>(std::shared_ptr<Aidge::Tensor>,
                                                double, Aidge::VarianceNorm);