Skip to content
Snippets Groups Projects
Commit 58736e3a authored by Maxence Naud's avatar Maxence Naud
Browse files

Merge remote-tracking branch 'origin/dev' into fix_gather_and_slice

parents aa10278f 1ffaf28c
No related branches found
No related tags found
No related merge requests found
......@@ -8,4 +8,5 @@ http://www.eclipse.org/legal/epl-2.0.
SPDX-License-Identifier: EPL-2.0
"""
from aidge_core.aidge_core import * # import so generated by PyBind
from aidge_core.export import ExportNode
from aidge_core.export import ExportNode, generate_file, generate_str
import aidge_core.utils
from .node_export import *
from .code_generation import *
import os
from jinja2 import Environment, FileSystemLoader
def generate_file(file_path: str, template_path: str, **kwargs) -> None:
"""Generate a file at `file_path` using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param file_path: path where to generate the file
:type file_path: str
:param template_path: Path to the template to use for code generation
:type template_path: str
"""
# Get directory name of the file
dirname = os.path.dirname(file_path)
# If directory doesn't exist, create it
if not os.path.exists(dirname):
os.makedirs(dirname)
# Get directory name and name of the template
template_dir = os.path.dirname(template_path)
template_name = os.path.basename(template_path)
# Select template
template = Environment(loader=FileSystemLoader(
template_dir)).get_template(template_name)
# Generate file
content = template.render(kwargs)
with open(file_path, mode="w", encoding="utf-8") as message:
message.write(content)
def generate_str(template_path:str, **kwargs) -> str:
"""Generate a string using the jinja template located at `file_path`.
kwargs are used to fill the template.
:param template_path: Path to the template to use for code generation
:type template_path: str
:return: A string of the interpreted template
:rtype: str
"""
dirname = os.path.dirname(template_path)
filename = os.path.basename(template_path)
template = Environment(loader=FileSystemLoader(dirname)).get_template(filename)
return template.render(kwargs)
def template_docstring(template_keyword, text_to_replace):
"""Method to template docstring
:param template: Template keyword to replace, in the documentation you template word must be between `{` `}`
:type template: str
:param text_to_replace: Text to replace your template with.
:type text_to_replace: str
"""
def dec(func):
if "{"+template_keyword+"}" not in func.__doc__:
raise RuntimeError(
f"The function {function.__name__} docstring does not contain the template keyword: {template_keyword}.")
func.__doc__ = func.__doc__.replace(
"{"+template_keyword+"}", text_to_replace)
return func
return dec
......@@ -27,9 +27,10 @@ enum class ScalingAttr {
scalingFactor, quantizedNbBits, isOutputUnsigned
};
class Scaling_Op : public OperatorTensor,
public Registrable<Scaling_Op, std::string, std::unique_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float, size_t, bool> {
class Scaling_Op
: public OperatorTensor,
public Registrable<Scaling_Op, std::string, std::shared_ptr<OperatorImpl>(const Scaling_Op&)>,
public StaticAttributes<ScalingAttr, float, size_t, bool> {
public:
static const std::string Type;
......@@ -84,7 +85,11 @@ inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, const std::stri
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor), name);
}
*/
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f, std::size_t quantizedNbBits=8, bool isOutputUnsigned=true, const std::string& name = "") {
inline std::shared_ptr<Node> Scaling(float scalingFactor = 1.0f,
std::size_t quantizedNbBits=8,
bool isOutputUnsigned=true,
const std::string& name = "")
{
return std::make_shared<Node>(std::make_shared<Scaling_Op>(scalingFactor,quantizedNbBits, isOutputUnsigned), name);
}
} // namespace Aidge
......
/********************************************************************************
* Copyright (c) 2024 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/Scaling.hpp"
#include "aidge/operator/OperatorTensor.hpp"
namespace py = pybind11;
namespace Aidge {
void init_Scaling(py::module& m)
{
py::class_<Scaling_Op, std::shared_ptr<Scaling_Op>, Attributes, OperatorTensor>(m, "ScalingOp", py::multiple_inheritance())
.def("get_inputs_name", &Scaling_Op::getInputsName)
.def("get_outputs_name", &Scaling_Op::getOutputsName)
.def("attributes_name", &Scaling_Op::staticGetAttrsName);
declare_registrable<Scaling_Op>(m, "ScalingOp");
m.def("Scaling", &Scaling, py::arg("scaling_factor") = 1.0f, py::arg("nb_bits") = 8, py::arg("is_output_unsigned") = true, py::arg("name") = "");
}
} // namespace Aidge
......@@ -51,6 +51,7 @@ void init_Pow(py::module&);
void init_ReduceMean(py::module&);
void init_ReLU(py::module&);
void init_Reshape(py::module&);
void init_Scaling(py::module&);
void init_Sigmoid(py::module&);
void init_Slice(py::module&);
void init_Softmax(py::module&);
......@@ -117,6 +118,7 @@ void init_Aidge(py::module& m) {
init_ReduceMean(m);
init_ReLU(m);
init_Reshape(m);
init_Scaling(m);
init_Sigmoid(m);
init_Slice(m);
init_Softmax(m);
......
This diff is collapsed.
......@@ -21,6 +21,6 @@
const std::string Aidge::Scaling_Op::Type = "Scaling";
void Aidge::Scaling_Op::setBackend(const std::string& name, Aidge::DeviceIdx_t device) {
mImpl = Registrar<Scaling_Op>::create(name)(*this);
SET_IMPL_MACRO(Scaling_Op, *this, name);
mOutputs[0]->setBackend(name, device);
}
\ No newline at end of file
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment