Skip to content
Snippets Groups Projects
FuseMulAdd.cpp 3.57 KiB
Newer Older
Cyril Moineau's avatar
Cyril Moineau committed
/********************************************************************************
 * Copyright (c) 2023 CEA-List
 *
 * This program and the accompanying materials are made available under the
 * terms of the Eclipse Public License 2.0 which is available at
 * http://www.eclipse.org/legal/epl-2.0.
 *
 * SPDX-License-Identifier: EPL-2.0
 *
 ********************************************************************************/

#include <set>
#include <cassert>
#include <memory>
#include <string>

#include "aidge/operator/FC.hpp"
#include "aidge/recipies/Recipies.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
Cyril Moineau's avatar
Cyril Moineau committed
using namespace Aidge;

void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
    // Fuse Mulmat & Add into FC
    // Inputs : old nodes (pointers on mul & add)
Cyril Moineau's avatar
Cyril Moineau committed
    assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
    // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ?
Cyril Moineau's avatar
Cyril Moineau committed
    // Step 0 : Assert the nodes types are correct to be fused
    std::shared_ptr<Node> add;
    std::shared_ptr<Node> matmul;
    for (const auto& element : nodes) {
        assert((element->type() == "MatMul" || element->type() == "Add") && "Wrong type for the nodes to replace");
        if (element->type() == "MatMul"){
            matmul = element;
        }
        else if (element->type() == "Add") {
            add = element;
        }
    }

    // Step 1 : Create FC
    // Fetch the output dimension throught the bias size
    auto producer_add_bias = add->input(1);
    Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0);

Cyril Moineau's avatar
Cyril Moineau committed
    //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
    std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false));

    // Step 2 : Branch existing producers & create the others
    // link weights & bias
    if (matmul->getParent(1)==nullptr) {
        matmul->getParent(0)->addChild(fc, 0, 1);
        printf("MatMul out[1] == nullptr !\n");
Cyril Moineau's avatar
Cyril Moineau committed
    } else {
        printf("MatMul out[1] != nullptr !\n");
        if (matmul->getParent(0)!=nullptr)
            matmul->getParent(0)->addChild(fc, 0, 0);
        matmul->input(1).first->addChild(fc, 0, 1);
Cyril Moineau's avatar
Cyril Moineau committed
    }
    (producer_add_bias.first)->addChild(fc,0,2);


    // Step 3 : Update all graphviews that contains at least one node to replace
        // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
        // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
        // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ?
    auto nodeToReplace = std::make_shared<GraphView>();
    nodeToReplace->add(nodes, false);
Cyril Moineau's avatar
Cyril Moineau committed
    nodeToReplace->replaceWith({fc});

void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){

    std::map<std::string,NodeRegex*> nodesRegex ;
    nodesRegex["MatMul"] = new NodeRegex("MatMul");
    nodesRegex["Add"] = new NodeRegex("Add");
    std::vector<std::string> seqRegex;
    seqRegex.push_back("MatMul -> Add;");
    GRegex GReg(nodesRegex, seqRegex);
    Match matches = GReg.match(graphView);
    std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
    for (size_t i = 0; i < matches.getNbMatch(); ++i) {
        fuseMulAdd(matchNodes[i]);
    }
}