diff --git a/include/aidge/operator/Scaling.hpp b/include/aidge/operator/Scaling.hpp index 8f54ab217631ac69a4e16555f8e58f550ab0156c..c864bd045d8a5a1fc5f4ee591d1d81fcaf241bac 100644 --- a/include/aidge/operator/Scaling.hpp +++ b/include/aidge/operator/Scaling.hpp @@ -27,9 +27,10 @@ enum class ScalingAttr { scalingFactor, quantizedNbBits, isOutputUnsigned }; -class Scaling_Op : public OperatorTensor, - public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, - public StaticAttributes<ScalingAttr, float, size_t, bool> { +class Scaling_Op + : public OperatorTensor, + public Registrable<Scaling_Op, std::string, std::shared_ptr<OperatorImpl>(const Scaling_Op&)>, + public StaticAttributes<ScalingAttr, float, size_t, bool> { public: static const std::string Type; @@ -84,7 +85,11 @@ inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::stri return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name); } */ -inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") { +inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, + std::size_t quantizedNbBits=8, + bool isOutputUnsigned=true, + const std::string& name = "") +{ return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name); } } // namespace Aidge diff --git a/python_binding/operator/pybind_Scaling.cpp b/python_binding/operator/pybind_Scaling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f091ea70f9b5e9927e535bd527cd84cf081d9823 --- /dev/null +++ b/python_binding/operator/pybind_Scaling.cpp @@ -0,0 +1,32 @@ +/******************************************************************************** + * Copyright (c) 2024 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ + +#include <pybind11/pybind11.h> + +#include "aidge/data/Tensor.hpp" +#include "aidge/operator/Scaling.hpp" +#include "aidge/operator/OperatorTensor.hpp" + +namespace py = pybind11; + +namespace Aidge { + +void init_Scaling(py::module& m) +{ + py::class_<Scaling_Op, std::shared_ptr<Scaling_Op>, Attributes, OperatorTensor>(m, "ScalingOp", py::multiple_inheritance()) + .def("get_inputs_name", &Scaling_Op::getInputsName) + .def("get_outputs_name", &Scaling_Op::getOutputsName) + .def("attributes_name", &Scaling_Op::staticGetAttrsName); + declare_registrable<Scaling_Op>(m, "ScalingOp"); + m.def("Scaling", &Scaling, py::arg("scaling_factor") = 1.0f, py::arg("nb_bits") = 8, py::arg("is_output_unsigned") = true, py::arg("name") = ""); +} + +} // namespace Aidge diff --git a/python_binding/pybind_core.cpp b/python_binding/pybind_core.cpp index 63e5100ac65b5582c7236c2b3467a7d1debcaa36..f12ab25bf60fb32fb3b91a59997007fd2e266e5d 100644 --- a/python_binding/pybind_core.cpp +++ b/python_binding/pybind_core.cpp @@ -51,6 +51,7 @@ void init_Pow(py::module&); void init_ReduceMean(py::module&); void init_ReLU(py::module&); void init_Reshape(py::module&); +void init_Scaling(py::module&); void init_Sigmoid(py::module&); void init_Slice(py::module&); void init_Softmax(py::module&); @@ -117,6 +118,7 @@ void init_Aidge(py::module& m) { init_ReduceMean(m); init_ReLU(m); init_Reshape(m); + init_Scaling(m); init_Sigmoid(m); init_Slice(m); init_Softmax(m); diff --git a/src/operator/Scaling.cpp b/src/operator/Scaling.cpp index 8b0d6f9db698e36d232dec38fd8cdd0fad5f8c59..6aee038d7985b812f748e6bdcf0d48b4cf20eba3 100644 --- a/src/operator/Scaling.cpp +++ b/src/operator/Scaling.cpp @@ -20,7 +20,7 @@ const std::string Aidge::Scaling_Op::Type = "Scaling"; -void Aidge::Scaling_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { - mImpl = Registrar<Scaling_Op>::create(name)(*this); +void Aidge::Scaling_Op::setBackend(const std::string& name, DeviceIdx_t device) { + SET_IMPL_MACRO(Scaling_Op, *this, name); mOutputs[0]->setBackend(name, device); } \ No newline at end of file