Skip to content
Snippets Groups Projects
Commit 009d227d authored by Olivier BICHLER's avatar Olivier BICHLER
Browse files

Fixed minor issues

parent 9090eb3c
No related branches found
No related tags found
3 merge requests!279v0.4.0,!253v0.4.0,!244Add MatMulTiling recipe
...@@ -56,6 +56,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) ...@@ -56,6 +56,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims)
slice00_ends->addChild(slice00, 0, 2); slice00_ends->addChild(slice00, 0, 2);
auto slice00_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true); auto slice00_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true);
slice00_axes->addChild(slice00, 0, 3); slice00_axes->addChild(slice00, 0, 3);
auto slice00_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true);
slice00_steps->addChild(slice00, 0, 4);
auto matMul00 = MatMul(); auto matMul00 = MatMul();
auto identity1 = Identity(); auto identity1 = Identity();
auto slice01 = Slice(); auto slice01 = Slice();
...@@ -65,6 +67,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) ...@@ -65,6 +67,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims)
slice01_ends->addChild(slice01, 0, 2); slice01_ends->addChild(slice01, 0, 2);
auto slice01_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true); auto slice01_axes = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{0, 1}}), "", true);
slice01_axes->addChild(slice01, 0, 3); slice01_axes->addChild(slice01, 0, 3);
auto slice01_steps = Producer(std::make_shared<Tensor>(Vector<DimSize_t>{{1, 1}}), "", true);
slice01_steps->addChild(slice01, 0, 4);
auto matMul01 = MatMul(); auto matMul01 = MatMul();
auto concat0 = Concat(2, axis); auto concat0 = Concat(2, axis);
...@@ -82,7 +86,7 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) ...@@ -82,7 +86,7 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims)
auto g = std::make_shared<GraphView>(); auto g = std::make_shared<GraphView>();
g->add({identity0, identity1}); g->add({identity0, identity1});
g->add({slice00, matMul00, matMul01, slice01, concat0}); g->add({slice00, slice00_starts, slice00_ends, slice00_axes, slice00_steps, matMul00, matMul01, slice01, slice01_starts, slice01_ends, slice01_axes, slice01_steps, concat0});
g->save("micrograph"); g->save("micrograph");
auto replaced = GraphView::replace(gMatMul, g); auto replaced = GraphView::replace(gMatMul, g);
...@@ -91,8 +95,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) ...@@ -91,8 +95,8 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims)
g->forwardDims(); g->forwardDims();
// Recursive tiling // Recursive tiling
MatMulTiling(matMul00, maxDims); matMulTiling(matMul00, maxDims);
MatMulTiling(matMul01, maxDims); matMulTiling(matMul01, maxDims);
} }
else { else {
Log::warn("Unable to split MatMul {}", matMul->name()); Log::warn("Unable to split MatMul {}", matMul->name());
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment