Skip to content
Snippets Groups Projects
Commit 74a582a1 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] BatchNorm factory to build its parameters at creation

parent 631de16f
No related branches found
No related tags found
No related merge requests found
......@@ -82,9 +82,9 @@ public:
associated &= !(getInput(i)->empty());
}
if (associated) {
const DimSize_t nbChannels = getInput(0)->dims()[1];
const DimSize_t nbFeatures = getInput(0)->dims()[1];
for (std::size_t i = nbData(); i < nbInputs(); ++i) {
if(getInput(i)->size() != nbChannels) {
if(getInput(i)->size() != nbFeatures) {
// /!\ Input size should be handled BEFORE calling this function
// This should raise an error
getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]}));
......@@ -117,15 +117,16 @@ template <DimIdx_t DIM>
const std::string BatchNorm_Op<DIM>::Type = "BatchNorm";
template <DimSize_t DIM>
inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F,
inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
const float epsilon = 1.0e-5F,
const float momentum = 0.1F,
const std::string& name = "") {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name);
addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale");
addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift");
addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean");
addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance");
addProducer(batchNorm, 1, std::array<DimSize_t,1>({nbFeatures}), "scale");
addProducer(batchNorm, 2, std::array<DimSize_t,1>({nbFeatures}), "shift");
addProducer(batchNorm, 3, std::array<DimSize_t,1>({nbFeatures}), "batch_mean");
addProducer(batchNorm, 4, std::array<DimSize_t,1>({nbFeatures}), "batch_variance");
return batchNorm;
}
} // namespace Aidge
......
......@@ -25,7 +25,7 @@ void declare_BatchNormOp(py::module& m) {
.def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName)
.def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName);
m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = "");
m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nbFeatures"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = "");
}
void init_BatchNorm(py::module &m) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment