Skip to content
Snippets Groups Projects

upd according to core changes

Merged Maxence Naud requested to merge ui_parameters into main
2 files
+ 19
19
Compare changes
  • Side-by-side
  • Inline
Files
2
@@ -29,22 +29,22 @@ labels = np.load(gzip.GzipFile('assets/mnist_labels.npy.gz', "r"))
# --------------------------------------------------------------
# Create the Producer node
input_array = np.zeros(784).astype('float32')
input_array = np.zeros(784).astype('float32')
input_tensor = aidge_core.Tensor(input_array)
input_node = aidge_core.Producer(input_tensor, "X")
# Configuration for the inputs
input_node.get_operator().set_datatype(aidge_core.DataType.Float32)
input_node.get_operator().set_datatype(aidge_core.dtype.float32)
input_node.get_operator().set_backend("cpu")
# Link Producer to the Graph
input_node.add_child(aidge_model)
# Configuration for the model
aidge_model.set_datatype(aidge_core.DataType.Float32)
aidge_model.set_datatype(aidge_core.dtype.float32)
aidge_model.set_backend("cpu")
# Create the Scheduler
# Create the Scheduler
scheduler = aidge_core.SequentialScheduler(aidge_model)
# --------------------------------------------------------------
@@ -55,7 +55,7 @@ def propagate(model, scheduler, sample):
# Setup the input
input_tensor = aidge_core.Tensor(sample)
input_node.get_operator().set_output(0, input_tensor)
# Run the inference
# Run the inference
scheduler.forward(verbose=False)
# Gather the results
output_node = model.get_output_nodes().pop()
@@ -64,7 +64,7 @@ def propagate(model, scheduler, sample):
def bake_sample(sample):
sample = np.reshape(sample, (1, 1, 28, 28))
return sample.astype('float32')
return sample.astype('float32')
print('\n EXAMPLE INFERENCES :')
for i in range(10):
@@ -126,12 +126,12 @@ print('\n EXAMPLE QUANTIZED INFERENCES :')
for i in range(10):
input_array = bake_sample(samples[i])
output_array = propagate(aidge_model, scheduler, input_array)
print(labels[i] , ' -> ', np.round(output_array, 2))
print(labels[i] , ' -> ', np.round(output_array, 2))
# --------------------------------------------------------------
# COMPUTE THE MODEL ACCURACY
# --------------------------------------------------------------
accuracy = compute_accuracy(aidge_model, samples[0:NB_SAMPLES], labels)
print(f'\n QUANTIZED MODEL ACCURACY : {accuracy * 100:.3f}%')
Loading