Skip to content
Snippets Groups Projects
Commit 0e48b653 authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Merge branch 'matmul_to_fc' into 'dev'

Change fuseMulAdd

See merge request !170
parents ca69594f 7d6be902
No related branches found
No related tags found
2 merge requests!212Version 0.3.0,!170Change fuseMulAdd
Pipeline #53623 passed
......@@ -65,7 +65,7 @@ class test_recipes(unittest.TestCase):
graph_view.add(b1)
old_nodes = graph_view.get_nodes()
aidge_core.fuse_mul_add(graph_view)
aidge_core.matmul_to_fc(graph_view)
self.assertTrue(len(graph_view.get_nodes()) == len(old_nodes) - 2)
self.assertTrue("MatMul0" not in [i.name() for i in graph_view.get_nodes()])
......
......@@ -31,18 +31,14 @@ void constantFolding(std::shared_ptr<GraphView> graph);
*
* @param nodes Strict set of Node to merge.
*/
//void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
void fuseMulAdd(std::shared_ptr<MatchSolution> solution);
void fuseMulAdd(std::shared_ptr<Node> matmul,std::shared_ptr<Node> add);
void matMulToFC(std::shared_ptr<Node> matmul, std::shared_ptr<Node> add = nullptr);
/**
* @brief Merge ``MatMul`` and :cpp:function:`Aidge::Add` Node into a :cpp:function:`Aidge::FC` Node.
*
* @param graphView Graph view to use graph matching on, in order to apply transformations.
*/
void fuseMulAdd(std::shared_ptr<GraphView> graphView);
void matMulToFC(std::shared_ptr<GraphView> graphView);
/**
* @brief Remove a node type.
......
......@@ -25,14 +25,14 @@ void init_Recipes(py::module &m)
{
m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter(
m.def("matmul_to_fc", static_cast<void(*)(std::shared_ptr<GraphView>)>(matMulToFC), py::arg("graph_view"), R"mydelimiter(
Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param graph_view: Graph view on which we want to apply the recipe
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// m.def("matmul_to_fc", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(matMulToFC), py::arg("nodes"), R"mydelimiter(
// recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse.
......@@ -84,13 +84,6 @@ void init_Recipes(py::module &m)
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
// m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
// Recipe to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
// :param nodes: The MatMul and Add nodes to fuse.
// :type nodes: list of :py:class:`aidge_core.Node`
// )mydelimiter");
m.def("fuse_batchnorm", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseBatchNorm), py::arg("graph_view"), R"mydelimiter(
Recipe to remove a flatten operator.
......
......@@ -22,28 +22,29 @@
#include "aidge/operator/GenericOperator.hpp"
#include "aidge/utils/ErrorHandling.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/graph/Matching.hpp"
//Graph Regex
#include "aidge/graphRegex/GraphRegex.hpp"
void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) { //std::set<std::shared_ptr<Node>> nodes){
void Aidge::matMulToFC(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<Aidge::Node> addNode) {
// Fuse Mulmat & Add into FC
// Inputs : old nodes (pointers on mul & add)
assert((matmulNode->type() == "MatMul" && addNode->type() == "Add") && "Wrong type for the nodes to replace");
AIDGE_ASSERT((matmulNode->type() == "MatMul" && (addNode == nullptr || addNode->type() == "Add")),
"Wrong type for the nodes to replace: {} and {}",
matmulNode->type(), (addNode) ? addNode->type() : "nullptr");
// Step 1 : Create FC
// Fetch the output dimension throught the bias size
std::shared_ptr<Node> bias = nullptr;
if (addNode->getParent(0) == matmulNode) {
AIDGE_ASSERT(addNode->getParent(1), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(1);
}
else if (addNode->getParent(1) == matmulNode) {
AIDGE_ASSERT(addNode->getParent(0), "No bias detected to produce the fuseMulAdd recipe.");
bias = addNode->getParent(0);
if (addNode) {
if (addNode->getParent(0) == matmulNode) {
AIDGE_ASSERT(addNode->getParent(1), "No bias detected to produce the matMulToFC recipe.");
bias = addNode->getParent(1);
}
else if (addNode->getParent(1) == matmulNode) {
AIDGE_ASSERT(addNode->getParent(0), "No bias detected to produce the matMulToFC recipe.");
bias = addNode->getParent(0);
}
}
std::shared_ptr<Node> weight = nullptr;
......@@ -75,24 +76,9 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
}
AIDGE_ASSERT(weight != nullptr, "Could not deduce weight input for MatMul operator.");
// TODO: find another way to get OutChannels for FC operator.
// This poor fix supposes that one of Add inputs is a const and has the same outChannels as the output
DimSize_t outSize = 0;
AIDGE_ASSERT(addNode->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type.");
const auto& op = std::static_pointer_cast<OperatorTensor>(addNode->getOperator());
for (size_t i = 0; i < op->nbInputs(); i++)
{
const auto& inTensor = op->getInput(i);
if(inTensor->nbDims() > 0) {
outSize = inTensor->dims()[inTensor->nbDims()-1];
break;
}
}
AIDGE_ASSERT(outSize, "Could not get output number of channels for FC operator.");
// Instanciate FC
std::string fcName = matmulNode->name();
if (!addNode->name().empty()) {
if (addNode && !addNode->name().empty()) {
fcName += "_" + addNode->name();
}
......@@ -105,43 +91,26 @@ void Aidge::fuseMulAdd(std::shared_ptr<Aidge::Node> matmulNode, std::shared_ptr<
bias->cloneSharedOperators()->addChild(fc, 0, 2);
}
// Step 3 : Update all graphviews that contains at least one node to replace
// Case 1 : If all nodes are in a graph view : delete old nodes & branch input & output
// Case 2 : If not all nodes are in a graph view : only delete the nodes from the graphview
// Maybe create a central mechanism to update automatically all graph views rather than each node have graphview presence memory?
auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1), fc->getParent(2)});
GraphView::replace({matmulNode, addNode, bias, weight}, newNodes);
}
void Aidge::fuseMulAdd(std::shared_ptr<Aidge::MatchSolution> solution){
assert(solution->at("MatMul").size() == 1 && "Wrong number of nodes MatMul to replace\n");
assert(solution->at("Add").size() == 1 && "Wrong number of nodes Add to replace\n");
for (const auto& matmulNode : solution->at("MatMul")) {
for (const auto& addNode : solution->at("Add")) {
fuseMulAdd(matmulNode,addNode);
}
if (addNode) {
auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1), fc->getParent(2)});
GraphView::replace({matmulNode, addNode, bias, weight}, newNodes);
}
else {
auto newNodes = std::set<std::shared_ptr<Node>>({fc, fc->getParent(1)});
GraphView::replace({matmulNode, weight}, newNodes);
}
}
void Aidge::fuseMulAdd(std::shared_ptr<Aidge::GraphView> graphView){
std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>();
regex->setNodeKey("Add","getType($) =='Add'");
regex->setNodeKey("MatMul","getType($) =='MatMul'");
regex->addQuery("MatMul -> Add ;");
for (const auto& solution : regex->match(graphView)) {
fuseMulAdd(solution);
}
void Aidge::matMulToFC(std::shared_ptr<Aidge::GraphView> graphView){
const auto matches = SinglePassGraphMatching(graphView).match("MatMul->Add#?");
for (const auto& match : matches) {
const auto it = match.anchors.find("Add");
matMulToFC(match.graph->rootNode(), (it != match.anchors.end()) ? it->second.at("#") : nullptr);
}
}
......@@ -189,7 +189,7 @@ TEST_CASE("GraphRegexUser") {
kitchenBook->setNodeKey("Flatten","getType($) =='Flatten'");
kitchenBook->setNodeKey("FC","getType($) =='FC'");
kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
//kitchenBook->addQuery("MatMul->Add",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(fuseMulAdd));
kitchenBook->addQuery("Flatten->FC",static_cast<void(*)(std::shared_ptr<MatchSolution>)>(removeFlatten));
kitchenBook->appliedRecipes(g);
......
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
namespace Aidge {
TEST_CASE("[cpu/recipes] FuseMulAdd", "[FuseMulAdd][recipes]") {
// generate the original GraphView
auto matmul0 = MatMul("matmul0");
auto add0 = Add(2, "add0");
auto matmul1 = MatMul("matmul1");
auto add1 = Add(2, "add1");
auto b0 = Producer({5}, "B0");
auto w0 = Producer({5, 5}, "W0");
auto b1 = Producer({5}, "B1");
auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0);
w0->addChild(matmul0, 0, 1);
matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 1);
w1->addChild(matmul1, 0, 0);
matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1);
auto g = std::make_shared<GraphView>();
g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1});
// Check original graph
REQUIRE(g->getNodes() ==
std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0)));
REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0)));
REQUIRE(((matmul1->getParent(1) == add0) && (matmul1->getParent(0) == w1)));
REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1)));
// Transform GraphView inplace
fuseMulAdd(g);
// Check new GraphView
std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
REQUIRE(newNodes.size() == 6);
for (const auto& node : newNodes) {
REQUIRE(((node->type() == "Producer") || (node->type() == "FC")));
}
}
} // namespace Aidge
/********************************************************************************
* Copyright (c) 2023 CEA-List
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0.
*
* SPDX-License-Identifier: EPL-2.0
*
********************************************************************************/
#include <catch2/catch_test_macros.hpp>
#include <set>
#include "aidge/data/Tensor.hpp"
#include "aidge/graph/GraphView.hpp"
#include "aidge/operator/Add.hpp"
#include "aidge/operator/FC.hpp"
#include "aidge/operator/MatMul.hpp"
#include "aidge/operator/Producer.hpp"
#include "aidge/recipes/Recipes.hpp"
namespace Aidge {
TEST_CASE("[cpu/recipes] MatMulToFC", "[MatMulToFC][recipes]") {
SECTION("with Add") {
// generate the original GraphView
auto matmul0 = MatMul("matmul0");
auto add0 = Add(2, "add0");
auto matmul1 = MatMul("matmul1");
auto add1 = Add(2, "add1");
auto b0 = Producer({5}, "B0");
auto w0 = Producer({5, 5}, "W0");
auto b1 = Producer({5}, "B1");
auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0);
w0->addChild(matmul0, 0, 1);
matmul0->addChild(add0, 0, 0);
b0->addChild(add0, 0, 1);
add0->addChild(matmul1, 0, 1);
w1->addChild(matmul1, 0, 0);
matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1);
auto g = std::make_shared<GraphView>();
g->add({w0, matmul0, b0, add0, w1, matmul1, b1, add1});
// Check original graph
REQUIRE(g->getNodes() ==
std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0)));
REQUIRE(((add0->getParent(0) == matmul0) && (add0->getParent(1) == b0)));
REQUIRE(((matmul1->getParent(1) == add0) && (matmul1->getParent(0) == w1)));
REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1)));
// Transform GraphView inplace
matMulToFC(g);
// Check new GraphView
std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, b0, add0, w1, matmul1, b1, add1}));
REQUIRE(newNodes.size() == 6);
for (const auto& node : newNodes) {
REQUIRE(((node->type() == "Producer") || (node->type() == "FC")));
}
}
SECTION("without Add") {
// generate the original GraphView
auto matmul0 = MatMul("matmul0");
auto matmul1 = MatMul("matmul1");
auto add1 = Add(2, "add1");
auto w0 = Producer({5, 5}, "W0");
auto b1 = Producer({5}, "B1");
auto w1 = Producer({5,5},"W1");
auto input = Producer({2,5}, "input");
input->addChild(matmul0, 0, 0);
w0->addChild(matmul0, 0, 1);
matmul0->addChild(matmul1, 0, 1);
w1->addChild(matmul1, 0, 0);
matmul1->addChild(add1, 0, 0);
b1->addChild(add1, 0, 1);
auto g = std::make_shared<GraphView>();
g->add({w0, matmul0, w1, matmul1, b1, add1});
// Check original graph
REQUIRE(g->getNodes() ==
std::set<std::shared_ptr<Node>>({w0, matmul0, w1, matmul1, b1, add1}));
REQUIRE(((matmul0->getParent(0) == input) && (matmul0->getParent(1) == w0)));
REQUIRE(((matmul1->getParent(1) == matmul0) && (matmul1->getParent(0) == w1)));
REQUIRE(((add1->getParent(0) == matmul1) && (add1->getParent(1) == b1)));
// Transform GraphView inplace
matMulToFC(g);
// Check new GraphView
std::set<std::shared_ptr<Node>> newNodes = g->getNodes();
REQUIRE(newNodes != std::set<std::shared_ptr<Node>>({w0, matmul0, w1, matmul1, b1, add1}));
REQUIRE(newNodes.size() == 5);
for (const auto& node : newNodes) {
REQUIRE(((node->type() == "Producer") || (node->type() == "FC")));
}
}
}
} // namespace Aidge
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment