diff --git a/aidge_quantization/unit_tests/aidge_ptq.py b/aidge_quantization/unit_tests/aidge_ptq.py index 73b5553af3ef00ae6380bf7c735bfb24f85ed7b9..507a24b382fc915665a0fc82005420fd2af21dc5 100644 --- a/aidge_quantization/unit_tests/aidge_ptq.py +++ b/aidge_quantization/unit_tests/aidge_ptq.py @@ -29,22 +29,22 @@ labels = np.load(gzip.GzipFile('assets/mnist_labels.npy.gz', "r")) # -------------------------------------------------------------- # Create the Producer node -input_array = np.zeros(784).astype('float32') +input_array = np.zeros(784).astype('float32') input_tensor = aidge_core.Tensor(input_array) input_node = aidge_core.Producer(input_tensor, "X") # Configuration for the inputs -input_node.get_operator().set_datatype(aidge_core.DataType.Float32) +input_node.get_operator().set_datatype(aidge_core.dtype.float32) input_node.get_operator().set_backend("cpu") # Link Producer to the Graph input_node.add_child(aidge_model) # Configuration for the model -aidge_model.set_datatype(aidge_core.DataType.Float32) +aidge_model.set_datatype(aidge_core.dtype.float32) aidge_model.set_backend("cpu") -# Create the Scheduler +# Create the Scheduler scheduler = aidge_core.SequentialScheduler(aidge_model) # -------------------------------------------------------------- @@ -55,7 +55,7 @@ def propagate(model, scheduler, sample): # Setup the input input_tensor = aidge_core.Tensor(sample) input_node.get_operator().set_output(0, input_tensor) - # Run the inference + # Run the inference scheduler.forward(verbose=False) # Gather the results output_node = model.get_output_nodes().pop() @@ -64,7 +64,7 @@ def propagate(model, scheduler, sample): def bake_sample(sample): sample = np.reshape(sample, (1, 1, 28, 28)) - return sample.astype('float32') + return sample.astype('float32') print('\n EXAMPLE INFERENCES :') for i in range(10): @@ -126,12 +126,12 @@ print('\n EXAMPLE QUANTIZED INFERENCES :') for i in range(10): input_array = bake_sample(samples[i]) output_array = propagate(aidge_model, scheduler, input_array) - print(labels[i] , ' -> ', np.round(output_array, 2)) + print(labels[i] , ' -> ', np.round(output_array, 2)) # -------------------------------------------------------------- # COMPUTE THE MODEL ACCURACY # -------------------------------------------------------------- - + accuracy = compute_accuracy(aidge_model, samples[0:NB_SAMPLES], labels) print(f'\n QUANTIZED MODEL ACCURACY : {accuracy * 100:.3f}%') diff --git a/src/QuantPTQ.cpp b/src/QuantPTQ.cpp index 9930156a669d4e64306533759a7cd8c48a694715..ef28d72f4ffc037415bebd86c8ade16d8a4554a9 100644 --- a/src/QuantPTQ.cpp +++ b/src/QuantPTQ.cpp @@ -430,7 +430,7 @@ void normalizeParameters(std::shared_ptr<GraphView> graphView) std::shared_ptr<Node> scalingNode = getPreviousScalingNode(mergingNode); std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (scalingNode->getOperator()); - scalingOperator->getAttr<float>("scalingFactor") /= rescaling; + scalingOperator->scalingFactor() /= rescaling; accumulatedRatios[mergingNode->name()] /= rescaling; // optional ... } } @@ -493,7 +493,7 @@ std::map<std::string, float> computeRanges(std::shared_ptr<GraphView> graphView, valueRanges[nodeName] = sampleRanges[nodeName]; } } - } + } return valueRanges; } @@ -534,10 +534,10 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st float prevScalingFactor = scalingFactors[prevNode->name()]; // XXX HERE : valueRanges must contains all the scaling nodes !!! - float scalingFactor = valueRanges[node->name()]; + float scalingFactor = valueRanges[node->name()]; std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (node->getOperator()); - scalingOperator->getAttr<float>("scalingFactor") /= (scalingFactor / prevScalingFactor); + scalingOperator->scalingFactor() /= (scalingFactor / prevScalingFactor); scalingFactors[node->name()] = scalingFactor; @@ -584,7 +584,7 @@ void normalizeActivations(std::shared_ptr<GraphView> graphView, std::map<std::st //Log::info(" SCALING NODE : {} {}", scalingNode->type(), scalingNode->name()); std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (scalingNode->getOperator()); - scalingOperator->getAttr<float>("scalingFactor") *= rescaling; + scalingOperator->scalingFactor() *= rescaling; } } } @@ -620,8 +620,8 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ std::shared_ptr<Node> scalingNode = *(node->getChildren().begin()); std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (scalingNode->getOperator()); - scalingOperator->getAttr<float>("scalingFactor") /= signedMax; - scalingOperator->getAttr<std::size_t>("quantizedNbBits") = nbBits; + scalingOperator->scalingFactor() /= signedMax; + scalingOperator->quantizedNbBits() = nbBits; } } @@ -631,7 +631,7 @@ void quantizeNormalizedNetwork(std::shared_ptr<GraphView> graphView, std::uint8_ if (node->type() == "Scaling") { std::shared_ptr<Scaling_Op> scalingOperator = std::static_pointer_cast<Scaling_Op> (node->getOperator()); - scalingOperator->getAttr<std::size_t>("quantizedNbBits") = nbBits; // XXX HERE !!! + scalingOperator->quantizedNbBits() = nbBits; // XXX HERE !!! } } } @@ -703,7 +703,7 @@ float computeBestClipping(std::vector<int> histogram, std::uint8_t nbBits) int signedMax = (1 << (nbBits - 1)) - 1; // Compute the cumulative approximation error : - // At each iteration we test a clipping candidate and loop over + // At each iteration we test a clipping candidate and loop over // the histogram to accumulate the squared error std::vector<float> clippingErrors; @@ -716,7 +716,7 @@ float computeBestClipping(std::vector<int> histogram, std::uint8_t nbBits) { float value = (bin + 0.5) / nbBins; float scaling = signedMax / clipping; - float rounded = std::round(value * scaling) / scaling; + float rounded = std::round(value * scaling) / scaling; float clipped = std::min(clipping, rounded); float approxError = (clipped - value); @@ -736,7 +736,7 @@ float computeBestClipping(std::vector<int> histogram, std::uint8_t nbBits) bestClipping = it / static_cast<float> (nbIter); minClippingError = clippingErrors[it]; } - + return bestClipping; }