Skip to content
Snippets Groups Projects
Commit 8b9e1b9a authored by Cyril Moineau's avatar Cyril Moineau Committed by Maxence Naud
Browse files

Add heFiller.

parent 5e9239f3
No related branches found
No related tags found
2 merge requests!105version 0.2.0,!99Adding Filler to aidge_core
......@@ -19,6 +19,25 @@
namespace Aidge {
void calculateFanInFanOut(std::shared_ptr<Tensor> tensor, unsigned int& fanIn,
unsigned int& fanOut) {
AIDGE_ASSERT(
tensor->nbDims() == 4,
"Tensor need to have 4 dimensions to compute FanIn and FanOut.");
// Warning: This function suppose NCXX data layout.
// Aidge currently only support NCHW but this maybe not be true in the
// future.
DimSize_t batchSize = tensor->dims()[0];
DimSize_t channelSize = tensor->dims()[1];
AIDGE_ASSERT(batchSize != 0,
"Cannot calculate FanIn if tensor batch size is 0.");
AIDGE_ASSERT(channelSize != 0,
"Cannot calculate FanOut if tensor channel size is 0.");
fanIn = static_cast<unsigned int>(tensor->size() / batchSize);
fanOut = static_cast<unsigned int>(tensor->size() / channelSize);
}
enum VarianceNorm { FanIn, Average, FanOut };
template <typename T>
void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) {
AIDGE_ASSERT(tensor->getImpl(),
......@@ -40,8 +59,7 @@ void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) {
}
// TODO: Keep template or use switch case depending on Tensor datatype ?
template <typename T>
void normalFiller(std::shared_ptr<Tensor> tensor,
double mean = 0.0,
void normalFiller(std::shared_ptr<Tensor> tensor, double mean = 0.0,
double stdDev = 1.0) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
......@@ -66,7 +84,7 @@ void normalFiller(std::shared_ptr<Tensor> tensor,
};
// TODO: Keep template or use switch case depending on Tensor datatype ?
template<typename T>
template <typename T>
void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
......@@ -74,7 +92,7 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) {
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::uniform_distribution<T> uniformDist(min, max);
std::uniform_real_distribution<T> uniformDist(min, max);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
......@@ -89,8 +107,113 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) {
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
// void xavierFiller(std::shared_ptr<Tensor> tensor);
// void heFiller(std::shared_ptr<Tensor> tensor);
template <typename T>
void xavierUniformFiller(std::shared_ptr<Tensor> tensor, T scaling = 1.0,
VarianceNorm varianceNorm = FanIn) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
unsigned int fanIn, fanOut = 0;
calculateFanInFanOut(tensor, fanIn, fanOut);
const T n((varianceNorm == FanIn) ? fanIn
: (varianceNorm == Average) ? (fanIn + fanOut) / 2.0
: fanOut);
const T scale(std::sqrt(3.0 / n));
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::uniform_real_distribution<T> uniformDist(-scale, scale);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
T value = scaling * uniformDist(gen);
tensorWithValues.set<T>(idx, value);
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
template <typename T>
void xavierNormalFiller(std::shared_ptr<Tensor> tensor, T scaling = 1.0,
VarianceNorm varianceNorm = FanIn) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
unsigned int fanIn, fanOut = 0;
calculateFanInFanOut(tensor, fanIn, fanOut);
const T n((varianceNorm == FanIn) ? fanIn
: (varianceNorm == Average) ? (fanIn + fanOut) / 2.0
: fanOut);
const double stdDev(std::sqrt(1.0 / n));
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::normal_distribution<T> normalDist(0.0, stdDev);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, normalDist(gen));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
template <typename T>
void heFiller(std::shared_ptr<Tensor> tensor, VarianceNorm varianceNorm = FanIn,
T meanNorm = 0.0, T scaling = 1.0) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
unsigned int fanIn, fanOut = 0;
calculateFanInFanOut(tensor, fanIn, fanOut);
const T n((varianceNorm == FanIn) ? fanIn
: (varianceNorm == Average) ? (fanIn + fanOut) / 2.0
: fanOut);
const T stdDev(std::sqrt(2.0 / n));
const T mean(varianceNorm == FanIn ? meanNorm / fanIn
: (varianceNorm == Average)
? meanNorm / ((fanIn + fanOut) / 2.0)
: meanNorm / fanOut);
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::normal_distribution<T> normalDist(mean, stdDev);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, normalDist(gen));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
} // namespace Aidge
......
......@@ -19,76 +19,150 @@ namespace py = pybind11;
namespace Aidge {
void init_Filler(py::module &m) {
m.def("constant_filler",
[](std::shared_ptr<Tensor> tensor, py::object value) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
constantFiller<double>(tensor, value.cast<double>());
break;
case DataType::Float32:
constantFiller<float>(tensor, value.cast<float>());
break;
case DataType::Int8:
constantFiller<int8_t>(tensor, value.cast<int8_t>());
break;
case DataType::Int16:
constantFiller<std::int16_t>(tensor,
value.cast<std::int16_t>());
break;
case DataType::Int32:
constantFiller<std::int32_t>(tensor,
value.cast<std::int32_t>());
break;
case DataType::Int64:
constantFiller<std::int64_t>(tensor,
value.cast<std::int64_t>());
break;
case DataType::UInt8:
constantFiller<std::uint8_t>(tensor,
value.cast<std::uint8_t>());
break;
case DataType::UInt16:
constantFiller<std::uint16_t>(
tensor, value.cast<std::uint16_t>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Constant filler.");
}
}, py::arg("tensor"), py::arg("value"))
.def("normal_filler",
[](std::shared_ptr<Tensor> tensor, double mean,
double stdDev) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
normalFiller<double>(tensor, mean, stdDev);
break;
case DataType::Float32:
normalFiller<float>(tensor, mean, stdDev);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Normal filler.");
}
}, py::arg("tensor"), py::arg("mean")=0.0, py::arg("stdDev")=1.0)
.def("uniform_filler",
[](std::shared_ptr<Tensor> tensor, double min,
double max) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
uniformFiller<double>(tensor, min, max);
break;
case DataType::Float32:
uniformFiller<float>(tensor, min, max);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
}, py::arg("tensor"), py::arg("min"), py::arg("max"))
;
py::enum_<enum VarianceNorm>(m, "VarianceNorm")
.value("FanIn", VarianceNorm::FanIn)
.value("Average", VarianceNorm::Average)
.value("FanOut", VarianceNorm::FanOut)
.export_values();
m.def(
"constant_filler",
[](std::shared_ptr<Tensor> tensor, py::object value) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
constantFiller<double>(tensor, value.cast<double>());
break;
case DataType::Float32:
constantFiller<float>(tensor, value.cast<float>());
break;
case DataType::Int8:
constantFiller<int8_t>(tensor, value.cast<int8_t>());
break;
case DataType::Int16:
constantFiller<std::int16_t>(tensor,
value.cast<std::int16_t>());
break;
case DataType::Int32:
constantFiller<std::int32_t>(tensor,
value.cast<std::int32_t>());
break;
case DataType::Int64:
constantFiller<std::int64_t>(tensor,
value.cast<std::int64_t>());
break;
case DataType::UInt8:
constantFiller<std::uint8_t>(tensor,
value.cast<std::uint8_t>());
break;
case DataType::UInt16:
constantFiller<std::uint16_t>(tensor,
value.cast<std::uint16_t>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Constant filler.");
}
},
py::arg("tensor"), py::arg("value"))
.def(
"normal_filler",
[](std::shared_ptr<Tensor> tensor, double mean,
double stdDev) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
normalFiller<double>(tensor, mean, stdDev);
break;
case DataType::Float32:
normalFiller<float>(tensor, mean, stdDev);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Normal filler.");
}
},
py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0)
.def(
"uniform_filler",
[](std::shared_ptr<Tensor> tensor, double min, double max) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
uniformFiller<double>(tensor, min, max);
break;
case DataType::Float32:
uniformFiller<float>(tensor, min, max);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
},
py::arg("tensor"), py::arg("min"), py::arg("max"))
.def(
"xavier_uniform_filler",
[](std::shared_ptr<Tensor> tensor, py::object scaling,
VarianceNorm varianceNorm) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
xavierUniformFiller<double>(
tensor, scaling.cast<double>(), varianceNorm);
break;
case DataType::Float32:
xavierUniformFiller<float>(
tensor, scaling.cast<float>(), varianceNorm);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
},
py::arg("tensor"), py::arg("scaling") = 1.0,
py::arg("varianceNorm") = VarianceNorm::FanIn)
.def(
"xavier_normal_filler",
[](std::shared_ptr<Tensor> tensor, py::object scaling,
VarianceNorm varianceNorm) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
xavierNormalFiller<double>(
tensor, scaling.cast<double>(), varianceNorm);
break;
case DataType::Float32:
xavierNormalFiller<float>(tensor, scaling.cast<float>(),
varianceNorm);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
},
py::arg("tensor"), py::arg("scaling") = 1.0,
py::arg("varianceNorm") = VarianceNorm::FanIn)
.def(
"he_filler",
[](std::shared_ptr<Tensor> tensor, VarianceNorm varianceNorm,
py::object meanNorm, py::object scaling) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
heFiller<double>(tensor, varianceNorm,
meanNorm.cast<double>(),
scaling.cast<double>());
break;
case DataType::Float32:
heFiller<float>(tensor, varianceNorm,
meanNorm.cast<float>(),
scaling.cast<float>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
},
py::arg("tensor"), py::arg("varianceNorm") = VarianceNorm::FanIn, py::arg("meanNorm") = 0.0, py::arg("scaling") = 1.0);
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment