/******************************************************************************** * Copyright (c) 2023 CEA-List * * This program and the accompanying materials are made available under the * terms of the Eclipse Public License 2.0 which is available at * http://www.eclipse.org/legal/epl-2.0. * * SPDX-License-Identifier: EPL-2.0 * ********************************************************************************/ #include <set> #include <cassert> #include <memory> #include <string> #include "aidge/operator/FC.hpp" #include "aidge/recipies/Recipies.hpp" #include "aidge/graph/GraphView.hpp" #include "aidge/graph/Node.hpp" #include "aidge/operator/Producer.hpp" #include "aidge/operator/GenericOperator.hpp" #include "aidge/utils/ErrorHandling.hpp" #include "aidge/operator/MatMul.hpp" //Graph Regex #include "aidge/graphRegex/GraphRegex.hpp" void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) { //std::set<std::shared_ptr<Node>> nodes){ // Fuse Mulmat & Add into FC // Inputs : old nodes (pointers on mul & add) assert((matmulNode->type() == "MatMul" && addNode->type() == "Add") && "Wrong type for the nodes to replace"); // Step 1 : Create FC // Fetch the output dimension throught the bias size std::shared_ptr<Node> bias = (addNode->getParent(1)) ? addNode->getParent(1)->cloneSharedOperators() : nullptr; AIDGE_ASSERT(matmulNode->getParent(1), "No weight detected to produce the fuseMulAdd recipe."); std::shared_ptr<Node> weight = matmulNode->getParent(1)->cloneSharedOperators(); // TODO: find another way to get OutChannels for FC operator. // This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output DimSize_t outSize = 0; const auto& op = std::dynamic_pointer_cast<OperatorTensor>(addNode->getOperator()); for (size_t i = 0; i < op->nbInputs(); i++) { const auto& inTensor = op->getInput(i); if(inTensor->nbDims() > 0) { outSize = inTensor->dims()[inTensor->nbDims()-1]; break; } } AIDGE_ASSERT(outSize, "Couldnt get output number of channels for FC operator."); // Instanciate FC //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(outSize, bias ? false : true)); // Step 2 : Branch existing producers & create the others // link weights & bias weight->addChild(fc, 0, 1); if (bias) { bias->addChild(fc, 0, 2); } // Step 3 : Update all graphviews that contains at least one node to replace // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory? auto newNodes = std::set<std::shared_ptr<Node>>({fc, weight, fc->getParent(2)}); GraphView::replace({matmulNode, addNode, addNode->getParent(1), matmulNode->getParent(1)}, newNodes); } void Aidge::fuseMulAdd(std::shared_ptr<Aidge::MatchSolution> solution){ assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n"); for (const auto& matmulNode : solution->at("MatMul")) { for (const auto& addNode : solution->at("Add")) { fuseMulAdd(matmulNode,addNode); } } } void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){ std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); regex->setNodeKey("Add","getType($) =='Add'"); regex->setNodeKey("MatMul","getType($) =='MatMul'"); regex->addQuery("MatMul -> Add ;"); for (const auto& solution : regex->match(graphView)) { fuseMulAdd(solution); } }