Skip to content
Snippets Groups Projects
Commit dddfdc17 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'pybind_scaling' into 'dev'

Add Scaling python binding

See merge request !110
parents 73269f1c 68805eb5
No related branches found
No related tags found
2 merge requests!1190.2.1,!110Add Scaling python binding
Pipeline #44371 passed
...@@ -27,9 +27,10 @@ enum class ScalingAttr { ...@@ -27,9 +27,10 @@ enum class ScalingAttr {
scalingFactor, quantizedNbBits, isOutputUnsigned scalingFactor, quantizedNbBits, isOutputUnsigned
}; };
class Scaling_Op : public OperatorTensor, class Scaling_Op
public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>, : public OperatorTensor,
public StaticAttributes<ScalingAttr, float, size_t, bool> { public Registrable<Scaling_Op, std::string, std::shared_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float, size_t, bool> {
public: public:
static const std::string Type; static const std::string Type;
...@@ -84,7 +85,11 @@ inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::stri ...@@ -84,7 +85,11 @@ inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::stri
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name); return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name);
} }
*/ */
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") { inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f,
std::size_t quantizedNbBits=8,
bool isOutputUnsigned=true,
const std::string& name = "")
{
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name); return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name);
} }
} // namespace Aidge } // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Scaling(py::module& m)
{
py::class_<Scaling_Op, std::shared_ptr<Scaling_Op>, Attributes, OperatorTensor>(m, "ScalingOp", py::multiple_inheritance())
.def("get_inputs_name", &Scaling_Op::getInputsName)
.def("get_outputs_name", &Scaling_Op::getOutputsName)
.def("attributes_name", &Scaling_Op::staticGetAttrsName);
declare_registrable<Scaling_Op>(m, "ScalingOp");
m.def("Scaling", &Scaling, py::arg("scaling_factor") = 1.0f, py::arg("nb_bits") = 8, py::arg("is_output_unsigned") = true, py::arg("name") = "");
}
} // namespace Aidge
...@@ -51,6 +51,7 @@ void init_Pow(py::module&); ...@@ -51,6 +51,7 @@ void init_Pow(py::module&);
void init_ReduceMean(py::module&); void init_ReduceMean(py::module&);
void init_ReLU(py::module&); void init_ReLU(py::module&);
void init_Reshape(py::module&); void init_Reshape(py::module&);
void init_Scaling(py::module&);
void init_Sigmoid(py::module&); void init_Sigmoid(py::module&);
void init_Slice(py::module&); void init_Slice(py::module&);
void init_Softmax(py::module&); void init_Softmax(py::module&);
...@@ -117,6 +118,7 @@ void init_Aidge(py::module& m) { ...@@ -117,6 +118,7 @@ void init_Aidge(py::module& m) {
init_ReduceMean(m); init_ReduceMean(m);
init_ReLU(m); init_ReLU(m);
init_Reshape(m); init_Reshape(m);
init_Scaling(m);
init_Sigmoid(m); init_Sigmoid(m);
init_Slice(m); init_Slice(m);
init_Softmax(m); init_Softmax(m);
......
...@@ -21,6 +21,6 @@ ...@@ -21,6 +21,6 @@
const std::string Aidge::Scaling_Op::Type = "Scaling"; const std::string Aidge::Scaling_Op::Type = "Scaling";
void Aidge::Scaling_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) { void Aidge::Scaling_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
mImpl = Registrar<Scaling_Op>::create(name)(*this); SET_IMPL_MACRO(Scaling_Op, *this, name);
mOutputs[0]->setBackend(name, device); mOutputs[0]->setBackend(name, device);
} }
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment