Skip to content
Snippets Groups Projects

[Add] Batchnorm 'training_mode' flag

Merged Benjamin Halimi requested to merge bn_flag into dev
1 file
+ 1
1
Compare changes
  • Side-by-side
  • Inline
@@ -62,7 +62,7 @@ void insertBatchNormNodes(std::shared_ptr<GraphView> graphView)
std::cout << " NB CHANNELS = " << nb_channels << std::endl; // TODO : remove this ...
std::string batchnormNodeName = makeUniqueName(parentNode->name() + "_BN", graphView);
std::shared_ptr<Node> batchnormNode = BatchNorm<2>(nb_channels, 1e-5, 0.1, batchnormNodeName);
std::shared_ptr<Node> batchnormNode = BatchNorm<2>(nb_channels, 1e-5, 0.1, false, batchnormNodeName);
batchnormNode->getOperator()->setDataType(DataType::Float32);
batchnormNode->getOperator()->setBackend("cpu");
Loading