Skip to content
Snippets Groups Projects
Commit b331fa1a authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Recipies] Extend recipies so that they can also use schedulers as input.

parent a0bf5e8d
No related branches found
No related tags found
1 merge request!9Fuse bn
...@@ -16,12 +16,12 @@ ...@@ -16,12 +16,12 @@
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
namespace Aidge{ namespace Aidge{
void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes); void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
void removeFlatten(std::set<std::shared_ptr<Node>> nodes); void fuseMulAdd(std::shared_ptr<GraphView> graphView);
void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
void removeFlatten(std::shared_ptr<GraphView> graphView);
} }
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
\ No newline at end of file
...@@ -20,24 +20,30 @@ namespace py = pybind11; ...@@ -20,24 +20,30 @@ namespace py = pybind11;
namespace Aidge { namespace Aidge {
void init_Recipies(py::module &m) { void init_Recipies(py::module &m) {
m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes"), R"mydelimiter( m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an `aidge.FC` operator. Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
Parameters :param graph_view: Graph view on which we want to apply the recipie
---------- :type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param nodes: The MatMul and Add nodes to fuse. :param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of `aidge.node` :type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter"); )mydelimiter");
m.def("remove_flatten", &removeFlatten, py::arg("nodes"), R"mydelimiter( m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator. Recipie to remove a flatten operator.
Parameters
----------
:param nodes: The flatten operator to remove.
:type nodes: list of `aidge.node`
:param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter"); )mydelimiter");
} }
} // namespace Aidge } // namespace Aidge
...@@ -25,16 +25,16 @@ using namespace Aidge; ...@@ -25,16 +25,16 @@ using namespace Aidge;
/** /**
* @brief Merge MatMul and Add Node into FC. * @brief Merge MatMul and Add Node into FC.
* *
* @param nodes Strict set of Node to merge. * @param nodes Strict set of Node to merge.
*/ */
void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Fuse Mulmat & Add into FC // Fuse Mulmat & Add into FC
// Inputs : old nodes (pointers on mul & add) // Inputs : old nodes (pointers on mul & add)
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n"); assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ? // Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ?
// Step 0 : Assert the nodes types are correct to be fused // Step 0 : Assert the nodes types are correct to be fused
std::shared_ptr<Node> add; std::shared_ptr<Node> add;
std::shared_ptr<Node> matmul; std::shared_ptr<Node> matmul;
...@@ -53,7 +53,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -53,7 +53,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
auto producer_add_bias = add->input(1); auto producer_add_bias = add->input(1);
Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0); Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0);
// Instanciate FC // Instanciate FC
//std::shared_ptr<Node> fc = FC(dim[0], false, "Fc"); //std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false)); std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false));
...@@ -77,4 +77,5 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){ ...@@ -77,4 +77,5 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
nodeToReplace->add(nodes); nodeToReplace->add(nodes);
nodeToReplace->replaceWith({fc}); nodeToReplace->replaceWith({fc});
} }
\ No newline at end of file
...@@ -15,10 +15,28 @@ ...@@ -15,10 +15,28 @@
#include "aidge/graph/GraphView.hpp" #include "aidge/graph/GraphView.hpp"
#include "aidge/utils/Recipies.hpp" #include "aidge/utils/Recipies.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
namespace Aidge { namespace Aidge {
void removeFlatten(std::set<std::shared_ptr<Node>> nodes) { void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
auto g = std::make_shared<GraphView>(); auto g = std::make_shared<GraphView>();
g->add(std::set<std::shared_ptr<Node>>({nodes})); g->add(std::set<std::shared_ptr<Node>>({nodes}));
g->replaceWith({}); g->replaceWith({});
} }
}
\ No newline at end of file void removeFlatten(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["Flatten"] = new NodeRegex("Flatten");
std::vector<std::string> seqRegex;
seqRegex.push_back("Flatten;");
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
removeFlatten(matchNodes[i]);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment