Skip to content
Snippets Groups Projects
Commit e87dd5bd authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Working 2D tiling

parent 693cb6ff
No related branches found
No related tags found
3 merge requests!118v0.4.0,!108v0.4.0,!105Add MatMulTiling recipe
Pipeline #58687 passed
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "aidge/operator/Producer.hpp" #include "aidge/operator/Producer.hpp"
#include "aidge/graph/OpArgs.hpp" #include "aidge/graph/OpArgs.hpp"
#include "aidge/scheduler/SequentialScheduler.hpp" #include "aidge/scheduler/SequentialScheduler.hpp"
#include "aidge/graph/Matching.hpp"
#include "aidge/utils/TensorUtils.hpp" #include "aidge/utils/TensorUtils.hpp"
using namespace Aidge; using namespace Aidge;
...@@ -79,10 +80,22 @@ TEST_CASE("[MatMulTiling]") { ...@@ -79,10 +80,22 @@ TEST_CASE("[MatMulTiling]") {
// Tiling // Tiling
fmt::println("Tiling"); fmt::println("Tiling");
matMulTiling(matmul1, {16, 16}); matMulTiling(matmul1, {16, 16});
removeIdentity(g1);
g1->setBackend("cpu"); g1->setBackend("cpu");
g1->save("MatMulSplitting_graph_split"); g1->save("MatMulSplitting_graph_split");
auto gm = SinglePassGraphMatching(g1);
gm.addNodeLambda("16x16", [](const NodePtr& node) {
const auto op =
std::static_pointer_cast<OperatorTensor>(node->getOperator());
const auto dims = op->getOutput(0)->dims();
return (dims.end()[-2] == 16 && dims.end()[-1] == 16);
});
const auto results = gm.match("MatMul[16x16]");
REQUIRE(results.size() == 25);
// Check result // Check result
fmt::println("Schedule forward tiled graph"); fmt::println("Schedule forward tiled graph");
s1 = SequentialScheduler(g1); s1 = SequentialScheduler(g1);
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment