Skip to content
Snippets Groups Projects
Commit 15a0329f authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Update export cpp to the new ExportLib object.

parent 73fa43b6
No related branches found
No related tags found
3 merge requests!27v0.2.0,!22v0.4.0,!15Export refactor
...@@ -2,6 +2,7 @@ r""" ...@@ -2,6 +2,7 @@ r"""
Aidge Export for CPP standalone projects Aidge Export for CPP standalone projects
""" """
from .export_registry import ExportLibCpp
from .operators import * from .operators import *
from collections import defaultdict from collections import defaultdict
......
...@@ -13,6 +13,7 @@ from aidge_core.export_utils.code_generation import * ...@@ -13,6 +13,7 @@ from aidge_core.export_utils.code_generation import *
from aidge_export_cpp.utils import (ROOT, OPERATORS_REGISTRY, supported_operators) from aidge_export_cpp.utils import (ROOT, OPERATORS_REGISTRY, supported_operators)
from aidge_export_cpp.utils.converter import numpy_dtype2ctype from aidge_export_cpp.utils.converter import numpy_dtype2ctype
import aidge_export_cpp.operators import aidge_export_cpp.operators
from aidge_export_cpp import ExportLibCpp
from aidge_export_cpp.utils.generation import * from aidge_export_cpp.utils.generation import *
from aidge_export_cpp.memory import * from aidge_export_cpp.memory import *
...@@ -50,9 +51,8 @@ def export(export_folder_name, graphview, scheduler, mem_wrapping=False): ...@@ -50,9 +51,8 @@ def export(export_folder_name, graphview, scheduler, mem_wrapping=False):
list_forward_nodes = scheduler.get_static_scheduling() list_forward_nodes = scheduler.get_static_scheduling()
for node in list_forward_nodes: for node in list_forward_nodes:
if node.type() in supported_operators(): if ExportLibCpp.exportable(node):
op = OPERATORS_REGISTRY[node.type()](node) op = ExportLibCpp.get_export_node(node)
# For configuration files # For configuration files
list_configs = op.export(dnn_folder, list_configs) list_configs = op.export(dnn_folder, list_configs)
......
...@@ -3,39 +3,39 @@ import shutil ...@@ -3,39 +3,39 @@ import shutil
import numpy as np import numpy as np
from pathlib import Path from pathlib import Path
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from aidge_core.export_utils import ExportNode, ExportNodeCpp, operator_register, generate_str, generate_file
from aidge_core import ExportNode, ExportNodeCpp from aidge_export_cpp.utils import ROOT
from aidge_core.export_utils.code_generation import *
from aidge_export_cpp.utils import ROOT, operator_register
from aidge_export_cpp.utils.converter import numpy_dtype2ctype from aidge_export_cpp.utils.converter import numpy_dtype2ctype
from aidge_export_cpp.utils.generation import * from aidge_export_cpp.utils.generation import *
from aidge_export_cpp import ExportLibCpp
############################################## ##############################################
################### Utils #################### ################### Utils ####################
############################################## ##############################################
def get_node_parents(node):
parents = []
for parent in node.get_parents():
if parent.type() != "Producer":
parents.append(parent)
return parents
def get_producer_parents(node): # def get_node_parents(node):
parents = [] # parents = []
for parent in node.get_parents(): # for parent in node.get_parents():
if parent.type() == "Producer": # if parent.type() != "Producer":
parents.append(parent) # parents.append(parent)
return parents # return parents
# def get_producer_parents(node):
# parents = []
# for parent in node.get_parents():
# if parent.type() == "Producer":
# parents.append(parent)
# return parents
############################################## ##############################################
############## Export functions ############## ############## Export functions ##############
############################################## ##############################################
def export_params(name:str, def export_params(name: str,
array: np.ndarray, array: np.ndarray,
filepath:str): filepath: str):
# Get directory name of the file # Get directory name of the file
dirname = os.path.dirname(filepath) dirname = os.path.dirname(filepath)
...@@ -47,9 +47,9 @@ def export_params(name:str, ...@@ -47,9 +47,9 @@ def export_params(name:str,
generate_file( generate_file(
filepath, filepath,
str(ROOT / "templates" / "data" / "parameters.jinja"), str(ROOT / "templates" / "data" / "parameters.jinja"),
name = name, name=name,
data_t = numpy_dtype2ctype(array.dtype), data_t=numpy_dtype2ctype(array.dtype),
values = array.tolist() values=array.tolist()
) )
...@@ -58,17 +58,17 @@ def export_params(name:str, ...@@ -58,17 +58,17 @@ def export_params(name:str,
############################################## ##############################################
@operator_register("Producer") @operator_register(ExportLibCpp, "Producer")
class ProducerCPP(ExportNode): class ProducerCPP(ExportNode):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.values = np.array(self.operator.get_output(0)) self.values = np.array(self.operator.get_output(0))
if len(self.values.shape) == 4: # Note: export in HWC if len(self.values.shape) == 4: # Note: export in HWC
self.values = np.transpose(self.values, (0, 2, 3, 1)) self.values = np.transpose(self.values, (0, 2, 3, 1))
def export(self, export_folder:Path, list_configs:list): def export(self, export_folder: Path, list_configs: list):
list_configs.append(f"parameters/{self.attributes['name']}.h") list_configs.append(f"parameters/{self.attributes['name']}.h")
export_params( export_params(
...@@ -78,26 +78,36 @@ class ProducerCPP(ExportNode): ...@@ -78,26 +78,36 @@ class ProducerCPP(ExportNode):
return list_configs return list_configs
def forward(self, list_actions:list): def forward(self, list_actions: list):
# A Producer does nothing during forward # A Producer does nothing during forward
return list_actions return list_actions
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register(ExportLibCpp, "ReLU")
@operator_register("ReLU")
class ReLUCPP(ExportNodeCpp): class ReLUCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.attributes["activation"] = "Rectifier" self.attributes["activation"] = "Rectifier"
self.attributes["rescaling"] = "NoScaling" self.attributes["rescaling"] = "NoScaling"
self.config_template = str(ROOT / "templates" / "configuration" / "activation_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "activation_forward.jinja") ROOT / "templates" / "configuration" / "activation_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "activation_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@staticmethod
def compatible(cls, node):
@operator_register("Conv") for idx, parent_node_in_id in node.inputs():
parent_node, _ = parent_node_in_id
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register(ExportLibCpp, "Conv")
class ConvCPP(ExportNodeCpp): class ConvCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
...@@ -106,84 +116,116 @@ class ConvCPP(ExportNodeCpp): ...@@ -106,84 +116,116 @@ class ConvCPP(ExportNodeCpp):
self.attributes["padding"] = [0, 0] self.attributes["padding"] = [0, 0]
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling" self.attributes["rescaling"] = "NoScaling"
self.config_template = str(ROOT / "templates" / "configuration" / "convolution_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja") ROOT / "templates" / "configuration" / "convolution_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "convolution.hpp"), str(ROOT / "kernels" / "convolution.hpp"),
str(ROOT / "kernels" / "macs.hpp"), str(ROOT / "kernels" / "macs.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register("PaddedConv") @operator_register(ExportLibCpp, "PaddedConv")
class PaddedConvCPP(ExportNodeCpp): class PaddedConvCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
# TODO find a way to retrive attr for meta op # TODO find a way to retrive attr for meta op
for n in self.operator.get_micro_graph().get_nodes(): for n in self.operator.get_micro_graph().get_nodes():
if n.type() == "Pad": if n.type() == "Pad":
self.attributes["padding"] = n.get_operator().attr.begin_end_borders self.attributes["padding"] = n.get_operator(
).attr.begin_end_borders
if n.type() == "Conv": if n.type() == "Conv":
self.attributes["kernel_dims"] = n.get_operator().attr.kernel_dims self.attributes["kernel_dims"] = n.get_operator(
self.attributes["stride_dims"] = n.get_operator().attr.stride_dims ).attr.kernel_dims
self.attributes["dilation_dims"] = n.get_operator().attr.dilation_dims self.attributes["stride_dims"] = n.get_operator(
).attr.stride_dims
self.attributes["dilation_dims"] = n.get_operator(
).attr.dilation_dims
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling" self.attributes["rescaling"] = "NoScaling"
self.config_template = str(ROOT / "templates" / "configuration" / "convolution_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja") ROOT / "templates" / "configuration" / "convolution_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "convolution_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "convolution.hpp"), str(ROOT / "kernels" / "convolution.hpp"),
str(ROOT / "kernels" / "macs.hpp"), str(ROOT / "kernels" / "macs.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@operator_register("Add") @classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register(ExportLibCpp, "Add")
class AddCPP(ExportNodeCpp): class AddCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.attributes["elemwise_op"] = "Add" self.attributes["elemwise_op"] = "Add"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling" self.attributes["rescaling"] = "NoScaling"
self.config_template = str(ROOT / "templates" / "configuration" / "elemwise_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja") ROOT / "templates" / "configuration" / "elemwise_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "elemwise.hpp"), str(ROOT / "kernels" / "elemwise.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register("Sub") @operator_register(ExportLibCpp, "Sub")
class SubCPP(ExportNodeCpp): class SubCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.attributes["elemwise_op"] = "Sub" self.attributes["elemwise_op"] = "Sub"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling" self.attributes["rescaling"] = "NoScaling"
self.config_template = str(ROOT / "templates" / "configuration" / "elemwise_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja") ROOT / "templates" / "configuration" / "elemwise_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "elemwise.hpp"), str(ROOT / "kernels" / "elemwise.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register("Mul")
@operator_register(ExportLibCpp, "Mul")
class MulCPP(ExportNodeCpp): class MulCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.attributes["elemwise_op"] = "Mul" self.attributes["elemwise_op"] = "Mul"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling" self.attributes["rescaling"] = "NoScaling"
self.config_template = str(ROOT / "templates" / "configuration" / "elemwise_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja") ROOT / "templates" / "configuration" / "elemwise_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "elemwise_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "elemwise.hpp"), str(ROOT / "kernels" / "elemwise.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register("MaxPooling") @operator_register(ExportLibCpp, "MaxPooling")
class MaxPoolCPP(ExportNodeCpp): class MaxPoolCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
...@@ -194,37 +236,50 @@ class MaxPoolCPP(ExportNodeCpp): ...@@ -194,37 +236,50 @@ class MaxPoolCPP(ExportNodeCpp):
self.attributes["pool_type"] = "Max" self.attributes["pool_type"] = "Max"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.config_template = str(ROOT / "templates" / "configuration" / "pooling_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja") ROOT / "templates" / "configuration" / "pooling_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "pooling.hpp"), str(ROOT / "kernels" / "pooling.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register("PaddedMaxPooling") @operator_register(ExportLibCpp, "PaddedMaxPooling")
class PaddedMaxPoolCPP(ExportNodeCpp): class PaddedMaxPoolCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
for n in self.operator.get_micro_graph().get_nodes(): for n in self.operator.get_micro_graph().get_nodes():
if n.type() == "Pad": if n.type() == "Pad":
self.attributes["padding"] = n.get_operator().attr.begin_end_borders self.attributes["padding"] = n.get_operator(
).attr.begin_end_borders
if n.type() == "MaxPooling": if n.type() == "MaxPooling":
self.attributes["kernel_dims"] = n.get_operator().attr.kernel_dims self.attributes["kernel_dims"] = n.get_operator(
self.attributes["stride_dims"] = n.get_operator().attr.stride_dims ).attr.kernel_dims
self.attributes["stride_dims"] = n.get_operator(
).attr.stride_dims
self.attributes["pool_type"] = "Max" self.attributes["pool_type"] = "Max"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.config_template = str(ROOT / "templates" / "configuration" / "pooling_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja") ROOT / "templates" / "configuration" / "pooling_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "pooling.hpp"), str(ROOT / "kernels" / "pooling.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register("GlobalAveragePooling") @operator_register(ExportLibCpp, "GlobalAveragePooling")
class GlobalAveragePoolCPP(ExportNodeCpp): class GlobalAveragePoolCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
...@@ -240,26 +295,35 @@ class GlobalAveragePoolCPP(ExportNodeCpp): ...@@ -240,26 +295,35 @@ class GlobalAveragePoolCPP(ExportNodeCpp):
self.attributes["pool_type"] = "Average" self.attributes["pool_type"] = "Average"
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.config_template = str(ROOT / "templates" / "configuration" / "pooling_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja") ROOT / "templates" / "configuration" / "pooling_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "pooling_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "pooling.hpp"), str(ROOT / "kernels" / "pooling.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
@operator_register(ExportLibCpp, "FC")
@operator_register("FC")
class FcCPP(ExportNodeCpp): class FcCPP(ExportNodeCpp):
def __init__(self, node): def __init__(self, node):
super().__init__(node) super().__init__(node)
self.attributes["activation"] = "Linear" self.attributes["activation"] = "Linear"
self.attributes["rescaling"] = "NoScaling" self.attributes["rescaling"] = "NoScaling"
self.config_template = str(ROOT / "templates" / "configuration" / "fullyconnected_config.jinja") self.config_template = str(
self.forward_template = str(ROOT / "templates" / "kernel_forward" / "fullyconnected_forward.jinja") ROOT / "templates" / "configuration" / "fullyconnected_config.jinja")
self.forward_template = str(
ROOT / "templates" / "kernel_forward" / "fullyconnected_forward.jinja")
self.include_list = [] self.include_list = []
self.kernels_to_copy = [ self.kernels_to_copy = [
str(ROOT / "kernels" / "fullyconnected.hpp"), str(ROOT / "kernels" / "fullyconnected.hpp"),
str(ROOT / "kernels" / "macs.hpp"), str(ROOT / "kernels" / "macs.hpp"),
str(ROOT / "kernels" / "activation.hpp"), str(ROOT / "kernels" / "activation.hpp"),
] ]
@classmethod
def exportable(cls, node):
return True # TODO add check i/o NCHW
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment