FuseBatchNorm onyl work for 2D
The current fuseBatchNorm only works for the 2D case.
This is of course because of the GraphRegex but also in the inner logic of the code:
// TODO: Find a way to remove the template
// A feature map with 2 dimensions is assumed
const std::shared_ptr<BatchNorm_Op<2>> batchOp =
std::static_pointer_cast<BatchNorm_Op<2>>(batchnormNode->getOperator());
I think that removing the templating on the number of dimension for operators has proposed in: aidge_onnx#72 (closed) could solve this issue!
Another solution would be to templatize the `fuseBatchNorm``and match different regex:
auto matches = SinglePassGraphMatching(graphView).match("(Conv2D|ConvDepthWise2D|PaddedConv2D|PaddedConvDepthWise2D)->BatchNorm2D");
for (auto match : matches) {
auto rootNode = match.graph->rootNode();
fuseBatchNorm<2>(rootNode, *rootNode->getChildren().begin());
}
auto matches = SinglePassGraphMatching(graphView).match("(Conv1D|ConvDepthWise1D|PaddedConv1D|PaddedConvDepthWise1D)->BatchNorm1D");
for (auto match : matches) {
auto rootNode = match.graph->rootNode();
fuseBatchNorm<1>(rootNode, *rootNode->getChildren().begin());
}
Edited by Cyril Moineau