diff --git a/include/aidge/filler/Filler.hpp b/include/aidge/filler/Filler.hpp index e6de2f930e7d06be47244a3f9f8cd73c3aa18b51..a021e3d10969025bb349c96e163602b7edf94735 100644 --- a/include/aidge/filler/Filler.hpp +++ b/include/aidge/filler/Filler.hpp @@ -19,6 +19,25 @@ namespace Aidge { +void calculateFanInFanOut(std::shared_ptr<Tensor> tensor, unsigned int& fanIn, + unsigned int& fanOut) { + AIDGE_ASSERT( + tensor->nbDims() == 4, + "Tensor need to have 4 dimensions to compute FanIn and FanOut."); + // Warning: This function suppose NCXX data layout. + // Aidge currently only support NCHW but this maybe not be true in the + // future. + DimSize_t batchSize = tensor->dims()[0]; + DimSize_t channelSize = tensor->dims()[1]; + AIDGE_ASSERT(batchSize != 0, + "Cannot calculate FanIn if tensor batch size is 0."); + AIDGE_ASSERT(channelSize != 0, + "Cannot calculate FanOut if tensor channel size is 0."); + fanIn = static_cast<unsigned int>(tensor->size() / batchSize); + fanOut = static_cast<unsigned int>(tensor->size() / channelSize); +} +enum VarianceNorm { FanIn, Average, FanOut }; + template <typename T> void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { AIDGE_ASSERT(tensor->getImpl(), @@ -40,8 +59,7 @@ void constantFiller(std::shared_ptr<Tensor> tensor, T constantValue) { } // TODO: Keep template or use switch case depending on Tensor datatype ? template <typename T> -void normalFiller(std::shared_ptr<Tensor> tensor, - double mean = 0.0, +void normalFiller(std::shared_ptr<Tensor> tensor, double mean = 0.0, double stdDev = 1.0) { AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); @@ -66,7 +84,7 @@ void normalFiller(std::shared_ptr<Tensor> tensor, }; // TODO: Keep template or use switch case depending on Tensor datatype ? -template<typename T> +template <typename T> void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) { AIDGE_ASSERT(tensor->getImpl(), "Tensor got no implementation, cannot fill it."); @@ -74,7 +92,7 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) { std::random_device rd; std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator - std::uniform_distribution<T> uniformDist(min, max); + std::uniform_real_distribution<T> uniformDist(min, max); std::shared_ptr<Tensor> cpyTensor; // Create cpy only if tensor not on CPU @@ -89,8 +107,113 @@ void uniformFiller(std::shared_ptr<Tensor> tensor, T min, T max) { // Copy values back to the original tensors (actual copy only if needed) tensor->copyCastFrom(tensorWithValues); }; -// void xavierFiller(std::shared_ptr<Tensor> tensor); -// void heFiller(std::shared_ptr<Tensor> tensor); + +template <typename T> +void xavierUniformFiller(std::shared_ptr<Tensor> tensor, T scaling = 1.0, + VarianceNorm varianceNorm = FanIn) { + AIDGE_ASSERT(tensor->getImpl(), + "Tensor got no implementation, cannot fill it."); + AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + + unsigned int fanIn, fanOut = 0; + calculateFanInFanOut(tensor, fanIn, fanOut); + + const T n((varianceNorm == FanIn) ? fanIn + : (varianceNorm == Average) ? (fanIn + fanOut) / 2.0 + : fanOut); + const T scale(std::sqrt(3.0 / n)); + + std::random_device rd; + std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator + + std::uniform_real_distribution<T> uniformDist(-scale, scale); + + std::shared_ptr<Tensor> cpyTensor; + // Create cpy only if tensor not on CPU + Tensor& tensorWithValues = + tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + // Setting values + for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { + T value = scaling * uniformDist(gen); + tensorWithValues.set<T>(idx, value); + } + + // Copy values back to the original tensors (actual copy only if needed) + tensor->copyCastFrom(tensorWithValues); +}; +template <typename T> +void xavierNormalFiller(std::shared_ptr<Tensor> tensor, T scaling = 1.0, + VarianceNorm varianceNorm = FanIn) { + AIDGE_ASSERT(tensor->getImpl(), + "Tensor got no implementation, cannot fill it."); + AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + + unsigned int fanIn, fanOut = 0; + calculateFanInFanOut(tensor, fanIn, fanOut); + + const T n((varianceNorm == FanIn) ? fanIn + : (varianceNorm == Average) ? (fanIn + fanOut) / 2.0 + : fanOut); + const double stdDev(std::sqrt(1.0 / n)); + + std::random_device rd; + std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator + + std::normal_distribution<T> normalDist(0.0, stdDev); + + std::shared_ptr<Tensor> cpyTensor; + // Create cpy only if tensor not on CPU + Tensor& tensorWithValues = + tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + + // Setting values + for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { + tensorWithValues.set<T>(idx, normalDist(gen)); + } + + // Copy values back to the original tensors (actual copy only if needed) + tensor->copyCastFrom(tensorWithValues); +}; + +template <typename T> +void heFiller(std::shared_ptr<Tensor> tensor, VarianceNorm varianceNorm = FanIn, + T meanNorm = 0.0, T scaling = 1.0) { + AIDGE_ASSERT(tensor->getImpl(), + "Tensor got no implementation, cannot fill it."); + AIDGE_ASSERT(NativeType<T>::type == tensor->dataType(), "Wrong data type"); + + unsigned int fanIn, fanOut = 0; + calculateFanInFanOut(tensor, fanIn, fanOut); + + const T n((varianceNorm == FanIn) ? fanIn + : (varianceNorm == Average) ? (fanIn + fanOut) / 2.0 + : fanOut); + + const T stdDev(std::sqrt(2.0 / n)); + + const T mean(varianceNorm == FanIn ? meanNorm / fanIn + : (varianceNorm == Average) + ? meanNorm / ((fanIn + fanOut) / 2.0) + : meanNorm / fanOut); + + std::random_device rd; + std::mt19937 gen(rd()); // Mersenne Twister pseudo-random number generator + + std::normal_distribution<T> normalDist(mean, stdDev); + + std::shared_ptr<Tensor> cpyTensor; + // Create cpy only if tensor not on CPU + Tensor& tensorWithValues = + tensor->refCastFrom(cpyTensor, tensor->dataType(), "cpu"); + + // Setting values + for (std::size_t idx = 0; idx < tensorWithValues.size(); ++idx) { + tensorWithValues.set<T>(idx, normalDist(gen)); + } + + // Copy values back to the original tensors (actual copy only if needed) + tensor->copyCastFrom(tensorWithValues); +}; } // namespace Aidge diff --git a/python_binding/filler/pybind_Filler.cpp b/python_binding/filler/pybind_Filler.cpp index 6938ef2b9c451329607fc2c7630dbee29a88f872..fea7543fa21e625da05493064d1bbf2fa630f4d5 100644 --- a/python_binding/filler/pybind_Filler.cpp +++ b/python_binding/filler/pybind_Filler.cpp @@ -19,76 +19,150 @@ namespace py = pybind11; namespace Aidge { void init_Filler(py::module &m) { - m.def("constant_filler", - [](std::shared_ptr<Tensor> tensor, py::object value) -> void { - switch (tensor->dataType()) { - case DataType::Float64: - constantFiller<double>(tensor, value.cast<double>()); - break; - case DataType::Float32: - constantFiller<float>(tensor, value.cast<float>()); - break; - case DataType::Int8: - constantFiller<int8_t>(tensor, value.cast<int8_t>()); - break; - case DataType::Int16: - constantFiller<std::int16_t>(tensor, - value.cast<std::int16_t>()); - break; - case DataType::Int32: - constantFiller<std::int32_t>(tensor, - value.cast<std::int32_t>()); - break; - case DataType::Int64: - constantFiller<std::int64_t>(tensor, - value.cast<std::int64_t>()); - break; - case DataType::UInt8: - constantFiller<std::uint8_t>(tensor, - value.cast<std::uint8_t>()); - break; - case DataType::UInt16: - constantFiller<std::uint16_t>( - tensor, value.cast<std::uint16_t>()); - break; - default: - AIDGE_THROW_OR_ABORT( - py::value_error, - "Data type is not supported for Constant filler."); - } - }, py::arg("tensor"), py::arg("value")) - .def("normal_filler", - [](std::shared_ptr<Tensor> tensor, double mean, - double stdDev) -> void { - switch (tensor->dataType()) { - case DataType::Float64: - normalFiller<double>(tensor, mean, stdDev); - break; - case DataType::Float32: - normalFiller<float>(tensor, mean, stdDev); - break; - default: - AIDGE_THROW_OR_ABORT( - py::value_error, - "Data type is not supported for Normal filler."); - } - }, py::arg("tensor"), py::arg("mean")=0.0, py::arg("stdDev")=1.0) - .def("uniform_filler", - [](std::shared_ptr<Tensor> tensor, double min, - double max) -> void { - switch (tensor->dataType()) { - case DataType::Float64: - uniformFiller<double>(tensor, min, max); - break; - case DataType::Float32: - uniformFiller<float>(tensor, min, max); - break; - default: - AIDGE_THROW_OR_ABORT( - py::value_error, - "Data type is not supported for Uniform filler."); - } - }, py::arg("tensor"), py::arg("min"), py::arg("max")) - ; + py::enum_<enum VarianceNorm>(m, "VarianceNorm") + .value("FanIn", VarianceNorm::FanIn) + .value("Average", VarianceNorm::Average) + .value("FanOut", VarianceNorm::FanOut) + .export_values(); + + m.def( + "constant_filler", + [](std::shared_ptr<Tensor> tensor, py::object value) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + constantFiller<double>(tensor, value.cast<double>()); + break; + case DataType::Float32: + constantFiller<float>(tensor, value.cast<float>()); + break; + case DataType::Int8: + constantFiller<int8_t>(tensor, value.cast<int8_t>()); + break; + case DataType::Int16: + constantFiller<std::int16_t>(tensor, + value.cast<std::int16_t>()); + break; + case DataType::Int32: + constantFiller<std::int32_t>(tensor, + value.cast<std::int32_t>()); + break; + case DataType::Int64: + constantFiller<std::int64_t>(tensor, + value.cast<std::int64_t>()); + break; + case DataType::UInt8: + constantFiller<std::uint8_t>(tensor, + value.cast<std::uint8_t>()); + break; + case DataType::UInt16: + constantFiller<std::uint16_t>(tensor, + value.cast<std::uint16_t>()); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Constant filler."); + } + }, + py::arg("tensor"), py::arg("value")) + .def( + "normal_filler", + [](std::shared_ptr<Tensor> tensor, double mean, + double stdDev) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + normalFiller<double>(tensor, mean, stdDev); + break; + case DataType::Float32: + normalFiller<float>(tensor, mean, stdDev); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Normal filler."); + } + }, + py::arg("tensor"), py::arg("mean") = 0.0, py::arg("stdDev") = 1.0) + .def( + "uniform_filler", + [](std::shared_ptr<Tensor> tensor, double min, double max) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + uniformFiller<double>(tensor, min, max); + break; + case DataType::Float32: + uniformFiller<float>(tensor, min, max); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Uniform filler."); + } + }, + py::arg("tensor"), py::arg("min"), py::arg("max")) + .def( + "xavier_uniform_filler", + [](std::shared_ptr<Tensor> tensor, py::object scaling, + VarianceNorm varianceNorm) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + xavierUniformFiller<double>( + tensor, scaling.cast<double>(), varianceNorm); + break; + case DataType::Float32: + xavierUniformFiller<float>( + tensor, scaling.cast<float>(), varianceNorm); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Uniform filler."); + } + }, + py::arg("tensor"), py::arg("scaling") = 1.0, + py::arg("varianceNorm") = VarianceNorm::FanIn) + .def( + "xavier_normal_filler", + [](std::shared_ptr<Tensor> tensor, py::object scaling, + VarianceNorm varianceNorm) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + xavierNormalFiller<double>( + tensor, scaling.cast<double>(), varianceNorm); + break; + case DataType::Float32: + xavierNormalFiller<float>(tensor, scaling.cast<float>(), + varianceNorm); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Uniform filler."); + } + }, + py::arg("tensor"), py::arg("scaling") = 1.0, + py::arg("varianceNorm") = VarianceNorm::FanIn) + .def( + "he_filler", + [](std::shared_ptr<Tensor> tensor, VarianceNorm varianceNorm, + py::object meanNorm, py::object scaling) -> void { + switch (tensor->dataType()) { + case DataType::Float64: + heFiller<double>(tensor, varianceNorm, + meanNorm.cast<double>(), + scaling.cast<double>()); + break; + case DataType::Float32: + heFiller<float>(tensor, varianceNorm, + meanNorm.cast<float>(), + scaling.cast<float>()); + break; + default: + AIDGE_THROW_OR_ABORT( + py::value_error, + "Data type is not supported for Uniform filler."); + } + }, + py::arg("tensor"), py::arg("varianceNorm") = VarianceNorm::FanIn, py::arg("meanNorm") = 0.0, py::arg("scaling") = 1.0); } } // namespace Aidge