Skip to content
Snippets Groups Projects
Commit c3b0ad25 authored by vincent  lorrain's avatar vincent lorrain
Browse files

[recipies][GraphRegex] use GraphRegex in fuseMulAdd

parent 61d71c0c
No related branches found
No related tags found
1 merge request!48Refactor/recipies
Pipeline #34068 failed
......@@ -17,6 +17,8 @@
#include "aidge/graph/Node.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/graphRegex/matchFsm/MatchResult.hpp"
namespace Aidge{
......@@ -27,7 +29,12 @@ namespace Aidge{
*
* @param nodes Strict set of Node to merge.
*/
void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
void fuseMulAdd(std::shared_ptr<MatchSolution> solution);
void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
/**
* @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
*
......
......@@ -25,27 +25,17 @@
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
using namespace Aidge;
void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
void Aidge::fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add){//std::set<std::shared_ptr<Node>> nodes){
// Fuse Mulmat & Add into FC
// Inputs : old nodes (pointers on mul & add)
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ?
assert((matmul->type() == "MatMul" && add->type() == "Add") && "Wrong type for the nodes to replace");
// Step 0 : Assert the nodes types are correct to be fused
std::shared_ptr<Node> add;
std::shared_ptr<Node> matmul;
for (const auto& element : nodes) {
assert((element->type() == "MatMul" || element->type() == "Add") && "Wrong type for the nodes to replace");
if (element->type() == "MatMul"){
matmul = element;
}
else if (element->type() == "Add") {
add = element;
}
}
// Step 1 : Create FC
// Fetch the output dimension throught the bias size
......@@ -78,17 +68,55 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
}
void Aidge::fuseMulAdd(std::shared_ptr<MatchSolution> solution){
assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n");
assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n");
for (const auto& matmul : solution->at("MatMul")) {
for (const auto& add : solution->at("Add")) {
fuseMulAdd(matmul,add);
}
}
}
void Aidge::fuseMulAdd(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["MatMul"] = new NodeRegex("MatMul");
nodesRegex["Add"] = new NodeRegex("Add");
std::vector<std::string> seqRegex;
seqRegex.push_back("MatMul -> Add;");
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
fuseMulAdd(matchNodes[i]);
// std::map<std::string,NodeRegex*> nodesRegex ;
// nodesRegex["MatMul"] = new NodeRegex("MatMul");
// nodesRegex["Add"] = new NodeRegex("Add");
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Add","getType($) =='Add'");
regex->setNodeKey("MatMul","getType($) =='MatMul'");
regex->addQuery("MatMul -> Add ;");
for (const auto& solution : regex->match(graphView)) {
fuseMulAdd(solution);
// // solution->at("MatMul");
// // solution->at("Add");
// assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n");
// assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n");
// for (const auto& matmul : solution->at("MatMul")) {
// for (const auto& add : solution->at("Add")) {
// fuseMulAdd(matmul,add);
// }
// }
}
// std::vector<std::string> seqRegex;
// seqRegex.push_back("MatMul -> Add;");
// GRegex GReg(nodesRegex, seqRegex);
// Match matches = GReg.match(graphView);
// std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
// for (size_t i = 0; i < matches.getNbMatch(); ++i) {
// fuseMulAdd(matchNodes[i]);
// }
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment