Skip to content
Snippets Groups Projects
Commit d12fa33d authored by Nicolas Vasilache's avatar Nicolas Vasilache
Browse files

[mlir] Add a TensorLoadToMemref canonicalization

A folder of `tensor_load + tensor_to_memref` exists but it only applies when
source and destination memref types are the same.

This revision adds a canonicalize `tensor_load + tensor_to_memref` to `memref_cast`
when type mismatches prevent folding to kick in.

Differential Revision: https://reviews.llvm.org/D97038
parent 0d829802
No related branches found
No related tags found
No related merge requests found
......@@ -3838,11 +3838,34 @@ struct TensorCastToMemref : public OpRewritePattern<TensorToMemrefOp> {
return success();
}
};
/// Canonicalize tensor_load + tensor_to_memref to memref_cast when type
/// mismatches prevent `TensorToMemrefOp::fold` to kick in.
struct TensorLoadToMemref : public OpRewritePattern<TensorToMemrefOp> {
using OpRewritePattern<TensorToMemrefOp>::OpRewritePattern;
LogicalResult matchAndRewrite(TensorToMemrefOp tensorToMemRef,
PatternRewriter &rewriter) const final {
auto tensorLoad = tensorToMemRef.tensor().getDefiningOp<TensorLoadOp>();
// Bail unless we have a tensor_load + tensor_to_memref with different
// types. `TensorToMemrefOp::fold` handles the same type case.
if (!tensorLoad ||
tensorLoad.memref().getType() == tensorToMemRef.getType())
return failure();
// If types are not cast-compatible, bail.
if (!MemRefCastOp::areCastCompatible(tensorLoad.memref().getType(),
tensorToMemRef.getType()))
return failure();
rewriter.replaceOpWithNewOp<MemRefCastOp>(
tensorToMemRef, tensorToMemRef.getType(), tensorLoad.memref());
return success();
}
};
} // namespace
void TensorToMemrefOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<TensorCastToMemref>(context);
results.insert<TensorCastToMemref, TensorLoadToMemref>(context);
}
//===----------------------------------------------------------------------===//
......
// RUN: mlir-opt %s -canonicalize | FileCheck %s
// RUN: mlir-opt %s -canonicalize --split-input-file | FileCheck %s
// -----
// Test case: Basic folding of tensor_load(tensor_to_memref(t)) -> t
// CHECK-LABEL: func @tensor_load_of_tensor_to_memref(
......@@ -10,6 +12,8 @@ func @tensor_load_of_tensor_to_memref(%arg0: tensor<?xf32>) -> tensor<?xf32> {
return %1 : tensor<?xf32>
}
// -----
// Test case: Basic folding of tensor_to_memref(tensor_load(m)) -> m
// CHECK-LABEL: func @tensor_to_memref_of_tensor_load(
// CHECK-SAME: %[[MEMREF:.*]]: memref<?xf32>) -> memref<?xf32> {
......@@ -20,7 +24,11 @@ func @tensor_to_memref_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
return %1 : memref<?xf32>
}
// -----
// Test case: If the memrefs are not the same type, don't fold them.
// Test case: If the memrefs are not cast-compatible (e.g. different address space),
// don't canonicalize them either.
// CHECK-LABEL: func @no_fold_tensor_to_memref_of_tensor_load(
// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>) -> memref<?xf32, 7> {
// CHECK: %[[TENSOR:.*]] = tensor_load %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2>
......@@ -32,6 +40,28 @@ func @no_fold_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, 2>) -> memref
return %1 : memref<?xf32, 7>
}
// -----
// CHECK-DAG: #[[$OFF_3:[a-z0-9]+]] = affine_map<(d0) -> (d0 + 3)>
// CHECK-DAG: #[[$OFF_UNK:[a-z0-9]+]] = affine_map<(d0)[s0] -> (d0 + s0)>
// Test case: If the memrefs are cast-compatible, canonicalize.
// CHECK-LABEL: func @canonicalize_tensor_to_memref_of_tensor_load(
// CHECK-SAME: %[[M:.*]]: memref<?xf32, #[[$OFF_3]]>) -> memref<?xf32, #[[$OFF_UNK]]> {
// CHECK-NOT: tensor_load
// CHECK-NOT: tensor_to_memref
// CHECK: %[[R:.*]] = memref_cast %[[M]] : memref<?xf32, #[[$OFF_3]]> to memref<?xf32, #[[$OFF_UNK]]>
// CHECK: return %[[R]]
func @canonicalize_tensor_to_memref_of_tensor_load(%arg0: memref<?xf32, offset: 3, strides: [1]>)
-> memref<?xf32, offset: ?, strides: [1]>
{
%0 = tensor_load %arg0 : memref<?xf32, offset: 3, strides: [1]>
%1 = tensor_to_memref %0 : memref<?xf32, offset: ?, strides: [1]>
return %1 : memref<?xf32, offset: ?, strides: [1]>
}
// -----
// Test case: Basic folding of dim(tensor_load(m)) -> dim(m).
// CHECK-LABEL: func @dim_of_tensor_load(
// CHECK-SAME: %[[MEMREF:[0-9a-z]*]]: memref<?xf32>
......@@ -45,6 +75,8 @@ func @dim_of_tensor_load(%arg0: memref<?xf32>) -> index {
return %1 : index
}
// -----
// Test case: Folding of load(tensor_to_memref(%v, %idxs))
// -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_tensor_to_memref(
......@@ -59,6 +91,8 @@ func @load_from_tensor_to_memref(%arg0: index, %arg1: index, %arg2: tensor<?x?xf
return %1 : f32
}
// -----
// Test case: Folding of dim(tensor.generate %idx) -> %idx
// CHECK-LABEL: func @dim_of_tensor.generate(
// CHECK-SAME: %[[IDX0:[0-9a-z]+]]: index, %[[IDX1:[0-9a-z]+]]: index
......@@ -74,6 +108,8 @@ func @dim_of_tensor.generate(%arg0: index, %arg1: index) -> index {
return %1 : index
}
// -----
// Test case: Folding of comparisons with equal operands.
// CHECK-LABEL: @cmpi_equal_operands
// CHECK-DAG: %[[T:.*]] = constant true
......@@ -96,6 +132,8 @@ func @cmpi_equal_operands(%arg0: i64)
: i1, i1, i1, i1, i1, i1, i1, i1, i1, i1
}
// -----
// Test case: Folding of dim(memref_reshape %v %shp, %idx) -> load %shp[%idx]
// CHECK-LABEL: func @dim_of_memref_reshape(
// CHECK-SAME: %[[MEM:[0-9a-z]+]]: memref<*xf32>,
......@@ -116,6 +154,8 @@ func @dim_of_memref_reshape(%arg0: memref<*xf32>, %arg1: memref<?xindex>)
return %1 : index
}
// -----
// Test case: Folding dim(tensor.cast %0, %idx) -> dim %0, %idx
// CHECK-LABEL: func @fold_dim_of_tensor.cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x?xf32>
......@@ -132,6 +172,8 @@ func @fold_dim_of_tensor.cast(%arg0 : tensor<4x?xf32>) -> (index, index) {
return %1, %2: index, index
}
// -----
// CHECK-LABEL: func @tensor_cast_to_memref
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
// CHECK: %[[M:.+]] = tensor_to_memref %[[ARG0]] : memref<4x6x16x32xi8>
......@@ -144,6 +186,8 @@ func @tensor_cast_to_memref(%arg0 : tensor<4x6x16x32xi8>) ->
return %1 : memref<?x?x16x32xi8>
}
// -----
// CHECK-LABEL: func @subview_of_memcast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: memref<4x6x16x32xi8>
// CHECK: %[[S:.+]] = subview %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : memref<4x6x16x32xi8> to memref<16x32xi8, #{{.*}}>
......@@ -158,6 +202,8 @@ func @subview_of_memcast(%arg : memref<4x6x16x32xi8>) ->
return %1 : memref<16x32xi8, affine_map<(d0, d1)[s0] -> (d0 * 32 + d1 + s0)>>
}
// -----
// CHECK-LABEL: func @trivial_subtensor
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK-NOT: subtensor
......@@ -167,6 +213,8 @@ func @trivial_subtensor(%arg0 : tensor<4x6x16x32xi8>) -> tensor<4x6x16x32xi8> {
return %0 : tensor<4x6x16x32xi8>
}
// -----
// CHECK-LABEL: func @trivial_subtensor_insert
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK-NOT: subtensor
......@@ -176,6 +224,8 @@ func @trivial_subtensor_insert(%arg0 : tensor<4x6x16x32xi8>, %arg1 : tensor<4x6x
return %0 : tensor<4x6x16x32xi8>
}
// -----
// CHECK-LABEL: func @rank_reducing_tensor_of_cast
// CHECK-SAME: %[[ARG0:.[a-z0-9A-Z_]+]]: tensor<4x6x16x32xi8>
// CHECK: %[[S:.+]] = subtensor %arg0[0, 1, 0] [1, 1, 16] [1, 1, 1] : tensor<4x6x16x32xi8> to tensor<16x32xi8>
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment