Skip to content

[core] FuseBatchnorm does not fuse all the BatchNorms

Context

I'm attempting to quantize a MobileNetV2 from the ONNX zoo.

In order to perform the PTQ, the model must be prepared, and all the BatchNorm layers must be removed (to be more specific, fused with the Conv layers).

Issue

The problem is that while most of the BatchNorm layers are removed, some of them remain in the model.

It appears that the remaining BatchNorms are the ones that have multiple child nodes.

I suspect that the issue comes from the replace function, of which I already had some trouble with, when used with nodes that have forking outputs.

This issue might unveil a way larger problem, as the replace function is used in a lot of situations.

Code

To reproduce the problem, one can download the model here :

https://github.com/onnx/models/blob/main/validated/vision/classification/mobilenet/model/mobilenetv2-7.onnx

And use the following snippet :

# Load the model
aidge_model = aidge_onnx.load_onnx("mobilenetv2-7.onnx", verbose=False)

# Prepare the model
aidge_core.remove_flatten(aidge_model)
aidge_core.fuse_batchnorm(aidge_model) # <- HERE

# Count the nb of remaining batchnorms
nb_batchnorms = sum((n.type() == 'BatchNorm') for n in aidge_model.get_nodes())
print(' NB BATCHNORMS : ', nb_batchnorms)

EDIT : It appears that this has nothing to do with the replace function, but it is in fact related to the GraphRex ...

Edited by Benjamin Halimi