Skip to content
Snippets Groups Projects

[Upd] standardization of some files

Merged Maxence Naud requested to merge small-standard-upd into dev
16 files
+ 259
176
Compare changes
  • Side-by-side
  • Inline
Files
16
@@ -21,7 +21,7 @@ ACCURACIES = (95.4, 94.4) # (97.9, 97.7)
@@ -21,7 +21,7 @@ ACCURACIES = (95.4, 94.4) # (97.9, 97.7)
NB_BITS = 4
NB_BITS = 4
# --------------------------------------------------------------
# --------------------------------------------------------------
# UTILS
# UTILS
# --------------------------------------------------------------
# --------------------------------------------------------------
def propagate(model, scheduler, sample):
def propagate(model, scheduler, sample):
@@ -50,7 +50,7 @@ def compute_accuracy(model, samples, labels):
@@ -50,7 +50,7 @@ def compute_accuracy(model, samples, labels):
# --------------------------------------------------------------
# --------------------------------------------------------------
class test_ptq(unittest.TestCase):
class test_ptq(unittest.TestCase):
def setUp(self):
def setUp(self):
# load the samples / labels (numpy)
# load the samples / labels (numpy)
@@ -70,19 +70,20 @@ class test_ptq(unittest.TestCase):
@@ -70,19 +70,20 @@ class test_ptq(unittest.TestCase):
def tearDown(self):
def tearDown(self):
pass
pass
def test_model(self):
def test_model(self):
Log.set_console_level(Level.Info)
Log.set_console_level(Level.Info)
# compute the base accuracy
# compute the base accuracy
accuracy = compute_accuracy(self.model, self.samples[0:NB_SAMPLES], self.labels)
accuracy = compute_accuracy(self.model, self.samples[0:NB_SAMPLES], self.labels)
self.assertAlmostEqual(accuracy * 100, ACCURACIES[0], msg='base accuracy does not meet the baseline !', delta=0.1)
self.assertAlmostEqual(accuracy * 100, ACCURACIES[0], msg='base accuracy does not meet the baseline !', delta=0.1)
def test_quant_model(self):
def test_quant_model(self):
Log.set_console_level(Level.Info)
Log.set_console_level(Level.Debug)
# create the calibration dataset
# create the calibration dataset
 
tensors = []
tensors = []
for sample in self.samples[0:NB_SAMPLES]:
for sample in self.samples[0:NB_SAMPLES]:
sample = prepare_sample(sample)
sample = prepare_sample(sample)
@@ -91,14 +92,13 @@ class test_ptq(unittest.TestCase):
@@ -91,14 +92,13 @@ class test_ptq(unittest.TestCase):
# quantize the model
# quantize the model
aidge_quantization.quantize_network(
aidge_quantization.quantize_network(
self.model,
self.model,
NB_BITS,
NB_BITS,
tensors,
tensors,
clipping_mode=aidge_quantization.Clipping.MSE,
clipping_mode=aidge_quantization.Clipping.MSE,
no_quantization=False,
no_quantization=False,
optimize_signs=True,
optimize_signs=True,
single_shift=False
single_shift=False
)
)
Loading