Skip to content
Snippets Groups Projects
Commit d12fb1bd authored by Cyril Moineau's avatar Cyril Moineau
Browse files

Use aidge2c to convert datatype instead of aidge_datatype2ctype.

parent 105ac6e4
No related branches found
No related tags found
2 merge requests!17v0.1.0,!12v0.4.0
......@@ -9,8 +9,10 @@ from typing import Tuple, List, Union, Dict
import aidge_core
from aidge_core import ExportNode
from aidge_core.export_utils.code_generation import *
from aidge_core.export_utils.data_conversion import aidge2c
from aidge_export_arm_cortexm.utils import ROOT, operator_register
from aidge_export_arm_cortexm.utils.converter import numpy_dtype2ctype, aidge_datatype2dataformat, aidge_datatype2ctype
from aidge_export_arm_cortexm.utils.converter import numpy_dtype2ctype, aidge_datatype2dataformat, aidge2c
from aidge_export_arm_cortexm.utils.generation import *
##############################################
......@@ -210,7 +212,7 @@ class ReLU_ARMCortexM(ExportNode):
self.board = board
self.library = library
self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype())
self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
self.datatype = aidge2c(node.get_operator().get_output(0).dtype())
def export(self, export_folder:Path, list_configs:list):
......@@ -263,7 +265,7 @@ class Conv_ARMCortexM(ExportNode):
self.board = board
self.library = library
self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype())
self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
self.datatype = aidge2c(node.get_operator().get_output(0).dtype())
self.scaling = Scaling()("no_scaling")
self.activation = "Linear"
......@@ -386,7 +388,7 @@ class PaddedConv_ARMCortexM(Conv_ARMCortexM):
self.board = board
self.library = library
self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype())
self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
self.datatype = aidge2c(node.get_operator().get_output(0).dtype())
self.scaling = Scaling()("no_scaling")
self.activation = "Linear"
......@@ -420,13 +422,13 @@ class ConvReluScaling_ARMCortexM(Conv_ARMCortexM):
self.activation = "Rectifier"
# Should do this line but there is a bug while changing the datatype of generic operator...
# self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
# self.datatype = aidge2c(node.get_operator().get_output(0).dtype())
# Do this instead
if self.operator.get_attr("quantizedNbBits") == 8:
if self.operator.get_attr("isOutputUnsigned"):
self.datatype = aidge_datatype2ctype(aidge_core.DataType.UInt8)
self.datatype = aidge2c(aidge_core.DataType.UInt8)
else:
self.datatype = aidge_datatype2ctype(aidge_core.DataType.Int8)
self.datatype = aidge2c(aidge_core.DataType.Int8)
# Impose Single Shift (perhaps change it to have a more modular system)
self.scaling = Scaling(self.operator.get_attr("scalingFactor"),
......@@ -440,7 +442,7 @@ class Pooling_ARMCortexM(ExportNode):
self.board = board
self.library = library
self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype())
self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
self.datatype = aidge2c(node.get_operator().get_output(0).dtype())
self.pool_type = "None"
self.activation = "Linear"
......@@ -564,7 +566,7 @@ class FC_ARMCortexM(ExportNode):
self.board = board
self.library = library
self.dataformat = aidge_datatype2dataformat(node.get_operator().get_output(0).dtype())
self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
self.datatype = aidge2c(node.get_operator().get_output(0).dtype())
self.scaling = Scaling()("no_scaling")
self.activation = "Linear"
......@@ -675,13 +677,13 @@ class FCScaling_ARMCortexM(FC_ARMCortexM):
super(FC_ARMCortexM, self).__init__(node, board, library)
# Should do this line but there is a bug while changing the datatype of generic operator...
# self.datatype = aidge_datatype2ctype(node.get_operator().get_output(0).dtype())
# self.datatype = aidge2c(node.get_operator().get_output(0).dtype())
# Do this instead
if self.operator.get_attr("quantizedNbBits") == 8:
if self.operator.get_attr("isOutputUnsigned"):
self.datatype = aidge_datatype2ctype(aidge_core.DataType.UInt8)
self.datatype = aidge2c(aidge_core.DataType.UInt8)
else:
self.datatype = aidge_datatype2ctype(aidge_core.DataType.Int8)
self.datatype = aidge2c(aidge_core.DataType.Int8)
# Impose Single Shift (perhaps change it to have a more modular system)
self.scaling = Scaling(self.operator.get_attr("scalingFactor"),
......
......@@ -19,7 +19,7 @@ def numpy_dtype2ctype(dtype):
# Add more dtype mappings as needed
else:
raise ValueError(f"Unsupported {dtype} dtype")
def aidge_datatype2ctype(datatype):
if datatype == aidge_core.DataType.Int8:
......@@ -37,7 +37,7 @@ def aidge_datatype2ctype(datatype):
# Add more dtype mappings as needed
else:
raise ValueError(f"Unsupported {datatype} aidge datatype")
def aidge_datatype2dataformat(datatype):
if datatype == aidge_core.DataType.Int8:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment