diff --git a/src/recipes/MatMulTiling.cpp b/src/recipes/MatMulTiling.cpp index 21b143606704e4a0b8ab1c27df8238d08c3f9046..66d4974c64508573c90f693cb1bad96936e1841c 100644 --- a/src/recipes/MatMulTiling.cpp +++ b/src/recipes/MatMulTiling.cpp @@ -35,7 +35,7 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) if (!op->dimsForwarded()) { AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling"); } - + const auto& in0Tensor = op->getInput(0); const auto& in1Tensor = op->getInput(1); const auto& outTensor = op->getOutput(0); @@ -44,42 +44,68 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) const auto& outputDims = outTensor->dims(); const auto& outputMatDims = std::vector<std::size_t>(outputDims.end() - 2, outputDims.end());; - if (outputMatDims[0] > maxDims[0]) { - const std::int32_t axis = -2; - const std::int64_t splitIndex = maxDims[0]; + if (outputMatDims[0] > maxDims[0] || outputMatDims[1] > maxDims[1]) { + const auto sliceDims = (outputMatDims[0] > maxDims[0]) ? input0Dims : input1Dims; + std::int32_t axis; + std::int64_t splitIndex0_end = static_cast<std::int64_t>(sliceDims.end()[-2]); + std::int64_t splitIndex0_start = 0; + std::int64_t splitIndex1_end = static_cast<std::int64_t>(sliceDims.end()[-1]); + std::int64_t splitIndex1_start = 0; + + if (outputMatDims[0] > maxDims[0]) { + splitIndex0_end = maxDims[0]; + splitIndex0_start = maxDims[0]; + axis = -2; + } + else { + splitIndex1_end = maxDims[1]; + splitIndex1_start = maxDims[1]; + axis = -1; + } auto identity0 = Identity(); - auto slice00 = Slice(); - auto slice00_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{0, 0}}), "", true); - slice00_starts->addChild(slice00, 0, 1); - auto slice00_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex, static_cast<std::int64_t>(input0Dims.end()[-1])}}), "", true); - slice00_ends->addChild(slice00, 0, 2); - auto slice00_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true); - slice00_axes->addChild(slice00, 0, 3); - auto slice00_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true); - slice00_steps->addChild(slice00, 0, 4); - auto matMul00 = MatMul(); + auto sliceX0 = Slice(); + auto sliceX0_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{0, 0}}), "", true); + sliceX0_starts->addChild(sliceX0, 0, 1); + auto sliceX0_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex0_end, splitIndex1_end}}), "", true); + sliceX0_ends->addChild(sliceX0, 0, 2); + auto sliceX0_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true); + sliceX0_axes->addChild(sliceX0, 0, 3); + auto sliceX0_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true); + sliceX0_steps->addChild(sliceX0, 0, 4); + auto matMulX0 = MatMul(); auto identity1 = Identity(); - auto slice01 = Slice(); - auto slice01_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex, 0}}), "", true); - slice01_starts->addChild(slice01, 0, 1); - auto slice01_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{static_cast<std::int64_t>(input0Dims.end()[-2]), static_cast<std::int64_t>(input0Dims.end()[-1])}}), "", true); - slice01_ends->addChild(slice01, 0, 2); - auto slice01_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true); - slice01_axes->addChild(slice01, 0, 3); - auto slice01_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true); - slice01_steps->addChild(slice01, 0, 4); - auto matMul01 = MatMul(); - auto concat0 = Concat(2, axis); + auto sliceX1 = Slice(); + auto sliceX1_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex0_start, splitIndex1_start}}), "", true); + sliceX1_starts->addChild(sliceX1, 0, 1); + auto sliceX1_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{static_cast<std::int64_t>(sliceDims.end()[-2]), static_cast<std::int64_t>(sliceDims.end()[-1])}}), "", true); + sliceX1_ends->addChild(sliceX1, 0, 2); + auto sliceX1_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true); + sliceX1_axes->addChild(sliceX1, 0, 3); + auto sliceX1_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true); + sliceX1_steps->addChild(sliceX1, 0, 4); + auto matMulX1 = MatMul(); + auto concat = Concat(2, axis); - identity0->addChild(slice00, 0, 0); - identity0->addChild(slice01, 0, 0); - identity1->addChild(matMul00, 0, 1); - identity1->addChild(matMul01, 0, 1); - slice00->addChild(matMul00, 0, 0); - slice01->addChild(matMul01, 0, 0); - matMul00->addChild(concat0, 0, 0); - matMul01->addChild(concat0, 0, 1); + if (outputMatDims[0] > maxDims[0]) { + identity0->addChild(sliceX0, 0, 0); + identity0->addChild(sliceX1, 0, 0); + identity1->addChild(matMulX0, 0, 1); + identity1->addChild(matMulX1, 0, 1); + sliceX0->addChild(matMulX0, 0, 0); + sliceX1->addChild(matMulX1, 0, 0); + } + else { + identity0->addChild(matMulX0, 0, 0); + identity0->addChild(matMulX1, 0, 0); + identity1->addChild(sliceX0, 0, 0); + identity1->addChild(sliceX1, 0, 0); + sliceX0->addChild(matMulX0, 0, 1); + sliceX1->addChild(matMulX1, 0, 1); + } + + matMulX0->addChild(concat, 0, 0); + matMulX1->addChild(concat, 0, 1); auto gMatMul = std::make_shared<GraphView>(); gMatMul->add({matMul}); @@ -87,7 +113,7 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) auto g = std::make_shared<GraphView>(); g->add({identity0}); g->add({identity1}); - g->add({slice00, slice00_starts, slice00_ends, slice00_axes, slice00_steps, matMul00, matMul01, slice01, slice01_starts, slice01_ends, slice01_axes, slice01_steps, concat0}); + g->add({sliceX0, sliceX0_starts, sliceX0_ends, sliceX0_axes, sliceX0_steps, matMulX0, matMulX1, sliceX1, sliceX1_starts, sliceX1_ends, sliceX1_axes, sliceX1_steps, concat}); auto replaced = GraphView::replace(gMatMul, g); @@ -96,14 +122,11 @@ void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) g->save("micrograph"); // Recursive tiling - matMulTiling(matMul01, maxDims); - // TODO: other dimension + matMulTiling(matMulX1, maxDims); + matMulTiling(matMulX0, maxDims); } else { Log::warn("Unable to split MatMul {}", matMul->name()); } } - else if (outputMatDims[1] > maxDims[1]) { - // TODO - } } diff --git a/src/recipes/RemoveNode.cpp b/src/recipes/RemoveNode.cpp index a09c67991409dfe491d46b4ad739f9ddf5b72aef..3a1bac588ee8a1bb38f74fee441c9eff07b4ef6e 100644 --- a/src/recipes/RemoveNode.cpp +++ b/src/recipes/RemoveNode.cpp @@ -13,24 +13,15 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Matching.hpp" #include "aidge/recipes/Recipes.hpp" - -//Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" - size_t Aidge::removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers) { - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey(type, "getType($) =='" + type + "'"); - regex->addQuery(type + "#"); - - const auto matches = regex->match(graphView); - for (const auto& solution : matches) { - assert(solution->at(type).size() == 1 && "Wrong number of nodes to replace\n"); - - std::set<NodePtr> nodesToRemove = solution->at(type); + auto matches = SinglePassGraphMatching(graphView).match(type); + for (const auto& match : matches) { + std::set<NodePtr> nodesToRemove = {match.graph->rootNode()}; if (incProducers) { - for (const auto& nodePtr: (*solution->at(type).begin())->getParents()) { + for (const auto& nodePtr: match.graph->rootNode()->getParents()) { if (nodePtr != nullptr && nodePtr->type() == "Producer") { nodesToRemove.insert(nodePtr); }