[mlir] Vectorize linalg.pad_tensor consumed by subtensor_insert

Vectorize linalg.pad_tensor without generating a linalg.init_tensor when consumed by a subtensor_insert.

Differential Revision: https://reviews.llvm.org/D103780
This commit is contained in:
Matthias Springer 2021-06-14 09:58:26 +09:00
parent b1b822714d
commit b1fd8a13cc
2 changed files with 119 additions and 1 deletions

View File

@ -650,6 +650,27 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
// Misc. vectorization patterns.
//----------------------------------------------------------------------------//
/// Helper function that retrieves the value of an IntegerAttr.
static int64_t getIntFromAttr(Attribute attr) {
return attr.cast<IntegerAttr>().getInt();
}
/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
/// are converted to ConstantIndexOps. Other attribute types are not supported.
static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
ArrayRef<OpFoldResult> ofrs) {
SmallVector<Value> result;
llvm::for_each(ofrs, [&](auto o) {
if (auto val = o.template dyn_cast<Value>()) {
result.push_back(val);
} else {
result.push_back(builder.create<ConstantIndexOp>(
loc, getIntFromAttr(o.template get<Attribute>())));
}
});
return result;
}
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
/// TransferWriteOp. For now, this only applies when all low and high paddings
/// are determined to be zero.
@ -763,12 +784,87 @@ struct PadTensorOpVectorizationWithTransferReadPattern
}
};
/// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.:
/// ```
/// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
/// %r = subtensor_insert %0 into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
/// ```
/// is rewritten to:
/// ```
/// %0 = vector.transfer_read %src[%c0, %c0], %padding
/// : tensor<?x?xf32>, vector<17x5xf32>
/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
/// ```
///
/// This rewrite is possible if:
/// - Low padding is static 0.
/// - `padOp` result shape is static.
/// - The entire padded tensor is inserted.
/// (Implies that sizes of `insertOp` are all static.)
/// - Only unit strides in `insertOp`.
/// - Single, scalar padding value.
struct PadTensorOpVectorizationWithSubTensorInsertPattern
: public VectorizePadTensorOpUserPattern<SubTensorInsertOp> {
using VectorizePadTensorOpUserPattern<
SubTensorInsertOp>::VectorizePadTensorOpUserPattern;
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
SubTensorInsertOp insertOp) const override {
// Low padding must be static 0.
if (!padOp.hasZeroLowPad()) return failure();
// Only unit stride supported.
if (!insertOp.hasUnitStride()) return failure();
// Pad value must be a constant.
auto padValue = padOp.getConstantPaddingValue();
if (!padValue)
return failure();
// Dynamic shapes not supported.
if (!padOp.result().getType().cast<ShapedType>().hasStaticShape())
return failure();
auto vecType = VectorType::get(padOp.getType().getShape(),
padOp.getType().getElementType());
unsigned vecRank = vecType.getRank();
unsigned tensorRank = insertOp.getType().getRank();
// Check if sizes match: Insert the entire tensor into most minor dims.
// (No permutations allowed.)
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
if (!llvm::all_of(
llvm::zip(insertOp.getMixedSizes(), expectedSizes),
[](auto it) { return isEqualConstantInt(std::get<0>(it),
std::get<1>(it)); }))
return failure();
// Generate TransferReadOp: Read entire source tensor and add high padding.
SmallVector<Value> readIndices(
vecRank, rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
auto read = rewriter.create<vector::TransferReadOp>(
padOp.getLoc(), vecType, padOp.source(), readIndices, padValue);
// Generate TransferWriteOp: Write to SubTensorInsertOp's dest tensor at
// specified offsets. Write is fully in-bounds because a SubTensorInsertOp's
// source must fit into the destination at the specified offsets.
auto writeIndices =
ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
SmallVector<bool> inBounds(vecRank, true);
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
insertOp, read, insertOp.dest(), writeIndices, inBounds);
return success();
}
};
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
patterns.add<GenericPadTensorOpVectorizationPattern>(
patterns.getContext(), baseBenefit);
// Try these specialized patterns first before resorting to the generic one.
patterns.add<PadTensorOpVectorizationWithTransferReadPattern>(
patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
PadTensorOpVectorizationWithSubTensorInsertPattern>(
patterns.getContext(), baseBenefit.getBenefit() + 1);
}

View File

@ -580,6 +580,28 @@ func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
// -----
// CHECK-LABEL: func @pad_and_subtensor_insert
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32>
// CHECK-NOT: linalg.pad_tensor
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C5:.*]] = constant 5.0
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
// CHECK: return %[[WRITE]]
func @pad_and_subtensor_insert(
%arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> {
%c0 = constant 0 : index
%c5 = constant 5.0 : f32
%0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] {
^bb0(%arg2: index, %arg3: index):
linalg.yield %c5 : f32
} : tensor<5x6xf32> to tensor<7x9xf32>
%r = subtensor_insert %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32>
return %r : tensor<12x13xf32>
}
// -----
// CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)>
// CHECK-LABEL: func @sum_exp