From 74a582a1d47b4a2ad94c6192d5daaed3523893e2 Mon Sep 17 00:00:00 2001 From: NAUD Maxence <maxence.naud@cea.fr> Date: Thu, 7 Dec 2023 10:53:34 +0000 Subject: [PATCH] [Upd] BatchNorm factory to build its parameters at creation --- include/aidge/operator/BatchNorm.hpp | 15 ++++++++------- python_binding/operator/pybind_BatchNorm.cpp | 2 +- 2 files changed, 9 insertions(+), 8 deletions(-) diff --git a/include/aidge/operator/BatchNorm.hpp b/include/aidge/operator/BatchNorm.hpp index 6dc0455bd..31dbdd4df 100644 --- a/include/aidge/operator/BatchNorm.hpp +++ b/include/aidge/operator/BatchNorm.hpp @@ -82,9 +82,9 @@ public: associated &= !(getInput(i)->empty()); } if (associated) { - const DimSize_t nbChannels = getInput(0)->dims()[1]; + const DimSize_t nbFeatures = getInput(0)->dims()[1]; for (std::size_t i = nbData(); i < nbInputs(); ++i) { - if(getInput(i)->size() != nbChannels) { + if(getInput(i)->size() != nbFeatures) { // /!\ Input size should be handled BEFORE calling this function // This should raise an error getInput(i)->resize(std::array<DimSize_t, 1>({getInput(0)->dims()[1]})); @@ -117,15 +117,16 @@ template <DimIdx_t DIM> const std::string BatchNorm_Op<DIM>::Type = "BatchNorm"; template <DimSize_t DIM> -inline std::shared_ptr<Node> BatchNorm(const float epsilon = 1.0e-5F, +inline std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures, + const float epsilon = 1.0e-5F, const float momentum = 0.1F, const std::string& name = "") { static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported"); auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum), name); - addProducer(batchNorm, 1, std::array<DimSize_t,0>({}), "scale"); - addProducer(batchNorm, 2, std::array<DimSize_t,0>({}), "shift"); - addProducer(batchNorm, 3, std::array<DimSize_t,0>({}), "batch_mean"); - addProducer(batchNorm, 4, std::array<DimSize_t,0>({}), "batch_variance"); + addProducer(batchNorm, 1, std::array<DimSize_t,1>({nbFeatures}), "scale"); + addProducer(batchNorm, 2, std::array<DimSize_t,1>({nbFeatures}), "shift"); + addProducer(batchNorm, 3, std::array<DimSize_t,1>({nbFeatures}), "batch_mean"); + addProducer(batchNorm, 4, std::array<DimSize_t,1>({nbFeatures}), "batch_variance"); return batchNorm; } } // namespace Aidge diff --git a/python_binding/operator/pybind_BatchNorm.cpp b/python_binding/operator/pybind_BatchNorm.cpp index ff0b9e0df..411a2e1b6 100644 --- a/python_binding/operator/pybind_BatchNorm.cpp +++ b/python_binding/operator/pybind_BatchNorm.cpp @@ -25,7 +25,7 @@ void declare_BatchNormOp(py::module& m) { .def("get_inputs_name", &BatchNorm_Op<DIM>::getInputsName) .def("get_outputs_name", &BatchNorm_Op<DIM>::getOutputsName); - m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); + m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nbFeatures"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("name") = ""); } void init_BatchNorm(py::module &m) { -- GitLab