Code owners
Assign users and groups as approvers for specific file changes. Learn more.
NormalFiller.cpp 1.71 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <memory>
#include <random> // normal_distribution, uniform_real_distribution
#include "aidge/data/Tensor.hpp"
#include "aidge/filler/Filler.hpp"
#include "aidge/utils/Random.hpp"
template <typename T>
void Aidge::normalFiller(std::shared_ptr<Aidge::Tensor> tensor, double mean,
double stdDev) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
std::normal_distribution<T> normalDist(mean, stdDev);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, normalDist(Aidge::Random::Generator::get()));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
}
template void Aidge::normalFiller<float>(std::shared_ptr<Aidge::Tensor>, double,
double);
template void Aidge::normalFiller<double>(std::shared_ptr<Aidge::Tensor>,
double, double);