Skip to content
Snippets Groups Projects
Commit 98fc16c8 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Add assert in fillers to avoid div by 0 or sqrt of neg.

parent 91eff9cb
No related branches found
No related tags found
2 merge requests!1190.2.1,!118Update how loss function work
...@@ -29,7 +29,9 @@ void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor, ...@@ -29,7 +29,9 @@ void Aidge::heFiller(std::shared_ptr<Aidge::Tensor> tensor,
: (varianceNorm == Aidge::VarianceNorm::Average) : (varianceNorm == Aidge::VarianceNorm::Average)
? (fanIn + fanOut) / 2.0 ? (fanIn + fanOut) / 2.0
: fanOut); : fanOut);
AIDGE_ASSERT(n > 0,
"Something went wrong division by zero or square root of "
"negative value.");
const T stdDev(std::sqrt(2.0 / n)); const T stdDev(std::sqrt(2.0 / n));
const T mean(varianceNorm == Aidge::VarianceNorm::FanIn ? meanNorm / fanIn const T mean(varianceNorm == Aidge::VarianceNorm::FanIn ? meanNorm / fanIn
......
...@@ -29,6 +29,9 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor, ...@@ -29,6 +29,9 @@ void Aidge::xavierUniformFiller(std::shared_ptr<Aidge::Tensor> tensor,
: (varianceNorm == Aidge::VarianceNorm::Average) : (varianceNorm == Aidge::VarianceNorm::Average)
? (fanIn + fanOut) / 2.0 ? (fanIn + fanOut) / 2.0
: fanOut); : fanOut);
AIDGE_ASSERT(n > 0,
"Something went wrong division by zero or square root of "
"negative value.");
const T scale(std::sqrt(3.0 / n)); const T scale(std::sqrt(3.0 / n));
std::uniform_real_distribution<T> uniformDist(-scale, scale); std::uniform_real_distribution<T> uniformDist(-scale, scale);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment