Skip to content
Snippets Groups Projects
Commit c64a6b94 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Change fuseMulAdd

parent ca69594f
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!170Change fuseMulAdd
...@@ -65,7 +65,7 @@ class test_recipes(unittest.TestCase): ...@@ -65,7 +65,7 @@ class test_recipes(unittest.TestCase):
graph_view.add(b1) graph_view.add(b1)
old_nodes = graph_view.get_nodes() old_nodes = graph_view.get_nodes()
aidge_core.fuse_mul_add(graph_view) aidge_core.matmul_to_fc(graph_view)
self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2) self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2)
self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()]) self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()])
......
...@@ -31,18 +31,14 @@ void constantFolding(std::shared_ptr<GraphView> graph); ...@@ -31,18 +31,14 @@ void constantFolding(std::shared_ptr<GraphView> graph);
* *
* @param nodes Strict set of Node to merge. * @param nodes Strict set of Node to merge.
*/ */
//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); void matMulToFC(std::shared_ptr<Node> matmul, std::shared_ptr<Node> add = nullptr);
void fuseMulAdd(std::shared_ptr<MatchSolution> solution);
void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
/** /**
* @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node. * @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
* *
* @param graphView Graph view to use graph matching on, in order to apply transformations. * @param graphView Graph view to use graph matching on, in order to apply transformations.
*/ */
void fuseMulAdd(std::shared_ptr<GraphView> graphView); void matMulToFC(std::shared_ptr<GraphView> graphView);
/** /**
* @brief Remove a node type. * @brief Remove a node type.
......
...@@ -25,14 +25,14 @@ void init_Recipes(py::module &m) ...@@ -25,14 +25,14 @@ void init_Recipes(py::module &m)
{ {
m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter( m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter(
Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param graph_view: Graph view on which we want to apply the recipe :param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView` :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter( // m.def("matmul_to_fc", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(matMulToFC), py::arg("nodes"), R"mydelimiter(
// recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator. // recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse. // :param nodes: The MatMul and Add nodes to fuse.
...@@ -84,13 +84,6 @@ void init_Recipes(py::module &m) ...@@ -84,13 +84,6 @@ void init_Recipes(py::module &m)
// :type nodes: list of :py:class:`aidge_core.Node` // :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter"); // )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter( m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipe to remove a flatten operator. Recipe to remove a flatten operator.
......
...@@ -22,28 +22,27 @@ ...@@ -22,28 +22,27 @@
#include "aidge/operator/GenericOperator.hpp" #include "aidge/operator/GenericOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp" #include "aidge/utils/ErrorHandling.hpp"
#include "aidge/operator/MatMul.hpp" #include "aidge/operator/MatMul.hpp"
#include "aidge/graph/Matching.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::matMulToFC(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) {
void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) { //std::set<std::shared_ptr<Node>> nodes){
// Fuse Mulmat & Add into FC // Fuse Mulmat & Add into FC
// Inputs : old nodes (pointers on mul & add) // Inputs : old nodes (pointers on mul & add)
AIDGE_ASSERT((matmulNode->type() == "MatMul" && (addNode == nullptr || addNode->type() == "Add")), "Wrong type for the nodes to replace");
assert((matmulNode->type() == "MatMul" && addNode->type() == "Add") && "Wrong type for the nodes to replace");
// Step 1 : Create FC // Step 1 : Create FC
// Fetch the output dimension throught the bias size // Fetch the output dimension throught the bias size
std::shared_ptr<Node> bias = nullptr; std::shared_ptr<Node> bias = nullptr;
if (addNode->getParent(0) == matmulNode) { if (addNode) {
AIDGE_ASSERT(addNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe."); if (addNode->getParent(0) == matmulNode) {
bias = addNode->getParent(1); AIDGE_ASSERT(addNode->getParent(1), "No bias detected to produce the matMulToFC recipe.");
} bias = addNode->getParent(1);
else if (addNode->getParent(1) == matmulNode) { }
AIDGE_ASSERT(addNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe."); else if (addNode->getParent(1) == matmulNode) {
bias = addNode->getParent(0); AIDGE_ASSERT(addNode->getParent(0), "No bias detected to produce the matMulToFC recipe.");
bias = addNode->getParent(0);
}
} }
std::shared_ptr<Node> weight = nullptr; std::shared_ptr<Node> weight = nullptr;
...@@ -75,24 +74,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -75,24 +74,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
} }
AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator."); AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator.");
// TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0;
AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator());
for (size_t i = 0; i < op->nbInputs(); i++)
{
const auto& inTensor = op->getInput(i);
if(inTensor->nbDims() > 0) {
outSize = inTensor->dims()[inTensor->nbDims()-1];
break;
}
}
AIDGE_ASSERT(outSize, "Could not get output number of channels for FC operator.");
// Instanciate FC // Instanciate FC
std::string fcName = matmulNode->name(); std::string fcName = matmulNode->name();
if (!addNode->name().empty()) { if (addNode && !addNode->name().empty()) {
fcName += "_" + addNode->name(); fcName += "_" + addNode->name();
} }
...@@ -105,7 +89,6 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -105,7 +89,6 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
bias->cloneSharedOperators()->addChild(fc, 0, 2); bias->cloneSharedOperators()->addChild(fc, 0, 2);
} }
// Step 3 : Update all graphviews that contains at least one node to replace // Step 3 : Update all graphviews that contains at least one node to replace
// Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output // Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
// Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview // Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
...@@ -115,33 +98,11 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr< ...@@ -115,33 +98,11 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
} }
void Aidge::matMulToFC(std::shared_ptr<Aidge::GraphView> graphView){
const auto matches = SinglePassGraphMatching(graphView).match("MatMul->Add#?");
void Aidge::fuseMulAdd(std::shared_ptr<Aidge::MatchSolution> solution){ for (const auto& match : matches) {
const auto it = match.anchors.find("Add");
assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n"); matMulToFC(match.startNode, (it != match.anchors.end()) ? it->second.at("#") : nullptr);
assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n");
for (const auto& matmulNode : solution->at("MatMul")) {
for (const auto& addNode : solution->at("Add")) {
fuseMulAdd(matmulNode,addNode);
}
}
}
void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Add","getType($) =='Add'");
regex->setNodeKey("MatMul","getType($) =='MatMul'");
regex->addQuery("MatMul -> Add ;");
for (const auto& solution : regex->match(graphView)) {
fuseMulAdd(solution);
} }
} }
...@@ -189,7 +189,7 @@ TEST_CASE("GraphRegexUser") { ...@@ -189,7 +189,7 @@ TEST_CASE("GraphRegexUser") {
kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'"); kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'");
kitchenBook->setNodeKey("FC","getType($) =='FC'"); kitchenBook->setNodeKey("FC","getType($) =='FC'");
kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd)); //kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten)); kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten));
kitchenBook->appliedRecipes(g); kitchenBook->appliedRecipes(g);
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
namespace Aidge { namespace Aidge {
TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") { TEST_CASE("[cpu/recipes] MatMulToFC", "[MatMulToFC][recipes]") {
// generate the original GraphView // generate the original GraphView
auto matmul0 = MatMul("matmul0"); auto matmul0 = MatMul("matmul0");
auto add0 = Add(2, "add0"); auto add0 = Add(2, "add0");
...@@ -60,7 +60,7 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") { ...@@ -60,7 +60,7 @@ TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") {
REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1))); REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1)));
// Transform GraphView inplace // Transform GraphView inplace
fuseMulAdd(g); matMulToFC(g);
// Check new GraphView // Check new GraphView
std::set<std::shared_ptr<Node>> newNodes = g->getNodes(); std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment