diff --git a/include/aidge/backend/cpu/data/TensorImpl.hpp b/include/aidge/backend/cpu/data/TensorImpl.hpp index d2fca20751e18ead1de55a07b1dfe80b697f8391..9390fe5860b5d3523886856d9b2a40752d338af5 100644 --- a/include/aidge/backend/cpu/data/TensorImpl.hpp +++ b/include/aidge/backend/cpu/data/TensorImpl.hpp @@ -119,30 +119,17 @@ private: template <typename T> const std::string TensorImpl_cpu<T>::Backend = "cpu"; -namespace { -static Registrar<Tensor> registrarTensorImpl_cpu_Float64( - {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Float32( - {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Float16( - {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Int64( - {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Int32( - {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Int16( - {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_Int8( - {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_UInt64( - {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_UInt32( - {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_UInt16( - {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); -static Registrar<Tensor> registrarTensorImpl_cpu_UInt8( - {"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create); -} // namespace +REGISTRAR(Tensor, {"cpu", DataType::Float64}, Aidge::TensorImpl_cpu<double>::create); +REGISTRAR(Tensor, {"cpu", DataType::Float32}, Aidge::TensorImpl_cpu<float>::create); +REGISTRAR(Tensor, {"cpu", DataType::Float16}, Aidge::TensorImpl_cpu<half_float::half>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int64}, Aidge::TensorImpl_cpu<int64_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int32}, Aidge::TensorImpl_cpu<int32_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int16}, Aidge::TensorImpl_cpu<int16_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::Int8}, Aidge::TensorImpl_cpu<int8_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt64}, Aidge::TensorImpl_cpu<uint64_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt32}, Aidge::TensorImpl_cpu<uint32_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt16}, Aidge::TensorImpl_cpu<uint16_t>::create); +REGISTRAR(Tensor, {"cpu", DataType::UInt8}, Aidge::TensorImpl_cpu<uint8_t>::create); } // namespace Aidge #endif /* AIDGE_CPU_DATA_TENSORIMPL_H_ */ diff --git a/include/aidge/recipes/Recipes.hpp b/include/aidge/recipes/Recipes.hpp index a9b9213e914811ccff7d1e6d8efe4fdd8a505b87..82ecc7d28b723d2b3e268f4fb6fbf20d595240ff 100644 --- a/include/aidge/recipes/Recipes.hpp +++ b/include/aidge/recipes/Recipes.hpp @@ -134,6 +134,23 @@ void explicitTranspose(std::shared_ptr<GraphView> graphView); */ void expandMetaOps(std::shared_ptr<GraphView> graph, bool recursive = false); +/** + * @brief Tile any :cpp:function:`Aidge::MatMul` operator to several fixed size matrix multiplications. + * For instance, for a MatMul of size 80x80 and a tiling of 16x16, this will tile + * the MatMul operator to 25 (5 by 5) MatMul operators of size 16x16, with Slice + * operators inserted at the inputs and Concat operators inserted at the outputs. + * + * This is especially useful when matrix multiplication must be mapped to fixed + * maximum size hardware TPU (Tensor Processing Unit) or MMA (Matrix Multiplication + * Accelerator). This recipe can be combined with the :cpp:function:`Aidge::convToMatMul` recipe in + * order to convert convolutions to matrix multiplication beforehand, and + * :cpp:function:`Aidge::constantFolding` recipe to fold sliced constant tensors. + * + * @param matMul MatMul operator to be tiled. + * @param maxDims Maximum output dimensions of the tiled MatMul operators. + */ +void matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims); + /** * Fuse each sub-graph matching a query in a Meta Operator. * @param graph Graph to manipulate diff --git a/src/operator/Concat.cpp b/src/operator/Concat.cpp index 55efdd51d56f7db4f64880b967def661e5354af5..27b9d1cf151c1d12aa4395a3b24673a2f2a4ad3c 100644 --- a/src/operator/Concat.cpp +++ b/src/operator/Concat.cpp @@ -49,7 +49,9 @@ std::shared_ptr<Aidge::Operator> Aidge::Concat_Op::clone() const { void Aidge::Concat_OpImpl::forward() { const Concat_Op& op = dynamic_cast<const Concat_Op&>(mOp); - const DimSize_t axis = op.axis(); + auto axis = op.axis(); + const auto nbDimsInput0 = op.getInput(0)->nbDims(); + axis = (axis < 0) ? axis + static_cast<std::int32_t>(nbDimsInput0) : axis; assert(op.getInput(0) && "missing input in Concat operator"); for (IOIndex_t i = 1; i < mOp.nbInputs(); ++i) { diff --git a/src/operator/MatMul.cpp b/src/operator/MatMul.cpp index 668ffd04b7acb0e72b4a3313805fa89ca3466f32..8fd2aa068c91dfebd6d1a3a47900c3aa9b0f9585 100644 --- a/src/operator/MatMul.cpp +++ b/src/operator/MatMul.cpp @@ -71,7 +71,7 @@ bool Aidge::MatMul_Op::forwardDims(bool /*allowDataDependency*/) { std::vector<std::size_t> outDims = std::vector<std::size_t>(dims_size-2, 1); for (std::size_t i = 0; i < dims_size-2; ++i) { - AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad vector dimension."); + AIDGE_ASSERT((dims0[i] == dims1[i]) || (dims0[i] == 1) || (dims1[i] == 1), "Bad dimension {}: {} != {} for input #0 {} and #1 {}.", i, dims0[i], dims1[i], dims0, dims1); outDims[i] = std::max(dims0[i], dims1[i]); } diff --git a/src/recipes/MatMulTiling.cpp b/src/recipes/MatMulTiling.cpp new file mode 100644 index 0000000000000000000000000000000000000000..cfc0b191a2f8b47aec92d1dec5ca8a44c95db5db --- /dev/null +++ b/src/recipes/MatMulTiling.cpp @@ -0,0 +1,131 @@ +/******************************************************************************** + * Copyright (c) 2023 CEA-List + * + * This program and the accompanying materials are made available under the + * terms of the Eclipse Public License 2.0 which is available at + * http://www.eclipse.org/legal/epl-2.0. + * + * SPDX-License-Identifier: EPL-2.0 + * + ********************************************************************************/ +#include <cassert> +#include <memory> +#include <set> +#include <string> + +#include "aidge/data/Tensor.hpp" +#include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Node.hpp" +#include "aidge/operator/Producer.hpp" +#include "aidge/operator/MatMul.hpp" +#include "aidge/operator/Slice.hpp" +#include "aidge/operator/Identity.hpp" +#include "aidge/operator/Concat.hpp" +#include "aidge/recipes/Recipes.hpp" +#include "aidge/utils/ErrorHandling.hpp" +#include "aidge/utils/Types.h" + +// see https://en.wikipedia.org/wiki/Matrix_multiplication_algorithm +void Aidge::matMulTiling(NodePtr matMul, const std::vector<DimSize_t>& maxDims) { + if (matMul->getOperator()->type() != "MatMul") { + AIDGE_INTERNAL_ASSERT("Operator should be a MatMul."); + } + AIDGE_ASSERT(matMul->getOperator()->operatorType() == OperatorType::Tensor, "Operator must be of Tensor type."); + const auto& op = std::static_pointer_cast<OperatorTensor>(matMul->getOperator()); + if (!op->dimsForwarded()) { + AIDGE_INTERNAL_ASSERT("Dimensions must be forwarded before any tiling"); + } + + const auto& in0Tensor = op->getInput(0); + const auto& in1Tensor = op->getInput(1); + const auto& outTensor = op->getOutput(0); + const auto& input0Dims = in0Tensor->dims(); + const auto& input1Dims = in1Tensor->dims(); + const auto& outputDims = outTensor->dims(); + const auto& outputMatDims = std::vector<std::size_t>(outputDims.end() - 2, outputDims.end());; + + if (outputMatDims[0] > maxDims[0] || outputMatDims[1] > maxDims[1]) { + const auto sliceDims = (outputMatDims[0] > maxDims[0]) ? input0Dims : input1Dims; + std::int32_t axis; + std::int64_t splitIndex0_end = static_cast<std::int64_t>(sliceDims.end()[-2]); + std::int64_t splitIndex0_start = 0; + std::int64_t splitIndex1_end = static_cast<std::int64_t>(sliceDims.end()[-1]); + std::int64_t splitIndex1_start = 0; + + if (outputMatDims[0] > maxDims[0]) { + splitIndex0_end = maxDims[0]; + splitIndex0_start = maxDims[0]; + axis = -2; + } + else { + splitIndex1_end = maxDims[1]; + splitIndex1_start = maxDims[1]; + axis = -1; + } + + auto identity0 = Identity(); + auto sliceX0 = Slice(); + auto sliceX0_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{0, 0}}), "", true); + sliceX0_starts->addChild(sliceX0, 0, 1); + auto sliceX0_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex0_end, splitIndex1_end}}), "", true); + sliceX0_ends->addChild(sliceX0, 0, 2); + auto sliceX0_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true); + sliceX0_axes->addChild(sliceX0, 0, 3); + auto sliceX0_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true); + sliceX0_steps->addChild(sliceX0, 0, 4); + auto matMulX0 = MatMul(); + auto identity1 = Identity(); + auto sliceX1 = Slice(); + auto sliceX1_starts = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{splitIndex0_start, splitIndex1_start}}), "", true); + sliceX1_starts->addChild(sliceX1, 0, 1); + auto sliceX1_ends = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{static_cast<std::int64_t>(sliceDims.end()[-2]), static_cast<std::int64_t>(sliceDims.end()[-1])}}), "", true); + sliceX1_ends->addChild(sliceX1, 0, 2); + auto sliceX1_axes = Producer(std::make_shared<Tensor>(Vector<std::int8_t>{{-2, -1}}), "", true); + sliceX1_axes->addChild(sliceX1, 0, 3); + auto sliceX1_steps = Producer(std::make_shared<Tensor>(Vector<std::int64_t>{{1, 1}}), "", true); + sliceX1_steps->addChild(sliceX1, 0, 4); + auto matMulX1 = MatMul(); + auto concat = Concat(2, axis); + + if (outputMatDims[0] > maxDims[0]) { + identity0->addChild(sliceX0, 0, 0); + identity0->addChild(sliceX1, 0, 0); + identity1->addChild(matMulX0, 0, 1); + identity1->addChild(matMulX1, 0, 1); + sliceX0->addChild(matMulX0, 0, 0); + sliceX1->addChild(matMulX1, 0, 0); + } + else { + identity0->addChild(matMulX0, 0, 0); + identity0->addChild(matMulX1, 0, 0); + identity1->addChild(sliceX0, 0, 0); + identity1->addChild(sliceX1, 0, 0); + sliceX0->addChild(matMulX0, 0, 1); + sliceX1->addChild(matMulX1, 0, 1); + } + + matMulX0->addChild(concat, 0, 0); + matMulX1->addChild(concat, 0, 1); + + auto gMatMul = std::make_shared<GraphView>(); + gMatMul->add({matMul}); + + auto g = std::make_shared<GraphView>(); + g->add({identity0}); + g->add({identity1}); + g->add({sliceX0, sliceX0_starts, sliceX0_ends, sliceX0_axes, sliceX0_steps, matMulX0, matMulX1, sliceX1, sliceX1_starts, sliceX1_ends, sliceX1_axes, sliceX1_steps, concat}); + + auto replaced = GraphView::replace(gMatMul, g); + + if (replaced) { + g->forwardDims({}, true); + + // Recursive tiling + matMulTiling(matMulX1, maxDims); + matMulTiling(matMulX0, maxDims); + } + else { + Log::warn("Unable to split MatMul {}", matMul->name()); + } + } +} diff --git a/src/recipes/RemoveNode.cpp b/src/recipes/RemoveNode.cpp index a09c67991409dfe491d46b4ad739f9ddf5b72aef..3a1bac588ee8a1bb38f74fee441c9eff07b4ef6e 100644 --- a/src/recipes/RemoveNode.cpp +++ b/src/recipes/RemoveNode.cpp @@ -13,24 +13,15 @@ #include "aidge/graph/Node.hpp" #include "aidge/graph/GraphView.hpp" +#include "aidge/graph/Matching.hpp" #include "aidge/recipes/Recipes.hpp" - -//Graph Regex -#include "aidge/graphRegex/GraphRegex.hpp" - size_t Aidge::removeNode(std::shared_ptr<GraphView> graphView, const std::string& type, bool incProducers) { - std::shared_ptr<GraphRegex> regex = std::make_shared<GraphRegex>(); - regex->setNodeKey(type, "getType($) =='" + type + "'"); - regex->addQuery(type + "#"); - - const auto matches = regex->match(graphView); - for (const auto& solution : matches) { - assert(solution->at(type).size() == 1 && "Wrong number of nodes to replace\n"); - - std::set<NodePtr> nodesToRemove = solution->at(type); + auto matches = SinglePassGraphMatching(graphView).match(type); + for (const auto& match : matches) { + std::set<NodePtr> nodesToRemove = {match.graph->rootNode()}; if (incProducers) { - for (const auto& nodePtr: (*solution->at(type).begin())->getParents()) { + for (const auto& nodePtr: match.graph->rootNode()->getParents()) { if (nodePtr != nullptr && nodePtr->type() == "Producer") { nodesToRemove.insert(nodePtr); }