Skip to content
Snippets Groups Projects
Commit b331fa1a authored by Cyril Moineau's avatar Cyril Moineau
Browse files

[Recipies] Extend recipies so that they can also use schedulers as input.

parent a0bf5e8d
No related branches found
No related tags found
No related merge requests found
......@@ -16,12 +16,12 @@
#include "aidge/graph/GraphView.hpp"
namespace Aidge{
void fuseMulAdd(std::set<std::shared_ptr<Node>> nodes);
void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
void fuseMulAdd(std::shared_ptr<GraphView> graphView);
void removeFlatten(std::set<std::shared_ptr<Node>> nodes);
void removeFlatten(std::shared_ptr<GraphView> graphView);
}
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
\ No newline at end of file
#endif /* AIDGE_CORE_UTILS_RECIPIES_H_ */
......@@ -20,24 +20,30 @@ namespace py = pybind11;
namespace Aidge {
void init_Recipies(py::module &m) {
m.def("fuse_mul_add", &fuseMulAdd, py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an `aidge.FC` operator.
Parameters
----------
m.def("fuse_mul_add", static_cast<void(*)(std::shared_ptr<GraphView>)>(fuseMulAdd), py::arg("graph_view"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("fuse_mul_add", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(fuseMulAdd), py::arg("nodes"), R"mydelimiter(
Recipie to Fuse MatMul and Add operators into an :py:class:`aidge_core.FC` operator.
:param nodes: The MatMul and Add nodes to fuse.
:type nodes: list of `aidge.node`
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
m.def("remove_flatten", static_cast<void(*)(std::shared_ptr<GraphView>)>(removeFlatten), py::arg("graph_view"), R"mydelimiter(
Recipie to remove a flatten operator.
:param graph_view: Graph view on which we want to apply the recipie
:type graph_view: :py:class:`aidge_core.GraphView`
)mydelimiter");
m.def("remove_flatten", &removeFlatten, py::arg("nodes"), R"mydelimiter(
m.def("remove_flatten", static_cast<void(*)(std::set<std::shared_ptr<Node>>)>(removeFlatten), py::arg("nodes"), R"mydelimiter(
Recipie to remove a flatten operator.
Parameters
----------
:param nodes: The flatten operator to remove.
:type nodes: list of `aidge.node`
:param nodes: The flatten operator to remove.
:type nodes: list of :py:class:`aidge_core.Node`
)mydelimiter");
}
} // namespace Aidge
......@@ -25,16 +25,16 @@ using namespace Aidge;
/**
* @brief Merge MatMul and Add Node into FC.
*
*
* @param nodes Strict set of Node to merge.
*/
void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
// Fuse Mulmat & Add into FC
// Inputs : old nodes (pointers on mul & add)
assert(nodes.size() == 2 && "Wrong number of nodes to replace\n");
// Too bad we lose information on the type after matching, how to keep the information after matching (not only for the type) ?
// Step 0 : Assert the nodes types are correct to be fused
std::shared_ptr<Node> add;
std::shared_ptr<Node> matmul;
......@@ -53,7 +53,7 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
auto producer_add_bias = add->input(1);
Tensor& bias_tensor = (producer_add_bias.first)->getOperator()->output(0);
// Instanciate FC
// Instanciate FC
//std::shared_ptr<Node> fc = FC(dim[0], false, "Fc");
std::shared_ptr<Node> fc = std::make_shared<Node>(std::make_shared<FC_Op>(bias_tensor.dims()[0], false));
......@@ -77,4 +77,5 @@ void Aidge::fuseMulAdd(std::set<std::shared_ptr<Node>> nodes){
nodeToReplace->add(nodes);
nodeToReplace->replaceWith({fc});
}
\ No newline at end of file
}
......@@ -15,10 +15,28 @@
#include "aidge/graph/GraphView.hpp"
#include "aidge/utils/Recipies.hpp"
// Graph Regex
#include "aidge/graphmatching/GRegex.hpp"
#include "aidge/graphmatching/NodeRegex.hpp"
namespace Aidge {
void removeFlatten(std::set<std::shared_ptr<Node>> nodes) {
auto g = std::make_shared<GraphView>();
g->add(std::set<std::shared_ptr<Node>>({nodes}));
g->replaceWith({});
}
}
\ No newline at end of file
void removeFlatten(std::shared_ptr<GraphView> graphView){
std::map<std::string,NodeRegex*> nodesRegex ;
nodesRegex["Flatten"] = new NodeRegex("Flatten");
std::vector<std::string> seqRegex;
seqRegex.push_back("Flatten;");
GRegex GReg(nodesRegex, seqRegex);
Match matches = GReg.match(graphView);
std::vector<std::set<std::shared_ptr<Node>>> matchNodes = matches.getMatchNodes();
for (size_t i = 0; i < matches.getNbMatch(); ++i) {
removeFlatten(matchNodes[i]);
}
}
}
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment