Skip to content
Snippets Groups Projects
Commit 50b66d9a authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Merge branch 'GridDepth' into 'dev'

Add support for GridSample + Minor changes

See merge request !235
parents 07c12bbb 081db324
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!235Add support for GridSample + Minor changes
Pipeline #57958 passed
......@@ -11,4 +11,6 @@ from aidge_core.aidge_core import * # import so generated by PyBind
import aidge_core.export_utils
import aidge_core.utils
from aidge_core.aidge_export_aidge import serialize_to_cpp
from aidge_core.show_graphview import gview_to_json
from aidge_core.mem_info import *
from ._version import *
......@@ -19,6 +19,8 @@ def generate_file(file_path: Union[Path, str], template_path: Union[Path, str],
file_path = Path(file_path)
if isinstance(template_path, str):
template_path = Path(template_path)
if not template_path.exists():
raise ValueError(f"Path to template {template_path} is not valid !")
# Make dir
file_path.parent.mkdir(parents=True, exist_ok=True)
......
......@@ -299,11 +299,15 @@ class ExportNodeCpp(ExportNode):
if self.config_template != "":
path_to_definition = f"{self.config_path}/{self.attributes['name']}.{self.config_extension}"
code_generation.generate_file(
str(export_folder / path_to_definition),
self.config_template,
**self.attributes
)
try:
code_generation.generate_file(
str(export_folder / path_to_definition),
self.config_template,
**self.attributes
)
except Exception as e:
raise RuntimeError(f"Error when creating config file for {self.node.name()}[{self.node.type()}].") from e
kernel_include_list.append(path_to_definition)
return self.include_list + kernel_include_list
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <pybind11/pybind11.h>
#include "aidge/backend/OperatorImpl.hpp"
#include "aidge/data/Tensor.hpp"
#include "aidge/operator/DepthToSpace.hpp"
#include "aidge/operator/OperatorTensor.hpp"
#include "aidge/utils/Attributes.hpp"
#include "aidge/utils/Types.h"
static typename Aidge::DepthToSpace_Op::Mode stringToMode(const std::string& mode) {
static std::unordered_map<std::string, typename Aidge::DepthToSpace_Op::Mode> map = {
{"DCR", Aidge::DepthToSpace_Op::Mode::DCR},
{"CRD", Aidge::DepthToSpace_Op::Mode::CRD}
};
return map[mode];
}
namespace py = pybind11;
namespace Aidge {
void declare_DepthToSpace(py::module &m) {
py::class_<DepthToSpace_Op, std::shared_ptr<DepthToSpace_Op>, OperatorTensor> (m, "DepthToSpaceOp", py::multiple_inheritance())
.def(py::init([](const std::uint32_t blockSize, const std::string& mode) {
return new DepthToSpace_Op(blockSize, stringToMode(mode));
}), py::arg("block_size"), py::arg("mode") = "CRD")
.def_static("get_inputs_name", &DepthToSpace_Op::getInputsName)
.def_static("get_outputs_name", &DepthToSpace_Op::getOutputsName)
.def_readonly_static("Type", &DepthToSpace_Op::Type)
.def("__repr__", [](DepthToSpace_Op& b) {
return fmt::format("Operator(type='{}')", b.Type);
});
declare_registrable<DepthToSpace_Op>(m, "DepthToSpaceOp");
m.def("DepthToSpace", [](
const std::uint32_t blockSize,
const std::string& mode,
const std::string& name) {
return DepthToSpace(blockSize, stringToMode(mode), name);
}, py::arg("block_size"), py::arg("mode") = "CRD", py::arg("name") = "");
}
void init_DepthToSpace(py::module &m) {
declare_DepthToSpace(m);
}
} // namespace Aidge
......@@ -55,7 +55,7 @@ void declare_GridSampleOp(py::module &m) {
return new GridSample_Op(stringToInterpolationMode(mode), stringToPaddingMode(padding_mode), align_corners);
}), py::arg("mode") = "linear",
py::arg("padding_mode") = "zeros",
py::arg("alogn_corners") = false)
py::arg("align_corners") = false)
.def_static("get_inputs_name", &GridSample_Op::getInputsName)
.def_static("get_outputs_name", &GridSample_Op::getOutputsName)
.def_readonly_static("Type", &GridSample_Op::Type)
......
......@@ -40,6 +40,7 @@ void init_Concat(py::module&);
void init_ConstantOfShape(py::module&);
void init_Conv(py::module&);
void init_ConvDepthWise(py::module&);
void init_DepthToSpace(py::module&);
void init_Div(py::module&);
void init_Erf(py::module&);
void init_FC(py::module&);
......@@ -126,6 +127,7 @@ void init_Aidge(py::module& m) {
init_Conv(m);
init_ConvDepthWise(m);
init_ConstantOfShape(m);
init_DepthToSpace(m);
init_Div(m);
init_Erf(m);
init_FC(m);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment