Skip to content
Snippets Groups Projects
Commit a0ce08b3 authored by Houssem ROUIS's avatar Houssem ROUIS
Browse files

add default seed constexpr for Dropout

parent b3c72b7f
No related branches found
No related tags found
No related merge requests found
Pipeline #82008 passed
...@@ -44,9 +44,11 @@ public: ...@@ -44,9 +44,11 @@ public:
static constexpr const char* const Type = "Dropout"; static constexpr const char* const Type = "Dropout";
static constexpr const char* const InputsName[] = {"data_input", "probability", "training_mode"}; static constexpr const char* const InputsName[] = {"data_input", "probability", "training_mode"};
static constexpr const char* const OutputsName[] = {"data_output", "mask"}; static constexpr const char* const OutputsName[] = {"data_output", "mask"};
using maskDType = cpptype_t<DataType::Boolean>; using maskDType = cpptype_t<DataType::Boolean>;
static constexpr std::int64_t DEFAULT_SEED = std::numeric_limits<std::int64_t>::lowest();
Dropout_Op(float probability = 0.5f, std::int64_t seed = std::numeric_limits<std::int64_t>::lowest(), bool trainingMode = false); Dropout_Op(float probability = 0.5f, std::int64_t seed = DEFAULT_SEED, bool trainingMode = false);
Dropout_Op(const Dropout_Op& op); Dropout_Op(const Dropout_Op& op);
...@@ -80,7 +82,7 @@ private: ...@@ -80,7 +82,7 @@ private:
}; };
// Function to create a Dropout node // Function to create a Dropout node
std::shared_ptr<Node> Dropout(float probability = 0.5f, std::int64_t seed = std::numeric_limits<std::int64_t>::lowest(), bool trainingMode = false, const std::string& name = ""); std::shared_ptr<Node> Dropout(float probability = 0.5f, std::int64_t seed = Dropout_Op::DEFAULT_SEED, bool trainingMode = false, const std::string& name = "");
} // namespace Aidge } // namespace Aidge
......
...@@ -23,7 +23,7 @@ void init_Dropout(py::module& m) { ...@@ -23,7 +23,7 @@ void init_Dropout(py::module& m) {
// Binding for Dropout operator class // Binding for Dropout operator class
py::class_<Dropout_Op, std::shared_ptr<Dropout_Op>, OperatorTensor>( py::class_<Dropout_Op, std::shared_ptr<Dropout_Op>, OperatorTensor>(
m, "DropoutOp", py::multiple_inheritance()) m, "DropoutOp", py::multiple_inheritance())
.def(py::init<float, std::int64_t, bool>(), py::arg("probability") = 0.5f, py::arg("seed") = std::numeric_limits<std::int64_t>::lowest(), py::arg("training_mode") = false) .def(py::init<float, std::int64_t, bool>(), py::arg("probability") = 0.5f, py::arg("seed") = Dropout_Op::DEFAULT_SEED, py::arg("training_mode") = false)
.def_static("get_inputs_name", []() { .def_static("get_inputs_name", []() {
return std::vector<std::string>(std::begin(Dropout_Op::InputsName), std::end(Dropout_Op::InputsName)); return std::vector<std::string>(std::begin(Dropout_Op::InputsName), std::end(Dropout_Op::InputsName));
}, "Get the names of the input tensors.") }, "Get the names of the input tensors.")
...@@ -39,6 +39,6 @@ void init_Dropout(py::module& m) { ...@@ -39,6 +39,6 @@ void init_Dropout(py::module& m) {
declare_registrable<Dropout_Op>(m, "DropoutOp"); declare_registrable<Dropout_Op>(m, "DropoutOp");
// Function to create a Dropout node // Function to create a Dropout node
m.def("Dropout", &Dropout, py::arg("probability") = 0.5f, py::arg("seed") = std::numeric_limits<std::int64_t>::lowest(), py::arg("training_mode") = false, py::arg("name") = ""); m.def("Dropout", &Dropout, py::arg("probability") = 0.5f, py::arg("seed") = Dropout_Op::DEFAULT_SEED, py::arg("training_mode") = false, py::arg("name") = "");
} }
} // namespace Aidge } // namespace Aidge
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment