forked from OSchip/llvm-project
[mlir][linalg] Remove generic PadTensorOp vectorization pattern
The generic vectorization pattern handles only those cases, where low and high padding is zero. This is already handled by a canonicalization pattern. Also add a new canonicalization test case to ensure that tensor cast ops are properly inserted. A more general vectorization pattern will be added in a subsequent commit. Differential Revision: https://reviews.llvm.org/D103590
This commit is contained in:
parent
4fa8677860
commit
fdb21f0c5e
|
@ -671,52 +671,6 @@ static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
|
|||
return result;
|
||||
}
|
||||
|
||||
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
|
||||
/// TransferWriteOp. For now, this only applies when all low and high paddings
|
||||
/// are determined to be zero.
|
||||
struct GenericPadTensorOpVectorizationPattern
|
||||
: public OpRewritePattern<PadTensorOp> {
|
||||
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PadTensorOp padOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
/// Given an OpFoldResult, return true if its value is guaranteed to be a
|
||||
/// zero integer.
|
||||
auto isZeroInt = [&](OpFoldResult ofr) {
|
||||
return isEqualConstantIntOrValue(ofr, rewriter.getIndexAttr(0)); };
|
||||
// Low padding must be static 0.
|
||||
if (!llvm::all_of(padOp.getMixedLowPad(), isZeroInt)) return failure();
|
||||
// High padding must be static 0.
|
||||
if (!llvm::all_of(padOp.getMixedHighPad(), isZeroInt)) return failure();
|
||||
// Pad value must be a constant.
|
||||
auto padValue = padOp.getConstantPaddingValue();
|
||||
if (!padValue) return failure();
|
||||
|
||||
// Bail on non-static shapes.
|
||||
auto resultShapedType = padOp.result().getType().cast<ShapedType>();
|
||||
if (!resultShapedType.hasStaticShape())
|
||||
return failure();
|
||||
VectorType vectorType = extractVectorTypeFromShapedValue(padOp.result());
|
||||
if (!vectorType)
|
||||
return failure();
|
||||
|
||||
// Now we can rewrite as InitTensorOp + TransferReadOp@[0..0] +
|
||||
// TransferWriteOp@[0..0].
|
||||
SmallVector<Value> indices(
|
||||
resultShapedType.getRank(),
|
||||
rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
|
||||
Value read = rewriter.create<vector::TransferReadOp>(
|
||||
padOp.getLoc(), vectorType, padOp.source(), indices, padValue);
|
||||
Value init = rewriter.create<InitTensorOp>(
|
||||
padOp.getLoc(), resultShapedType.getShape(),
|
||||
resultShapedType.getElementType());
|
||||
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(padOp, read, init,
|
||||
indices);
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
|
||||
/// operation type OpTy.
|
||||
template <typename OpTy>
|
||||
|
@ -995,13 +949,14 @@ struct PadTensorOpVectorizationWithSubTensorInsertPattern
|
|||
|
||||
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
|
||||
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
|
||||
patterns.add<GenericPadTensorOpVectorizationPattern>(
|
||||
patterns.getContext(), baseBenefit);
|
||||
// TODO: Canonicalizer handles simple cases where low = 0 and high = 0, but a
|
||||
// generic vectorization pattern is still missing.
|
||||
|
||||
// Try these specialized patterns first before resorting to the generic one.
|
||||
patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
|
||||
PadTensorOpVectorizationWithTransferWritePattern,
|
||||
PadTensorOpVectorizationWithSubTensorInsertPattern>(
|
||||
patterns.getContext(), baseBenefit.getBenefit() + 1);
|
||||
patterns.getContext(), baseBenefit);
|
||||
}
|
||||
|
||||
// TODO: cleanup all the convolution vectorization patterns.
|
||||
|
|
|
@ -1148,3 +1148,21 @@ func @tensor_pad_cast_fold(%arg0: tensor<4x4xf32>) -> tensor<4x4xf32> {
|
|||
// CHECK-LABEL: @tensor_pad_cast
|
||||
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x4xf32>
|
||||
// CHECK: return %[[ARG0]]
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pad_static_zero_cast(
|
||||
// CHECK-SAME: %[[ARG0:.*]]: tensor<?x?x?xf32>
|
||||
// CHECK-NOT: linalg.pad_tensor
|
||||
// CHECK: %[[RESULT:.*]] = tensor.cast %[[ARG0]] : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||
// CHECK: return %[[RESULT]]
|
||||
func @pad_static_zero_cast(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
|
||||
%c0 = constant 0 : index
|
||||
%0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index):
|
||||
linalg.yield %pad_value : f32
|
||||
} : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||
|
||||
return %0 : tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
|
|
|
@ -512,27 +512,6 @@ func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12x
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pad_static
|
||||
// CHECK-NOT: linalg.pad_tensor
|
||||
func @pad_static(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[READ:.*]] = vector.transfer_read %{{.*}}[%[[C0]], %[[C0]], %[[C0]]]
|
||||
// CHECK-SAME: : tensor<?x?x?xf32>, vector<2x3x4xf32>
|
||||
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 3, 4] : tensor<2x3x4xf32>
|
||||
// CHECK: %[[WRITTEN:.*]] = vector.transfer_write %[[READ]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]]
|
||||
// CHECK-SAME: {in_bounds = [true, true, true]} : vector<2x3x4xf32>, tensor<2x3x4xf32>
|
||||
%c0 = constant 0 : index
|
||||
%0 = linalg.pad_tensor %arg0 low[0, %c0, 0] high[0, 0, %c0] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index):
|
||||
linalg.yield %pad_value : f32
|
||||
} : tensor<?x?x?xf32> to tensor<2x3x4xf32>
|
||||
|
||||
// CHECK: return %[[WRITTEN]] : tensor<2x3x4xf32>
|
||||
return %0 : tensor<2x3x4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @pad_static_high_padding
|
||||
// CHECK: linalg.pad_tensor
|
||||
func @pad_static_high_padding(%arg0: tensor<?x?x?xf32>, %pad_value: f32) -> tensor<2x3x4xf32> {
|
||||
|
|
Loading…
Reference in New Issue