diff --git a/aidge_export_arm_cortexm/operators.py b/aidge_export_arm_cortexm/operators.py index b0f4f60b68d7ca414c50b821551e7664d05ca27a..eb4603252c17f676b21e96b0d4fe05f9824e683d 100644 --- a/aidge_export_arm_cortexm/operators.py +++ b/aidge_export_arm_cortexm/operators.py @@ -9,8 +9,10 @@ from typing import Tuple, List, Union, Dict import aidge_core from aidge_core import ExportNode from aidge_core.export_utils.code_generation import * +from aidge_core.export_utils.data_conversion import aidge2c + from aidge_export_arm_cortexm.utils import ROOT, operator_register -from aidge_export_arm_cortexm.utils.converter import numpy_dtype2ctype, aidge_datatype2dataformat, aidge_datatype2ctype +from aidge_export_arm_cortexm.utils.converter import numpy_dtype2ctype, aidge_datatype2dataformat, aidge2c from aidge_export_arm_cortexm.utils.generation import * ############################################## @@ -210,7 +212,7 @@ class ReLU_ARMCortexM(ExportNode): self.board = board self.library = library self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype()) - self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + self.datatype = aidge2c(node.get_operator().get_output(0).dtype()) def export(self, export_folder:Path, list_configs:list): @@ -263,7 +265,7 @@ class Conv_ARMCortexM(ExportNode): self.board = board self.library = library self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype()) - self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + self.datatype = aidge2c(node.get_operator().get_output(0).dtype()) self.scaling = Scaling()("no_scaling") self.activation = "Linear" @@ -386,7 +388,7 @@ class PaddedConv_ARMCortexM(Conv_ARMCortexM): self.board = board self.library = library self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype()) - self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + self.datatype = aidge2c(node.get_operator().get_output(0).dtype()) self.scaling = Scaling()("no_scaling") self.activation = "Linear" @@ -420,13 +422,13 @@ class ConvReluScaling_ARMCortexM(Conv_ARMCortexM): self.activation = "Rectifier" # Should do this line but there is a bug while changing the datatype of generic operator... - # self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + # self.datatype = aidge2c(node.get_operator().get_output(0).dtype()) # Do this instead if self.operator.get_attr("quantizedNbBits") == 8: if self.operator.get_attr("isOutputUnsigned"): - self.datatype = aidge_datatype2ctype(aidge_core.DataType.UInt8) + self.datatype = aidge2c(aidge_core.DataType.UInt8) else: - self.datatype = aidge_datatype2ctype(aidge_core.DataType.Int8) + self.datatype = aidge2c(aidge_core.DataType.Int8) # Impose Single Shift (perhaps change it to have a more modular system) self.scaling = Scaling(self.operator.get_attr("scalingFactor"), @@ -440,7 +442,7 @@ class Pooling_ARMCortexM(ExportNode): self.board = board self.library = library self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype()) - self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + self.datatype = aidge2c(node.get_operator().get_output(0).dtype()) self.pool_type = "None" self.activation = "Linear" @@ -564,7 +566,7 @@ class FC_ARMCortexM(ExportNode): self.board = board self.library = library self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype()) - self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + self.datatype = aidge2c(node.get_operator().get_output(0).dtype()) self.scaling = Scaling()("no_scaling") self.activation = "Linear" @@ -675,13 +677,13 @@ class FCScaling_ARMCortexM(FC_ARMCortexM): super(FC_ARMCortexM, self).__init__(node, board, library) # Should do this line but there is a bug while changing the datatype of generic operator... - # self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype()) + # self.datatype = aidge2c(node.get_operator().get_output(0).dtype()) # Do this instead if self.operator.get_attr("quantizedNbBits") == 8: if self.operator.get_attr("isOutputUnsigned"): - self.datatype = aidge_datatype2ctype(aidge_core.DataType.UInt8) + self.datatype = aidge2c(aidge_core.DataType.UInt8) else: - self.datatype = aidge_datatype2ctype(aidge_core.DataType.Int8) + self.datatype = aidge2c(aidge_core.DataType.Int8) # Impose Single Shift (perhaps change it to have a more modular system) self.scaling = Scaling(self.operator.get_attr("scalingFactor"), diff --git a/aidge_export_arm_cortexm/utils/converter.py b/aidge_export_arm_cortexm/utils/converter.py index 426aa69d782417c699f790568fd5881b15ecdb1a..08b8599fd5c8f77e9a862401a87b7251773d9ba1 100644 --- a/aidge_export_arm_cortexm/utils/converter.py +++ b/aidge_export_arm_cortexm/utils/converter.py @@ -19,7 +19,7 @@ def numpy_dtype2ctype(dtype): # Add more dtype mappings as needed else: raise ValueError(f"Unsupported {dtype} dtype") - + def aidge_datatype2ctype(datatype): if datatype == aidge_core.DataType.Int8: @@ -37,7 +37,7 @@ def aidge_datatype2ctype(datatype): # Add more dtype mappings as needed else: raise ValueError(f"Unsupported {datatype} aidge datatype") - + def aidge_datatype2dataformat(datatype): if datatype == aidge_core.DataType.Int8: