forked from OSchip/llvm-project
[mlir][tensor] Add canonicalization for tensor.cast from extract_slice
Propagate static size information into extract_slice producer if possible. Differential Revision: https://reviews.llvm.org/D125972
This commit is contained in:
parent
d640442518
commit
f2676b151d
|
@ -229,11 +229,58 @@ struct ChainedTensorCast : public OpRewritePattern<CastOp> {
|
|||
}
|
||||
};
|
||||
|
||||
/// Fold tensor.cast into tesor.extract_slice producer.
|
||||
/// Example:
|
||||
/// ```
|
||||
/// %0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] :
|
||||
/// tensor<128x512xf32> to tensor<?x512xf32>
|
||||
/// %1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
|
||||
/// ```
|
||||
/// ->
|
||||
/// ```
|
||||
/// %1 = tensor.extract_slice %arg0[%o, 0] [16, 512] [1, 1] :
|
||||
/// tensor<128x512xf32> to tensor<16x512xf32>
|
||||
/// ```
|
||||
struct TensorCastExtractSlice : public OpRewritePattern<CastOp> {
|
||||
using OpRewritePattern<CastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(CastOp tensorCast,
|
||||
PatternRewriter &rewriter) const final {
|
||||
auto extractOperand =
|
||||
tensorCast.getOperand().getDefiningOp<ExtractSliceOp>();
|
||||
|
||||
if (!extractOperand || !canFoldIntoProducerOp(tensorCast) ||
|
||||
tensorCast.getType().getShape() ==
|
||||
tensorCast.source().getType().cast<RankedTensorType>().getShape())
|
||||
return failure();
|
||||
|
||||
SmallVector<OpFoldResult, 4> sizes = extractOperand.getMixedSizes();
|
||||
auto dimMask = computeRankReductionMask(
|
||||
extractFromI64ArrayAttr(extractOperand.static_sizes()),
|
||||
extractOperand.getType().getShape());
|
||||
size_t dimIndex = 0;
|
||||
for (size_t i = 0, e = sizes.size(); i < e; i++) {
|
||||
if (dimMask && dimMask->count(i))
|
||||
continue;
|
||||
int64_t dim = tensorCast.getType().getShape()[dimIndex++];
|
||||
if (ShapedType::isDynamic(dim))
|
||||
continue;
|
||||
sizes[i] = rewriter.getIndexAttr(dim);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<ExtractSliceOp>(
|
||||
tensorCast, tensorCast.getType().cast<RankedTensorType>(),
|
||||
extractOperand.source(), extractOperand.getMixedOffsets(), sizes,
|
||||
extractOperand.getMixedStrides());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
void CastOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<ChainedTensorCast>(context);
|
||||
results.add<ChainedTensorCast, TensorCastExtractSlice>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -30,9 +30,6 @@ func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
|
|||
return %3 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// CHECK: #[[BOUND2_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 2)>
|
||||
// CHECK: #[[BOUND4_MAP:.+]] = affine_map<(d0)[s0] -> (-d0 + s0, 4)>
|
||||
|
||||
// CHECK: func @matmul_tensors(
|
||||
// CHECK-SAME: %[[A:[0-9a-z]*]]: tensor<?x?xf32>
|
||||
// CHECK-SAME: %[[B:[0-9a-z]*]]: tensor<?x?xf32>
|
||||
|
@ -40,26 +37,20 @@ func.func @matmul_tensors(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>, %arg2:
|
|||
|
||||
// CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = arith.constant 1 : index
|
||||
// CHECK-DAG: %[[dA0:.*]] = tensor.dim %[[A]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[dA1:.*]] = tensor.dim %[[A]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[dB0:.*]] = tensor.dim %[[B]], %[[C0]] : tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[dB1:.*]] = tensor.dim %[[B]], %[[C1]] : tensor<?x?xf32>
|
||||
// CHECK: scf.for %[[I:[0-9a-z]*]]
|
||||
// CHECK: %[[sizeA0:.*]] = affine.min #[[BOUND2_MAP]](%[[I]])[%[[dA0]]]
|
||||
// CHECK: %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [%[[sizeA0]], %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[castA:.*]] = tensor.cast %[[stA]] : tensor<?x?xf32> to tensor<2x?xf32>
|
||||
// CHECK: %[[stA:.*]] = tensor.extract_slice %[[A]][%[[I]], 0] [2, %[[dA1]]] [1, 1] : tensor<?x?xf32> to tensor<2x?xf32>
|
||||
// CHECK: scf.for %[[J:[0-9a-z]*]]
|
||||
// CHECK-NEXT: scf.for %[[K:[0-9a-z]*]] {{.*}} iter_args(%[[RES:[0-9a-z]*]]
|
||||
// CHECK-DAG: %[[stB1:.*]] = tensor.extract_slice %[[B]][%[[K]], %[[J]]] [4, 3] [1, 1] : tensor<?x?xf32> to tensor<4x3xf32>
|
||||
// CHECK-DAG: %[[stF:.*]] = tensor.extract_slice %[[RES]][%[[I]], %[[J]]] [2, 3] [1, 1] : tensor<?x?xf32> to tensor<2x3xf32>
|
||||
//
|
||||
// slices of the producing matmul.
|
||||
// CHECK: %[[sizeB1:.*]] = affine.min #[[BOUND4_MAP]](%[[K]])[%[[dB1]]]
|
||||
// CHECK: %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], %[[sizeB1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK: %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [%[[sizeA0]], %[[sizeB1]]] [1, 1] : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK-DAG: %[[castC:.+]] = tensor.cast %[[stC]] : tensor<?x?xf32> to tensor<2x4xf32>
|
||||
// CHECK-DAG: %[[castB:.+]] = tensor.cast %[[stB2]] : tensor<?x?xf32> to tensor<?x4xf32>
|
||||
// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[castA]], %[[castB]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[castC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK-DAG: %[[stB2:.*]] = tensor.extract_slice %[[B]][0, %[[K]]] [%[[dB0]], 4] [1, 1] : tensor<?x?xf32> to tensor<?x4xf32>
|
||||
// CHECK-DAG: %[[stC:.*]] = tensor.extract_slice %[[C]][%[[I]], %[[K]]] [2, 4] [1, 1] : tensor<?x?xf32> to tensor<2x4xf32>
|
||||
// CHECK: %[[stD:.*]] = linalg.matmul ins(%[[stA]], %[[stB2]] : tensor<2x?xf32>, tensor<?x4xf32>) outs(%[[stC]] : tensor<2x4xf32>) -> tensor<2x4xf32>
|
||||
// CHECK-NEXT: %[[stG:.*]] = linalg.matmul ins(%[[stD]], %[[stB1]] : tensor<2x4xf32>, tensor<4x3xf32>) outs(%[[stF]] : tensor<2x3xf32>) -> tensor<2x3xf32>
|
||||
// CHECK-NEXT: tensor.insert_slice %[[stG]] into %[[RES]][%[[I]], %[[J]]]
|
||||
|
||||
|
|
|
@ -1401,3 +1401,27 @@ func.func @insert_slice_cast(%arg0 : tensor<1x?xf32>, %arg1 : tensor<?x?xf32>, %
|
|||
// CHECK: return %[[RES]] : tensor<?x?xf32>
|
||||
return %1 : tensor<?x?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cast_extract_slice
|
||||
func.func @cast_extract_slice(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
|
||||
-> tensor<16x512xf32> {
|
||||
// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 512] [1, 1] : tensor<128x512xf32> to tensor<16x512xf32>
|
||||
%0 = tensor.extract_slice %arg0[%o, 0] [%s, 512] [1, 1] : tensor<128x512xf32> to tensor<?x512xf32>
|
||||
%1 = tensor.cast %0 : tensor<?x512xf32> to tensor<16x512xf32>
|
||||
// CHECK: return %[[E]] : tensor<16x512xf32>
|
||||
return %1 : tensor<16x512xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @cast_extract_slice_rank_reduce
|
||||
func.func @cast_extract_slice_rank_reduce(%arg0 : tensor<128x512xf32>, %s : index, %o : index)
|
||||
-> tensor<16xf32> {
|
||||
// CHECK: %[[E:.*]] = tensor.extract_slice %{{.*}}[%{{.*}}, 0] [16, 1] [1, 1] : tensor<128x512xf32> to tensor<16xf32>
|
||||
%0 = tensor.extract_slice %arg0[%o, 0] [%s, 1] [1, 1] : tensor<128x512xf32> to tensor<?xf32>
|
||||
%1 = tensor.cast %0 : tensor<?xf32> to tensor<16xf32>
|
||||
// CHECK: return %[[E]] : tensor<16xf32>
|
||||
return %1 : tensor<16xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue