Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
pybind_SGD.cpp 970 B
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <pybind11/pybind11.h>

#include "aidge/learning/optimizer/Optimizer.hpp"
#include "aidge/learning/optimizer/SGD.hpp"

namespace py = pybind11;
namespace Aidge {
// namespace learning {

void init_SGD(py::module& m) {
    py::class_<SGD, std::shared_ptr<SGD>, Attributes, Optimizer>(m, "SGD", py::multiple_inheritance())
    .def(py::init<float, float, float>(), py::arg("momentum") = 0.0f, py::arg("dampening") = 0.0f, py::arg("weight_decay") = 0.0f)
    .def("update", &SGD::update);
}
// }  // namespace learning
}  // namespace Aidge