Skip to content

[core] [recipes] FuseMulAdd when there is only Matmul

Problem description

Currently, the recipe FuseMulAdd fuses MatMul and Add operators into a FC layer, but MatMul operators alone don't get replaced with FC layers.

Indeed, if we create a Linear layer with no bias with the PyTorch framework, the exported ONNX will not have an Add layer. Thus not replacing the MatMul -> FC.

This is problematic has this makes other recipes like removeFlatten fail.

Reproducible example code

Trying to import the following network in Aidge:

class TorchLeNet(torch.nn.Module):

    def __init__(self):
        super(TorchLeNet, self).__init__()
        c1 = torch.nn.Conv2d(1, 6, 5, bias=False)
        c2 = torch.nn.Conv2d(6, 16, 5, bias=False)
        c3 = torch.nn.Conv2d(16, 120, 5, bias=False)
        l1 = torch.nn.Linear(120, 84, bias=False)
        l2 = torch.nn.Linear(84, 10, bias=False)

        torch.nn.init.uniform_(c1.weight)
        torch.nn.init.uniform_(c2.weight)
        torch.nn.init.uniform_(c3.weight)
        torch.nn.init.uniform_(l1.weight)
        torch.nn.init.uniform_(l2.weight)

        self.layer=torch.nn.Sequential(
            c1,
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(6),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            c2,
            torch.nn.ReLU(),
            torch.nn.BatchNorm2d(16),
            torch.nn.ReLU(),
            torch.nn.MaxPool2d(2),
            c3,
            torch.nn.ReLU(),
            torch.nn.Flatten(),
            l1,
            torch.nn.ReLU(),
            l2,
        )
        self.sequence = torch.nn.Sequential(
            self.layer,
        )
    def forward(self, x):
        x = self.sequence(x)
        return x

Proposed solution

Add a case for MatMul with no Add and replace the MatMul by a FC with noBias=True.

Porposed code solution update https://gitlab.eclipse.org/eclipse/aidge/aidge_core/-/blob/main/src/recipies/FuseMulAdd.cpp?ref_type=heads#L81-96 and create a function replaceMatMulByFC which does the replacement.

    regex->addQuery("MatMul -> Add ;",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
    regex->addQuery("MatMul;",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(replaceMatMulByFC));
    regex->appliedRecipes(graphView);
Edited by Maxence Naud