Skip to content
Snippets Groups Projects
Commit 8b9e1b9a authored by Cyril Moineau's avatar Cyril Moineau Committed by Maxence Naud
Browse files

Add heFiller.

parent 5e9239f3
No related branches found
No related tags found
No related merge requests found
...@@ -19,6 +19,25 @@ ...@@ -19,6 +19,25 @@
namespace Aidge { namespace Aidge {
void calculateFanInFanOut(std::shared_ptr<Tensor> tensor, unsigned int& fanIn,
unsigned int& fanOut) {
AIDGE_ASSERT(
tensor->nbDims() == 4,
"Tensor need to have 4 dimensions to compute FanIn and FanOut.");
// Warning: This function suppose NCXX data layout.
// Aidge currently only support NCHW but this maybe not be true in the
// future.
DimSize_t batchSize = tensor->dims()[0];
DimSize_t channelSize = tensor->dims()[1];
AIDGE_ASSERT(batchSize != 0,
"Cannot calculate FanIn if tensor batch size is 0.");
AIDGE_ASSERT(channelSize != 0,
"Cannot calculate FanOut if tensor channel size is 0.");
fanIn = static_cast<unsigned int>(tensor->size() / batchSize);
fanOut = static_cast<unsigned int>(tensor->size() / channelSize);
}
enum VarianceNorm { FanIn, Average, FanOut };
template <typename T> template <typename T>
void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) {
AIDGE_ASSERT(tensor->getImpl(), AIDGE_ASSERT(tensor->getImpl(),
...@@ -40,8 +59,7 @@ void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { ...@@ -40,8 +59,7 @@ void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) {
} }
// TODO: Keep template or use switch case depending on Tensor datatype ? // TODO: Keep template or use switch case depending on Tensor datatype ?
template <typename T> template <typename T>
void normalFiller(std::shared_ptr<Tensor> tensor, void normalFiller(std::shared_ptr<Tensor> tensor, double mean = 0.0,
double mean = 0.0,
double stdDev = 1.0) { double stdDev = 1.0) {
AIDGE_ASSERT(tensor->getImpl(), AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it."); "Tensor got no implementation, cannot fill it.");
...@@ -66,7 +84,7 @@ void normalFiller(std::shared_ptr<Tensor> tensor, ...@@ -66,7 +84,7 @@ void normalFiller(std::shared_ptr<Tensor> tensor,
}; };
// TODO: Keep template or use switch case depending on Tensor datatype ? // TODO: Keep template or use switch case depending on Tensor datatype ?
template<typename T> template <typename T>
void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) { void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) {
AIDGE_ASSERT(tensor->getImpl(), AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it."); "Tensor got no implementation, cannot fill it.");
...@@ -74,7 +92,7 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) { ...@@ -74,7 +92,7 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) {
std::random_device rd; std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::uniform_distribution<T> uniformDist(min, max); std::uniform_real_distribution<T> uniformDist(min, max);
std::shared_ptr<Tensor> cpyTensor; std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU // Create cpy only if tensor not on CPU
...@@ -89,8 +107,113 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) { ...@@ -89,8 +107,113 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) {
// Copy values back to the original tensors (actual copy only if needed) // Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues); tensor->copyCastFrom(tensorWithValues);
}; };
// void xavierFiller(std::shared_ptr<Tensor> tensor);
// void heFiller(std::shared_ptr<Tensor> tensor); template <typename T>
void xavierUniformFiller(std::shared_ptr<Tensor> tensor, T scaling = 1.0,
VarianceNorm varianceNorm = FanIn) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
unsigned int fanIn, fanOut = 0;
calculateFanInFanOut(tensor, fanIn, fanOut);
const T n((varianceNorm == FanIn) ? fanIn
: (varianceNorm == Average) ? (fanIn + fanOut) / 2.0
: fanOut);
const T scale(std::sqrt(3.0 / n));
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::uniform_real_distribution<T> uniformDist(-scale, scale);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
T value = scaling * uniformDist(gen);
tensorWithValues.set<T>(idx, value);
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
template <typename T>
void xavierNormalFiller(std::shared_ptr<Tensor> tensor, T scaling = 1.0,
VarianceNorm varianceNorm = FanIn) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
unsigned int fanIn, fanOut = 0;
calculateFanInFanOut(tensor, fanIn, fanOut);
const T n((varianceNorm == FanIn) ? fanIn
: (varianceNorm == Average) ? (fanIn + fanOut) / 2.0
: fanOut);
const double stdDev(std::sqrt(1.0 / n));
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::normal_distribution<T> normalDist(0.0, stdDev);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, normalDist(gen));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
template <typename T>
void heFiller(std::shared_ptr<Tensor> tensor, VarianceNorm varianceNorm = FanIn,
T meanNorm = 0.0, T scaling = 1.0) {
AIDGE_ASSERT(tensor->getImpl(),
"Tensor got no implementation, cannot fill it.");
AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type");
unsigned int fanIn, fanOut = 0;
calculateFanInFanOut(tensor, fanIn, fanOut);
const T n((varianceNorm == FanIn) ? fanIn
: (varianceNorm == Average) ? (fanIn + fanOut) / 2.0
: fanOut);
const T stdDev(std::sqrt(2.0 / n));
const T mean(varianceNorm == FanIn ? meanNorm / fanIn
: (varianceNorm == Average)
? meanNorm / ((fanIn + fanOut) / 2.0)
: meanNorm / fanOut);
std::random_device rd;
std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator
std::normal_distribution<T> normalDist(mean, stdDev);
std::shared_ptr<Tensor> cpyTensor;
// Create cpy only if tensor not on CPU
Tensor& tensorWithValues =
tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu");
// Setting values
for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) {
tensorWithValues.set<T>(idx, normalDist(gen));
}
// Copy values back to the original tensors (actual copy only if needed)
tensor->copyCastFrom(tensorWithValues);
};
} // namespace Aidge } // namespace Aidge
......
...@@ -19,76 +19,150 @@ namespace py = pybind11; ...@@ -19,76 +19,150 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Filler(py::module &m) { void init_Filler(py::module &m) {
m.def("constant_filler", py::enum_<enum VarianceNorm>(m, "VarianceNorm")
[](std::shared_ptr<Tensor> tensor, py::object value) -> void { .value("FanIn", VarianceNorm::FanIn)
switch (tensor->dataType()) { .value("Average", VarianceNorm::Average)
case DataType::Float64: .value("FanOut", VarianceNorm::FanOut)
constantFiller<double>(tensor, value.cast<double>()); .export_values();
break;
case DataType::Float32: m.def(
constantFiller<float>(tensor, value.cast<float>()); "constant_filler",
break; [](std::shared_ptr<Tensor> tensor, py::object value) -> void {
case DataType::Int8: switch (tensor->dataType()) {
constantFiller<int8_t>(tensor, value.cast<int8_t>()); case DataType::Float64:
break; constantFiller<double>(tensor, value.cast<double>());
case DataType::Int16: break;
constantFiller<std::int16_t>(tensor, case DataType::Float32:
value.cast<std::int16_t>()); constantFiller<float>(tensor, value.cast<float>());
break; break;
case DataType::Int32: case DataType::Int8:
constantFiller<std::int32_t>(tensor, constantFiller<int8_t>(tensor, value.cast<int8_t>());
value.cast<std::int32_t>()); break;
break; case DataType::Int16:
case DataType::Int64: constantFiller<std::int16_t>(tensor,
constantFiller<std::int64_t>(tensor, value.cast<std::int16_t>());
value.cast<std::int64_t>()); break;
break; case DataType::Int32:
case DataType::UInt8: constantFiller<std::int32_t>(tensor,
constantFiller<std::uint8_t>(tensor, value.cast<std::int32_t>());
value.cast<std::uint8_t>()); break;
break; case DataType::Int64:
case DataType::UInt16: constantFiller<std::int64_t>(tensor,
constantFiller<std::uint16_t>( value.cast<std::int64_t>());
tensor, value.cast<std::uint16_t>()); break;
break; case DataType::UInt8:
default: constantFiller<std::uint8_t>(tensor,
AIDGE_THROW_OR_ABORT( value.cast<std::uint8_t>());
py::value_error, break;
"Data type is not supported for Constant filler."); case DataType::UInt16:
} constantFiller<std::uint16_t>(tensor,
}, py::arg("tensor"), py::arg("value")) value.cast<std::uint16_t>());
.def("normal_filler", break;
[](std::shared_ptr<Tensor> tensor, double mean, default:
double stdDev) -> void { AIDGE_THROW_OR_ABORT(
switch (tensor->dataType()) { py::value_error,
case DataType::Float64: "Data type is not supported for Constant filler.");
normalFiller<double>(tensor, mean, stdDev); }
break; },
case DataType::Float32: py::arg("tensor"), py::arg("value"))
normalFiller<float>(tensor, mean, stdDev); .def(
break; "normal_filler",
default: [](std::shared_ptr<Tensor> tensor, double mean,
AIDGE_THROW_OR_ABORT( double stdDev) -> void {
py::value_error, switch (tensor->dataType()) {
"Data type is not supported for Normal filler."); case DataType::Float64:
} normalFiller<double>(tensor, mean, stdDev);
}, py::arg("tensor"), py::arg("mean")=0.0, py::arg("stdDev")=1.0) break;
.def("uniform_filler", case DataType::Float32:
[](std::shared_ptr<Tensor> tensor, double min, normalFiller<float>(tensor, mean, stdDev);
double max) -> void { break;
switch (tensor->dataType()) { default:
case DataType::Float64: AIDGE_THROW_OR_ABORT(
uniformFiller<double>(tensor, min, max); py::value_error,
break; "Data type is not supported for Normal filler.");
case DataType::Float32: }
uniformFiller<float>(tensor, min, max); },
break; py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0)
default: .def(
AIDGE_THROW_OR_ABORT( "uniform_filler",
py::value_error, [](std::shared_ptr<Tensor> tensor, double min, double max) -> void {
"Data type is not supported for Uniform filler."); switch (tensor->dataType()) {
} case DataType::Float64:
}, py::arg("tensor"), py::arg("min"), py::arg("max")) uniformFiller<double>(tensor, min, max);
; break;
case DataType::Float32:
uniformFiller<float>(tensor, min, max);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
},
py::arg("tensor"), py::arg("min"), py::arg("max"))
.def(
"xavier_uniform_filler",
[](std::shared_ptr<Tensor> tensor, py::object scaling,
VarianceNorm varianceNorm) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
xavierUniformFiller<double>(
tensor, scaling.cast<double>(), varianceNorm);
break;
case DataType::Float32:
xavierUniformFiller<float>(
tensor, scaling.cast<float>(), varianceNorm);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
},
py::arg("tensor"), py::arg("scaling") = 1.0,
py::arg("varianceNorm") = VarianceNorm::FanIn)
.def(
"xavier_normal_filler",
[](std::shared_ptr<Tensor> tensor, py::object scaling,
VarianceNorm varianceNorm) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
xavierNormalFiller<double>(
tensor, scaling.cast<double>(), varianceNorm);
break;
case DataType::Float32:
xavierNormalFiller<float>(tensor, scaling.cast<float>(),
varianceNorm);
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
},
py::arg("tensor"), py::arg("scaling") = 1.0,
py::arg("varianceNorm") = VarianceNorm::FanIn)
.def(
"he_filler",
[](std::shared_ptr<Tensor> tensor, VarianceNorm varianceNorm,
py::object meanNorm, py::object scaling) -> void {
switch (tensor->dataType()) {
case DataType::Float64:
heFiller<double>(tensor, varianceNorm,
meanNorm.cast<double>(),
scaling.cast<double>());
break;
case DataType::Float32:
heFiller<float>(tensor, varianceNorm,
meanNorm.cast<float>(),
scaling.cast<float>());
break;
default:
AIDGE_THROW_OR_ABORT(
py::value_error,
"Data type is not supported for Uniform filler.");
}
},
py::arg("tensor"), py::arg("varianceNorm") = VarianceNorm::FanIn, py::arg("meanNorm") = 0.0, py::arg("scaling") = 1.0);
} }
} // namespace Aidge } // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment