Newer
Older
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <random> // normal_distribution, uniform_real_distribution
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/utils/Random.hpp"
void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor,
T scaling, Aidge::VarianceNorm varianceNorm) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
unsigned int fanIn, fanOut = 0;
Aidge::calculateFanInFanOut(tensor, fanIn, fanOut);
const T n((varianceNorm == Aidge::VarianceNorm::FanIn) ? fanIn
: (varianceNorm == Aidge::VarianceNorm::Average)
? (fanIn + fanOut) / 2.0
: fanOut);
AIDGE_ASSERT(n > 0,
"Something went wrong division by zero or square root of "
"negative value.");
const T scale(std::sqrt(3.0 / n));
std::uniform_real_distribution<T> uniformDist(-scale, scale);
std::shared_ptr<Aidge::Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Aidge::Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(
idx, scaling * uniformDist(Aidge::Random::Generator::get()));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
}
template <typename T>
void Aidge::xavierNormalFiller(std::shared_ptr<Aidge::Tensor> tensor, T scaling,
Aidge::VarianceNorm varianceNorm) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
unsigned int fanIn, fanOut = 0;
Aidge::calculateFanInFanOut(tensor, fanIn, fanOut);
const T n((varianceNorm == Aidge::VarianceNorm::FanIn) ? fanIn
: (varianceNorm == Aidge::VarianceNorm::Average)
? (fanIn + fanOut) / 2.0
: fanOut);
const double stdDev(std::sqrt(1.0 / n));
std::normal_distribution<T> normalDist(0.0, stdDev);
std::shared_ptr<Aidge::Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Aidge::Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(
idx, scaling * normalDist(Aidge::Random::Generator::get()));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
}
template void Aidge::xavierUniformFiller<float>(std::shared_ptr<Aidge::Tensor>,
float, Aidge::VarianceNorm);
template void Aidge::xavierUniformFiller<double>(std::shared_ptr<Aidge::Tensor>,
double, Aidge::VarianceNorm);
template void Aidge::xavierNormalFiller<float>(std::shared_ptr<Aidge::Tensor>,
float, Aidge::VarianceNorm);
template void Aidge::xavierNormalFiller<double>(std::shared_ptr<Aidge::Tensor>,
double, Aidge::VarianceNorm);