Skip to content
Snippets Groups Projects

UPD: version 0.3.0 -> 0.3.1

Merged Maxence Naud requested to merge dev into main
37 files
+ 900
568
Compare changes
  • Side-by-side
  • Inline
Files
37
@@ -21,7 +21,7 @@ ACCURACIES = (95.4, 94.4) # (97.9, 97.7)
NB_BITS = 4
# --------------------------------------------------------------
# UTILS
# UTILS
# --------------------------------------------------------------
def propagate(model, scheduler, sample):
@@ -50,7 +50,7 @@ def compute_accuracy(model, samples, labels):
# --------------------------------------------------------------
class test_ptq(unittest.TestCase):
def setUp(self):
# load the samples / labels (numpy)
@@ -70,19 +70,20 @@ class test_ptq(unittest.TestCase):
def tearDown(self):
pass
def test_model(self):
Log.set_console_level(Level.Info)
# compute the base accuracy
accuracy = compute_accuracy(self.model, self.samples[0:NB_SAMPLES], self.labels)
self.assertAlmostEqual(accuracy * 100, ACCURACIES[0], msg='base accuracy does not meet the baseline !', delta=0.1)
def test_quant_model(self):
Log.set_console_level(Level.Info)
Log.set_console_level(Level.Debug)
# create the calibration dataset
tensors = []
for sample in self.samples[0:NB_SAMPLES]:
sample = prepare_sample(sample)
@@ -91,14 +92,13 @@ class test_ptq(unittest.TestCase):
# quantize the model
aidge_quantization.quantize_network(
self.model,
NB_BITS,
tensors,
clipping_mode=aidge_quantization.Clipping.MSE,
self.model,
NB_BITS,
tensors,
clipping_mode=aidge_quantization.Clipping.MSE,
no_quantization=False,
optimize_signs=True,
optimize_signs=True,
single_shift=False
)
Loading