diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 6dc0455bd78d8f196d28bb03b26630f46eabd95b..31dbdd4df2340953e408d0ff5744cb4ff8ce3e9d 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -82,9 +82,9 @@ public: associated &= !(getInput(i)->empty()); } if (associated) { - const DimSize_t nbChannels = getInput(0)->dims()[1]; + const DimSize_t nbFeatures = getInput(0)->dims()[1]; for (std::size_t i = nbData(); i < nbInputs(); ++i) { - if(getInput(i)->size() != nbChannels) { + if(getInput(i)->size() != nbFeatures) { // /!\ Input size should be handled BEFORE calling this function // This should raise an error getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]})); @@ -117,15 +117,16 @@ template <DimIdx_t DIM> const std::string BatchNorm_Op<DIM>::Type = "BatchNorm"; template <DimSize_t DIM> -inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, +inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures, + const float epsilon = 1.0e-5F, const float momentum = 0.1F, const std::string& name = "") { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); - addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale"); - addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift"); - addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean"); - addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance"); + addProducer(batchNorm, 1, std::array<DimSize_t,1>({nbFeatures}), "scale"); + addProducer(batchNorm, 2, std::array<DimSize_t,1>({nbFeatures}), "shift"); + addProducer(batchNorm, 3, std::array<DimSize_t,1>({nbFeatures}), "batch_mean"); + addProducer(batchNorm, 4, std::array<DimSize_t,1>({nbFeatures}), "batch_variance"); return batchNorm; } } // namespace Aidge diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index ff0b9e0dfcb0d1c5e5567a938b1ca74faf242bed..411a2e1b6ae78065a79b92f25c23dac13e341997 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -25,7 +25,7 @@ void declare_BatchNormOp(py::module& m) { .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); - m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); + m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nbFeatures"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); } void init_BatchNorm(py::module &m) {