diff --git a/include/aidge/operator/Memorize.hpp b/include/aidge/operator/Memorize.hpp index a1d90f06f098eb7fa2fc199b595991702daf488a..0a47d0df1976d4c53698e5119b7eb362811c0703 100644 --- a/include/aidge/operator/Memorize.hpp +++ b/include/aidge/operator/Memorize.hpp @@ -37,7 +37,7 @@ public: enum class MemorizeAttr { ScheduleStep, ForwardStep, EndStep }; class Memorize_Op : public OperatorTensor, - public Registrable<Memorize_Op, std::string, std::unique_ptr<OperatorImpl>(const Memorize_Op&)> { + public Registrable<Memorize_Op, std::string, std::shared_ptr<OperatorImpl>(const Memorize_Op&)> { public: static const std::string Type; diff --git a/python_binding/operator/pybind_Memorize.cpp b/python_binding/operator/pybind_Memorize.cpp new file mode 100644 index 0000000000000000000000000000000000000000..3ac1122111aae1a9b7eb353399e46562ae51b0b1 --- /dev/null +++ b/python_binding/operator/pybind_Memorize.cpp @@ -0,0 +1,33 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> +#include <string> +#include <vector> + +#include "aidge/operator/Memorize.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; +namespace Aidge { + +void init_Memorize(py::module& m) { + py::class_<Memorize_Op, std::shared_ptr<Memorize_Op>, OperatorTensor>(m, "MemorizeOp", py::multiple_inheritance()) + .def(py::init<const std::uint32_t>(), py::arg("end_step")) + .def_static("get_inputs_name", &Memorize_Op::getInputsName) + .def_static("get_outputs_name", &Memorize_Op::getOutputsName); + + declare_registrable<Memorize_Op>(m, "MemorizeOp"); + + m.def("Memorize", &Memorize, py::arg("end_step"), py::arg("name") = ""); +} + +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 616a8424b1f9df7e52a8af485b1bb82235f66a2f..fdedb7bd2d42077944b7ed48ba21c2e1ae45ff0e 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -47,6 +47,7 @@ void init_Identity(py::module&); void init_LeakyReLU(py::module&); void init_MatMul(py::module&); void init_MaxPooling(py::module&); +void init_Memorize(py::module&); void init_MetaOperatorDefs(py::module&); void init_Mul(py::module&); void init_Pad(py::module&); @@ -126,6 +127,7 @@ void init_Aidge(py::module& m) { init_LeakyReLU(m); init_MatMul(m); init_MaxPooling(m); + init_Memorize(m); init_MetaOperatorDefs(m); init_Mul(m); init_Pad(m);