Skip to content
Snippets Groups Projects
Commit 74a582a1 authored by Maxence Naud's avatar Maxence Naud
Browse files

[Upd] BatchNorm factory to build its parameters at creation

parent 631de16f
No related branches found
No related tags found
No related merge requests found
...@@ -82,9 +82,9 @@ public: ...@@ -82,9 +82,9 @@ public:
associated &= !(getInput(i)->empty()); associated &= !(getInput(i)->empty());
} }
if (associated) { if (associated) {
const DimSize_t nbChannels = getInput(0)->dims()[1]; const DimSize_t nbFeatures = getInput(0)->dims()[1];
for (std::size_t i = nbData(); i < nbInputs(); ++i) { for (std::size_t i = nbData(); i < nbInputs(); ++i) {
if(getInput(i)->size() != nbChannels) { if(getInput(i)->size() != nbFeatures) {
// /!\ Input size should be handled BEFORE calling this function // /!\ Input size should be handled BEFORE calling this function
// This should raise an error // This should raise an error
getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]})); getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]}));
...@@ -117,15 +117,16 @@ template <DimIdx_t DIM> ...@@ -117,15 +117,16 @@ template <DimIdx_t DIM>
const std::string BatchNorm_Op<DIM>::Type = "BatchNorm"; const std::string BatchNorm_Op<DIM>::Type = "BatchNorm";
template <DimSize_t DIM> template <DimSize_t DIM>
inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
const float epsilon = 1.0e-5F,
const float momentum = 0.1F, const float momentum = 0.1F,
const std::string& name = "") { const std::string& name = "") {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name);
addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale"); addProducer(batchNorm, 1, std::array<DimSize_t,1>({nbFeatures}), "scale");
addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift"); addProducer(batchNorm, 2, std::array<DimSize_t,1>({nbFeatures}), "shift");
addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean"); addProducer(batchNorm, 3, std::array<DimSize_t,1>({nbFeatures}), "batch_mean");
addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance"); addProducer(batchNorm, 4, std::array<DimSize_t,1>({nbFeatures}), "batch_variance");
return batchNorm; return batchNorm;
} }
} // namespace Aidge } // namespace Aidge
......
...@@ -25,7 +25,7 @@ void declare_BatchNormOp(py::module& m) { ...@@ -25,7 +25,7 @@ void declare_BatchNormOp(py::module& m) {
.def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName)
.def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName);
m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nbFeatures"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = "");
} }
void init_BatchNorm(py::module &m) { void init_BatchNorm(py::module &m) {
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment