Skip to content
Snippets Groups Projects
Commit a449cf97 authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Update aidge_export to new export_node + aidge_export -> serialize_to_cpp

parent ddd14df0
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!163Export refactor
Showing
with 253 additions and 196 deletions
...@@ -12,4 +12,4 @@ from aidge_core.aidge_core import * # import so generated by PyBind ...@@ -12,4 +12,4 @@ from aidge_core.aidge_core import * # import so generated by PyBind
import aidge_core.export_utils import aidge_core.export_utils
import aidge_core.utils import aidge_core.utils
import aidge_core.aidge_export_aidge from aidge_core.aidge_export_aidge import serialize_to_cpp
...@@ -5,4 +5,4 @@ FILE = Path(__file__).resolve() ...@@ -5,4 +5,4 @@ FILE = Path(__file__).resolve()
ROOT_EXPORT = FILE.parents[0] ROOT_EXPORT = FILE.parents[0]
from .operator_export import * from .operator_export import *
from .export import export from .export import serialize_to_cpp
...@@ -2,15 +2,14 @@ import aidge_core ...@@ -2,15 +2,14 @@ import aidge_core
import shutil import shutil
import os import os
from pathlib import Path from pathlib import Path
from .utils import supported_operators, OPERATORS_REGISTRY
from . import ROOT_EXPORT
import aidge_core.export_utils
from . import ROOT_EXPORT
from aidge_core.aidge_export_aidge.registry import ExportSerialize
from aidge_core.export_utils import ExportNode, generate_file from aidge_core.export_utils import ExportNode, generate_file
def serialize_to_cpp(export_folder: str,
def export(export_folder: str,
graph_view: aidge_core.GraphView, graph_view: aidge_core.GraphView,
enable_python_binding: bool = True, enable_python_binding: bool = True,
): ):
...@@ -59,7 +58,6 @@ def export(export_folder: str, ...@@ -59,7 +58,6 @@ def export(export_folder: str,
open_nodes = list(graph_view.get_input_nodes()) open_nodes = list(graph_view.get_input_nodes())
# List of Aidge nodes already explored # List of Aidge nodes already explored
closed_nodes = [] closed_nodes = []
while open_nodes: while open_nodes:
node = open_nodes.pop(0) node = open_nodes.pop(0)
if node in closed_nodes: if node in closed_nodes:
...@@ -81,21 +79,22 @@ def export(export_folder: str, ...@@ -81,21 +79,22 @@ def export(export_folder: str,
# Next nodes to treat are children of current node # Next nodes to treat are children of current node
open_nodes += list(node.get_children()) open_nodes += list(node.get_children())
if node.type() in supported_operators(): if not ExportSerialize.exportable(node):
set_operator.add(node.type()) #raise RuntimeError
op = OPERATORS_REGISTRY[node.type()](node) print(f"Node {node.name()} (of type [{node.type()}]) is not exportable !")
op = ExportSerialize.get_export_node(node)(node)
set_operator.add(node.type())
# TODO: list_configs and list_actions don't need to be passed by argument # TODO: list_configs and list_actions don't need to be passed by argument
# Export the configuration # Export the configuration
list_configs = op.export(export_folder_path, list_configs) list_configs += op.export(export_folder_path)
# Add forward kernel # Add forward kernel
list_actions = op.forward(list_actions) list_actions += op.forward()
else:
raise RuntimeError(f"Operator: {node.type()} is not supported")
closed_nodes.append(node) closed_nodes.append(node)
# Generate full dnn.cpp # Generate full dnn.cpp
aidge_core.generate_file( aidge_core.export_utils.generate_file(
export_folder_path / "src/dnn.cpp", export_folder_path / "src/dnn.cpp",
ROOT_EXPORT / "templates/dnn.jinja", ROOT_EXPORT / "templates/dnn.jinja",
headers=list_configs, headers=list_configs,
......
from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input from aidge_core.aidge_export_aidge.registry import ExportSerialize
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core.export_utils import ExportNode, generate_file, generate_str from aidge_core.export_utils import ExportNodeCpp, operator_register
from pathlib import Path
@operator_register("Conv") @operator_register(ExportSerialize, "Conv")
class Conv(ExportNode): class Conv(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.config_template = str(
ROOT_EXPORT / "templates/attributes/conv.jinja")
self.forward_template = str(
ROOT_EXPORT /"templates/graph_ctor/conv.jinja")
self.include_list = []
self.kernels_to_copy = []
self.config_path = "include/parameters"
self.config_extension = "hpp"
@classmethod
def exportable(cls, node):
return True
def export(self, export_folder:Path, list_configs:list): # def export(self, export_folder:Path, list_configs:list):
include_path = f"attributes/{self.name}.hpp" # include_path = f"attributes/{self.name}.hpp"
filepath = export_folder / f"include/{include_path}" # filepath = export_folder / f"include/{include_path}"
generate_file( # generate_file(
filepath, # filepath,
ROOT_EXPORT / "templates/attributes/conv.jinja", # ROOT_EXPORT / "templates/attributes/conv.jinja",
name=self.name, # name=self.name,
**self.attributes # **self.attributes
) # )
list_configs.append(include_path) # list_configs.append(include_path)
return list_configs # return list_configs
def forward(self, list_actions:list): # def forward(self, list_actions:list):
list_actions.append(generate_str( # list_actions.append(generate_str(
ROOT_EXPORT /"templates/graph_ctor/conv.jinja", # ROOT_EXPORT /"templates/graph_ctor/conv.jinja",
name=self.name, # name=self.name,
inputs=parse_node_input(self.node.inputs()), # inputs=parse_node_input(self.node.inputs()),
**self.attributes # **self.attributes
)) # ))
return list_actions # return list_actions
from aidge_core.aidge_export_aidge.utils import operator_register,parse_node_input
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core.export_utils import ExportNode, generate_file, generate_str from aidge_core.aidge_export_aidge.registry import ExportSerialize
from aidge_core.export_utils import ExportNodeCpp, operator_register
from pathlib import Path from pathlib import Path
@operator_register("FC") @operator_register(ExportSerialize, "FC")
class FC(ExportNode): class FC(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.config_template = str(
ROOT_EXPORT / "templates/attributes/fc.jinja")
def export(self, export_folder:Path, list_configs:list): self.forward_template = str(
ROOT_EXPORT / "templates/graph_ctor/fc.jinja")
self.include_list = []
include_path = f"attributes/{self.name}.hpp" self.kernels_to_copy = []
filepath = export_folder / f"include/{include_path}" self.config_path = "include/parameters"
self.config_extension = "hpp"
@classmethod
generate_file( def exportable(cls, node):
filepath, return True
ROOT_EXPORT / "templates/attributes/fc.jinja",
name=self.name, # def export(self, export_folder:Path, list_configs:list):
InChannels=self.inputs_dims[1][1],
OutChannels=self.operator.out_channels(),
**self.attributes # include_path = f"attributes/{self.name}.hpp"
) # filepath = export_folder / f"include/{include_path}"
list_configs.append(include_path)
return list_configs
# generate_file(
def forward(self, list_actions:list): # filepath,
list_actions.append(generate_str( # ROOT_EXPORT / "templates/attributes/fc.jinja",
ROOT_EXPORT / "templates/graph_ctor/fc.jinja", # name=self.name,
name=self.name, # InChannels=self.inputs_dims[1][1],
inputs=parse_node_input(self.node.inputs()), # OutChannels=self.operator.out_channels(),
**self.attributes # **self.attributes
)) # )
return list_actions # list_configs.append(include_path)
# return list_configs
# def forward(self, list_actions:list):
# list_actions.append(generate_str(
# ROOT_EXPORT / "templates/graph_ctor/fc.jinja",
# name=self.name,
# inputs=parse_node_input(self.node.inputs()),
# **self.attributes
# ))
# return list_actions
from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input from aidge_core.aidge_export_aidge.registry import ExportSerialize
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core.export_utils import ExportNode, generate_file, generate_str from aidge_core.export_utils import ExportNodeCpp, operator_register
from pathlib import Path
@operator_register("MaxPooling") @operator_register(ExportSerialize,"MaxPooling")
class MaxPooling(ExportNode): class MaxPooling(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.config_template = str(
ROOT_EXPORT / "templates/attributes/maxpooling.jinja")
self.forward_template = str(
ROOT_EXPORT / "templates/graph_ctor/maxpooling.jinja")
self.include_list = []
self.kernels_to_copy = []
self.config_path = "include/parameters"
self.config_extension = "hpp"
@classmethod
def exportable(cls, node):
return True
# def export(self, export_folder:Path, list_configs:list):
# include_path = f"attributes/{self.name}.hpp"
# filepath = export_folder / f"include/{include_path}"
def export(self, export_folder:Path, list_configs:list): # generate_file(
include_path = f"attributes/{self.name}.hpp" # filepath,
filepath = export_folder / f"include/{include_path}" # ROOT_EXPORT / "templates/attributes/maxpooling.jinja",
# name=self.name,
# **self.attributes
# )
# list_configs.append(include_path)
# return list_configs
generate_file( # def forward(self, list_actions:list):
filepath, # list_actions.append(generate_str(
ROOT_EXPORT / "templates/attributes/maxpooling.jinja", # ROOT_EXPORT / "templates/graph_ctor/maxpooling.jinja",
name=self.name, # name=self.name,
**self.attributes # inputs=parse_node_input(self.node.inputs()),
) # **self.attributes
list_configs.append(include_path) # ))
return list_configs # return list_actions
def forward(self, list_actions:list):
list_actions.append(generate_str(
ROOT_EXPORT / "templates/graph_ctor/maxpooling.jinja",
name=self.name,
inputs=parse_node_input(self.node.inputs()),
**self.attributes
))
return list_actions
from aidge_core.aidge_export_aidge.utils import operator_register
from aidge_core.export_utils.data_conversion import aidge2c
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from aidge_core.export_utils import ExportNode, generate_file, generate_str from aidge_core.aidge_export_aidge.registry import ExportSerialize
from aidge_core.export_utils import ExportNodeCpp, operator_register
import numpy as np import numpy as np
from pathlib import Path
@operator_register("Producer") @operator_register(ExportSerialize, "Producer")
class Producer(ExportNode): class Producer(ExportNodeCpp):
""" """
If there is a standardization of the export operators If there is a standardization of the export operators
then this class should be just a inheritance of ProducerCPP then this class should be just a inheritance of ProducerCPP
...@@ -16,33 +13,47 @@ class Producer(ExportNode): ...@@ -16,33 +13,47 @@ class Producer(ExportNode):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
child, in_idx = self.node.output(0)[0] child, in_idx = self.node.output(0)[0]
self.tensor_name = f"{child.name()}_{in_idx}"
self.values = np.array(self.operator.get_output(0)) self.values = np.array(self.operator.get_output(0))
def export(self, export_folder:Path, list_configs:list): self.config_template = str(
assert(len(self.node.output(0)) == 1) ROOT_EXPORT / "templates/parameter.jinja")
self.forward_template = str(
include_path = f"parameters/{self.tensor_name}.hpp" ROOT_EXPORT / "templates/graph_ctor/producer.jinja")
filepath = export_folder / f"include/{include_path}" self.attributes["tensor_name"] = f"{child.name()}_{in_idx}"
self.attributes["values"] = str(self.operator.get_output(0))
aidge_tensor = self.operator.get_output(0) self.include_list = []
datatype = aidge2c(aidge_tensor.dtype()) self.kernels_to_copy = []
generate_file( self.config_path = "include/parameters"
filepath, self.config_extension = "hpp"
ROOT_EXPORT / "templates/parameter.jinja", @classmethod
dims = aidge_tensor.dims(), def exportable(cls, node):
data_t = datatype, return True
name = self.tensor_name,
values = str(aidge_tensor) # def export(self, export_folder:Path, list_configs:list):
) # assert(len(self.node.output(0)) == 1)
list_configs.append(include_path)
return list_configs # include_path = f"parameters/{self.tensor_name}.hpp"
# filepath = export_folder / f"include/{include_path}"
def forward(self, list_actions:list):
list_actions.append(generate_str( # aidge_tensor = self.operator.get_output(0)
ROOT_EXPORT / "templates/graph_ctor/producer.jinja", # datatype = aidge2c(aidge_tensor.dtype())
name=self.name, # generate_file(
tensor_name=self.tensor_name, # filepath,
**self.attributes # ROOT_EXPORT / "templates/parameter.jinja",
)) # dims = aidge_tensor.dims(),
return list_actions # data_t = datatype,
# name = self.tensor_name,
# values = str(aidge_tensor)
# )
# list_configs.append(include_path)
# return list_configs
# def forward(self, list_actions:list):
# list_actions.append(generate_str(
# ROOT_EXPORT / "templates/graph_ctor/producer.jinja",
# name=self.name,
# tensor_name=self.tensor_name,
# **self.attributes
# ))
# return list_actions
from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input from aidge_core.aidge_export_aidge.registry import ExportSerialize
from aidge_core.export_utils import ExportNode, generate_str
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from pathlib import Path from aidge_core.export_utils import ExportNodeCpp, operator_register
@operator_register("ReLU") @operator_register(ExportSerialize, "ReLU")
class ReLU(ExportNode): class ReLU(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.config_template = ""
self.forward_template = str(
ROOT_EXPORT / "templates/graph_ctor/relu.jinja")
self.include_list = []
self.kernels_to_copy = []
@classmethod
def exportable(cls, node):
return True
# def export(self, export_folder:Path, list_configs:list):
# return list_configs
def export(self, export_folder:Path, list_configs:list): # def forward(self, list_actions:list):
return list_configs # list_actions.append(generate_str(
# ROOT_EXPORT / "templates/graph_ctor/relu.jinja",
def forward(self, list_actions:list): # name=self.name,
list_actions.append(generate_str( # inputs=parse_node_input(self.node.inputs()),
ROOT_EXPORT / "templates/graph_ctor/relu.jinja", # **self.attributes
name=self.name, # ))
inputs=parse_node_input(self.node.inputs()), # return list_actions
**self.attributes
))
return list_actions
from aidge_core.aidge_export_aidge.utils import operator_register, parse_node_input from aidge_core.aidge_export_aidge.registry import ExportSerialize
from aidge_core.export_utils import ExportNode, generate_str
from aidge_core.aidge_export_aidge import ROOT_EXPORT from aidge_core.aidge_export_aidge import ROOT_EXPORT
from pathlib import Path from aidge_core.export_utils import ExportNodeCpp, operator_register
@operator_register("Sub") @operator_register(ExportSerialize, "Sub")
class Sub(ExportNode): class Sub(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.config_template = ""
self.forward_template = str(
ROOT_EXPORT / "templates/graph_ctor/sub.jinja")
self.include_list = []
self.kernels_to_copy = []
def export(self, export_folder:Path, list_configs:list): @classmethod
return list_configs def exportable(cls, node):
return True
def forward(self, list_actions:list): # def export(self, export_folder:Path, list_configs:list):
list_actions.append(generate_str( # return list_configs
ROOT_EXPORT / "templates/graph_ctor/sub.jinja",
name=self.name, # def forward(self, list_actions:list):
inputs=parse_node_input(self.node.inputs()), # list_actions.append(generate_str(
**self.attributes # ROOT_EXPORT / "templates/graph_ctor/sub.jinja",
)) # name=self.name,
return list_actions # inputs=parse_node_input(self.node.inputs()),
# **self.attributes
# ))
# return list_actions
from aidge_core.export_utils import ExportLib
from . import ROOT_EXPORT
class ExportSerialize(ExportLib):
name="export_serialize"
#ifndef EXPORT_ATTRIBUTES_{{name|upper}}_H #ifndef EXPORT_ATTRIBUTES_{{name|upper}}_H
#define EXPORT_ATTRIBUTES_{{name|upper}}_H #define EXPORT_ATTRIBUTES_{{name|upper}}_H
#define _{{name|upper}}_IN_CHANNELS {{InChannels}} #define _{{name|upper}}_IN_CHANNELS {{in_chan[0]}}
#define _{{name|upper}}_OUT_CHANNELS {{OutChannels}} #define _{{name|upper}}_OUT_CHANNELS {{out_chan[0]}}
{% for i in range(KernelDims|length) %} {% for i in range(kernel_dims|length) %}
#define _{{name|upper}}_KERNEL_{{i}} {{KernelDims[i]}} #define _{{name|upper}}_KERNEL_{{i}} {{kernel_dims[i]}}
{%- endfor %} {%- endfor %}
{% for i in range(StrideDims|length) %} {% for i in range(stride_dims|length) %}
#define _{{name|upper}}_STRIDE_{{i}} {{StrideDims[i]}} #define _{{name|upper}}_STRIDE_{{i}} {{stride_dims[i]}}
{%- endfor %} {%- endfor %}
{% for i in range(DilationDims|length) %} {% for i in range(dilation_dims|length) %}
#define _{{name|upper}}_DILATION_{{i}} {{DilationDims[i]}} #define _{{name|upper}}_DILATION_{{i}} {{dilation_dims[i]}}
{%- endfor %} {%- endfor %}
#endif /* EXPORT_ATTRIBUTES_{{name|upper}}_H */ #endif /* EXPORT_ATTRIBUTES_{{name|upper}}_H */
#ifndef EXPORT_ATTRIBUTES_{{name|upper}}_H #ifndef EXPORT_ATTRIBUTES_{{name|upper}}_H
#define EXPORT_ATTRIBUTES_{{name|upper}}_H #define EXPORT_ATTRIBUTES_{{name|upper}}_H
#define _{{name|upper}}_IN_CHANNELS {{InChannels}} #define _{{name|upper}}_IN_CHANNELS {{in_chan[0]}}
#define _{{name|upper}}_OUT_CHANNELS {{OutChannels}} #define _{{name|upper}}_OUT_CHANNELS {{out_chan[0]}}
#endif /* EXPORT_ATTRIBUTES_{{name|upper}}_H */ #endif /* EXPORT_ATTRIBUTES_{{name|upper}}_H */
#ifndef EXPORT_ATTRIBUTES_{{name|upper}}_H #ifndef EXPORT_ATTRIBUTES_{{name|upper}}_H
#define EXPORT_ATTRIBUTES_{{name|upper}}_H #define EXPORT_ATTRIBUTES_{{name|upper}}_H
{% for i in range(KernelDims|length) %} {% for i in range(kernel_dims|length) %}
#define _{{name|upper}}_KERNEL_{{i}} {{KernelDims[i]}} #define _{{name|upper}}_KERNEL_{{i}} {{kernel_dims[i]}}
{%- endfor %} {%- endfor %}
{% for i in range(StrideDims|length) %} {% for i in range(stride_dims|length) %}
#define _{{name|upper}}_STRIDE_{{i}} {{StrideDims[i]}} #define _{{name|upper}}_STRIDE_{{i}} {{stride_dims[i]}}
{%- endfor %} {%- endfor %}
#define _{{name|upper}}_CEIL_MODE {{CeilMode|int}} #define _{{name|upper}}_CEIL_MODE {{ceil_mode|int}}
#endif /* EXPORT_ATTRIBUTES_{{name|upper}}_H */ #endif /* EXPORT_ATTRIBUTES_{{name|upper}}_H */
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
/*** OPERATOR ATTRIBUTES & PARAMETERS ***/ /*** OPERATOR ATTRIBUTES & PARAMETERS ***/
{%- for header in headers %} {%- for header in headers %}
#include "{{ header }}" #include "{{ header | replace('include/', '') }}"
{%- endfor %} {%- endfor %}
/*** HEADER ***/ /*** HEADER ***/
......
{# NOTE: Trying a shorter notation like {%- for input in inputs if input[0] %} {# NOTE: Trying a shorter notation like {%- for input in inputs if input[0] %}
will mess up loop.index as the input set up at None will not increment ! #} will mess up loop.index as the input set up at None will not increment ! #}
{%- for input in inputs %} {%- for input_node, out_id in node.inputs() %}
{%- if input[0] %} {%- if input_node %}
{{input[0]}}->addChild({{name}}, {{input[1]}}, {{loop.index - 1}}); {# NOTE: loop.index begin at 1 #} {{input_node.name()}}->addChild({{name}}, {{out_id}}, {{loop.index - 1}}); {# NOTE: loop.index begin at 1 #}
{%- endif %} {%- endif %}
{%- endfor %} {%- endfor %}
...@@ -5,18 +5,18 @@ std::shared_ptr<Aidge::Node> {{name}} = ...@@ -5,18 +5,18 @@ std::shared_ptr<Aidge::Node> {{name}} =
_{{name|upper}}_IN_CHANNELS, _{{name|upper}}_IN_CHANNELS,
_{{name|upper}}_OUT_CHANNELS, _{{name|upper}}_OUT_CHANNELS,
{ {
{%- for i in range(KernelDims|length) -%} {%- for i in range(kernel_dims|length) -%}
_{{name|upper}}_KERNEL_{{i}}{%- if not loop.last %}, {% endif -%} _{{name|upper}}_KERNEL_{{i}}{%- if not loop.last %}, {% endif -%}
{%- endfor -%} {%- endfor -%}
}, },
"{{name}}", "{{name}}",
{ {
{%- for i in range(StrideDims|length) -%} {%- for i in range(stride_dims|length) -%}
_{{name|upper}}_STRIDE_{{i}} {%- if not loop.last %}, {% endif -%} _{{name|upper}}_STRIDE_{{i}} {%- if not loop.last %}, {% endif -%}
{%- endfor -%} {%- endfor -%}
}, },
{ {
{%- for i in range(DilationDims|length) -%} {%- for i in range(dilation_dims|length) -%}
_{{name|upper}}_DILATION_{{i}} {%- if not loop.last %}, {% endif -%} _{{name|upper}}_DILATION_{{i}} {%- if not loop.last %}, {% endif -%}
{%- endfor -%} {%- endfor -%}
} }
......
...@@ -3,13 +3,13 @@ ...@@ -3,13 +3,13 @@
std::shared_ptr<Aidge::Node> {{name}} = std::shared_ptr<Aidge::Node> {{name}} =
Aidge::MaxPooling( Aidge::MaxPooling(
{ {
{%- for i in range(KernelDims|length) -%} {%- for i in range(kernel_dims|length) -%}
_{{name|upper}}_KERNEL_{{i}}{%- if not loop.last %}, {% endif -%} _{{name|upper}}_KERNEL_{{i}}{%- if not loop.last %}, {% endif -%}
{%- endfor -%} {%- endfor -%}
}, },
"{{name}}", "{{name}}",
{ {
{%- for i in range(StrideDims|length) -%} {%- for i in range(stride_dims|length) -%}
_{{name|upper}}_STRIDE_{{i}} {%- if not loop.last %}, {% endif -%} _{{name|upper}}_STRIDE_{{i}} {%- if not loop.last %}, {% endif -%}
{%- endfor -%} {%- endfor -%}
}, },
......
...@@ -4,7 +4,7 @@ ...@@ -4,7 +4,7 @@
#include <aidge/data/Tensor.hpp> #include <aidge/data/Tensor.hpp>
#include <memory> #include <memory>
std::shared_ptr<Aidge::Tensor> {{name}} = std::make_shared<Aidge::Tensor>(Aidge::Array{{dims|length}}D<{{data_t}}, {{ dims|join(", ") }}> { std::shared_ptr<Aidge::Tensor> {{tensor_name}} = std::make_shared<Aidge::Tensor>(Aidge::Array{{out_dims[0]|length}}D<{{out_cdtype[0]}}, {{ out_dims[0]|join(", ") }}> {
{{ values }} {{ values }}
}); });
......
from .node_export import ExportNode, ExportNodeCpp from .node_export import ExportNode, ExportNodeCpp
from .code_generation import generate_file, generate_str, copy_file from .code_generation import generate_file, generate_str, copy_file
from .export_registry import ExportLib, operator_register from .export_registry import ExportLib, operator_register
from .scheduler_export import ExportScheduler from .scheduler_export import scheduler_export
...@@ -26,9 +26,6 @@ class ExportLib(): # Should be abstract ? ...@@ -26,9 +26,6 @@ class ExportLib(): # Should be abstract ?
_language: LANGUAGE = None _language: LANGUAGE = None
_compilo:str = None _compilo:str = None
def __init__(self) -> None: def __init__(self) -> None:
raise RuntimeError("ExportLib should not be instanciated") raise RuntimeError("ExportLib should not be instanciated")
@classmethod @classmethod
...@@ -63,7 +60,7 @@ class ExportLib(): # Should be abstract ? ...@@ -63,7 +60,7 @@ class ExportLib(): # Should be abstract ?
:rtype: ExportNode :rtype: ExportNode
""" """
if not cls.exportable(node): if not cls.exportable(node):
raise ValueError(f"Node {node.type()} is not exportable by ExportLib {cls._name} !") raise ValueError(f"Node {node.type()} is not exportable by ExportLib {cls.name} !")
if len(cls._export_node_registry[node.type()]) != 1: if len(cls._export_node_registry[node.type()]) != 1:
raise RuntimeError("ExportLib registry doesn't support when multiple export node are available yet ...") raise RuntimeError("ExportLib registry doesn't support when multiple export node are available yet ...")
else: else:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment