Skip to content
Snippets Groups Projects
Commit c617a212 authored by Benjamin Halimi's avatar Benjamin Halimi
Browse files

change the training flag type from int to bool

parent b71bd9c2
No related branches found
No related tags found
1 merge request!263Add the BatchNorm train/test flag support
Pipeline #60284 passed
......@@ -33,7 +33,7 @@ public:
static const std::string Type;
private:
using Attributes_ = StaticAttributes<BatchNormAttr, float, float, int>;
using Attributes_ = StaticAttributes<BatchNormAttr, float, float, bool>;
template <BatchNormAttr e>
using attr = typename Attributes_::template attr<e>;
const std::shared_ptr<Attributes_> mAttributes;
......@@ -42,7 +42,7 @@ public:
BatchNorm_Op() = delete;
constexpr BatchNorm_Op(float epsilon, float momentum, int trainingMode)
constexpr BatchNorm_Op(float epsilon, float momentum, bool trainingMode)
: OperatorTensor(Type,
{InputCategory::Data,
InputCategory::Param,
......@@ -86,7 +86,7 @@ public:
inline std::shared_ptr<Attributes> attributes() const override { return mAttributes; }
inline float& epsilon() const { return mAttributes->template getAttr<BatchNormAttr::Epsilon>(); }
inline float& momentum() const { return mAttributes->template getAttr<BatchNormAttr::Momentum>(); }
inline int& trainingMode() const { return mAttributes->template getAttr<BatchNormAttr::TrainingMode>(); }
inline bool& trainingMode() const { return mAttributes->template getAttr<BatchNormAttr::TrainingMode>(); }
static const std::vector<std::string> getInputsName() {
return {"data_input", "scale", "shift", "mean", "variance"};
......@@ -104,13 +104,13 @@ template <DimSize_t DIM>
std::shared_ptr<Node> BatchNorm(const DimSize_t nbFeatures,
const float epsilon = 1.0e-5F,
const float momentum = 0.1F,
const int trainingMode = 0,
const bool trainingMode = false,
const std::string& name = "");
} // namespace Aidge
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const int, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const int, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const int, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const bool, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const bool, const std::string&);
extern template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const bool, const std::string&);
namespace {
template <>
......
......@@ -26,7 +26,7 @@ void declare_BatchNormOp(py::module& m) {
const std::string pyClassName("BatchNorm" + std::to_string(DIM) + "DOp");
py::class_<BatchNorm_Op<DIM>, std::shared_ptr<BatchNorm_Op<DIM>>, OperatorTensor>(
m, pyClassName.c_str(), py::multiple_inheritance())
.def(py::init<float, float, int>(),
.def(py::init<float, float, bool>(),
py::arg("epsilon"),
py::arg("momentum"),
py::arg("training_mode"))
......@@ -36,7 +36,7 @@ void declare_BatchNormOp(py::module& m) {
declare_registrable<BatchNorm_Op<DIM>>(m, pyClassName);
m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("training_mode") = 0, py::arg("name") = "");
m.def(("BatchNorm" + std::to_string(DIM) + "D").c_str(), &BatchNorm<DIM>, py::arg("nb_features"), py::arg("epsilon") = 1.0e-5F, py::arg("momentum") = 0.1F, py::arg("training_mode") = false, py::arg("name") = "");
}
void init_BatchNorm(py::module &m) {
......
......@@ -108,7 +108,7 @@ template <Aidge::DimSize_t DIM>
inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFeatures,
const float epsilon,
const float momentum,
const int trainingMode,
const bool trainingMode,
const std::string& name) {
static_assert(DIM<=MaxDim,"Too many kernel dimensions required by BatchNorm, not supported");
auto batchNorm = std::make_shared<Node>(std::make_shared<BatchNorm_Op<static_cast<DimIdx_t>(DIM)>>(epsilon, momentum, trainingMode), name);
......@@ -119,6 +119,6 @@ inline std::shared_ptr<Aidge::Node> Aidge::BatchNorm(const Aidge::DimSize_t nbFe
return batchNorm;
}
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const int, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const int, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const int, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<2>(const DimSize_t, const float, const float, const bool, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<3>(const DimSize_t, const float, const float, const bool, const std::string&);
template std::shared_ptr<Aidge::Node> Aidge::BatchNorm<4>(const DimSize_t, const float, const float, const bool, const std::string&);
......@@ -352,11 +352,11 @@ TEST_CASE("[core/graph] Matching") {
auto g2 = Sequential({
Producer({16, 3, 512, 512}, "dataProvider"),
Conv(3, 4, {5, 5}, "conv1"),
BatchNorm<2>(4, 1.0e-5, 0.1, 0, "bn1"),
BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn1"),
Conv(4, 4, {5, 5}, "conv2"),
ReLU("relu2"),
Conv(4, 4, {5, 5}, "conv3"),
BatchNorm<2>(4, 1.0e-5, 0.1, 0, "bn3"),
BatchNorm<2>(4, 1.0e-5, 0.1, false, "bn3"),
FC(4, 4, false, "fc1"),
FC(4, 4, false, "fc2"),
FC(4, 4, false, "fc3"),
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment