Forked from
Eclipse Projects / aidge / aidge_core
2182 commits behind the upstream repository.
-
Maxence Naud authoredMaxence Naud authored
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
FuseMulAdd.cpp 3.74 KiB
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <set>
#include <cassert>
#include <memory>
#include <string>
#include "aidge/operator/FC.hpp"
#include "aidge/utils/Recipies.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graph/Node.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
using namespace Aidge;
void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Fuse Mulmat & Add into FC
// Inputs : old nodes (pointers on mul & add)
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ?
// Step 0 : Assert the nodes types are correct to be fused
std::shared_ptr<Node> add;
std::shared_ptr<Node> matmul;
for (const auto& element : nodes) {
assert((element->type() == "MatMul" || element->type() == "Add") && "Wrong type for the nodes to replace");
if (element->type() == "MatMul"){
matmul = element;
}
else if (element->type() == "Add") {
add = element;
}
}
// Step 1 : Create FC
// Fetch the output dimension throught the bias size
std::shared_ptr<Node> bias = (add->getParent(1)) ? add->getParent(1)->cloneSharedOperators() : nullptr;
if (!(matmul->getParent(1))) {
AIDGE_INTERNAL_ASSERT("No weight detected to produce the fuseMulAdd recipe.");
}
std::shared_ptr<Node> weight = matmul->getParent(1)->cloneSharedOperators();
DimSize_t outSize = weight->getOperator()->output(0).dims<2>()[1];
// Instanciate FC
//std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(outSize, bias ? false : true));
// Step 2 : Branch existing producers & create the others
// link weights & bias
weight->addChild(fc, 0, 1);
if (bias) {
bias->addChild(fc, 0, 2);
}
// Step 3 : Update all graphviews that contains at least one node to replace
// Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
// Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
// Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory ?
// auto nodeToReplace = std::make_shared<GraphView>();
// nodeToReplace->add(nodes, false);
// nodeToReplace->replaceWith({fc});
auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)});
GraphView::replace({matmul, add, add->getParent(1), matmul->getParent(1)}, newNodes);
}
void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["MatMul"] = new NodeRegex("MatMul");
nodesRegex["Add"] = new NodeRegex("Add");
std::vector<std::string> seqRegex;
seqRegex.push_back("MatMul -> Add;");
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
fuseMulAdd(matchNodes[i]);
}
}