forked from OSchip/llvm-project
[mlir] Vectorize linalg.pad_tensor consumed by subtensor_insert
Vectorize linalg.pad_tensor without generating a linalg.init_tensor when consumed by a subtensor_insert. Differential Revision: https://reviews.llvm.org/D103780
This commit is contained in:
parent
b1b822714d
commit
b1fd8a13cc
|
@ -650,6 +650,27 @@ mlir::linalg::vectorizeLinalgOp(OpBuilder &b, Operation *op,
|
||||||
// Misc. vectorization patterns.
|
// Misc. vectorization patterns.
|
||||||
//----------------------------------------------------------------------------//
|
//----------------------------------------------------------------------------//
|
||||||
|
|
||||||
|
/// Helper function that retrieves the value of an IntegerAttr.
|
||||||
|
static int64_t getIntFromAttr(Attribute attr) {
|
||||||
|
return attr.cast<IntegerAttr>().getInt();
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Given an ArrayRef of OpFoldResults, return a vector of Values. IntegerAttrs
|
||||||
|
/// are converted to ConstantIndexOps. Other attribute types are not supported.
|
||||||
|
static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
|
||||||
|
ArrayRef<OpFoldResult> ofrs) {
|
||||||
|
SmallVector<Value> result;
|
||||||
|
llvm::for_each(ofrs, [&](auto o) {
|
||||||
|
if (auto val = o.template dyn_cast<Value>()) {
|
||||||
|
result.push_back(val);
|
||||||
|
} else {
|
||||||
|
result.push_back(builder.create<ConstantIndexOp>(
|
||||||
|
loc, getIntFromAttr(o.template get<Attribute>())));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
|
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, TransferReadOp and
|
||||||
/// TransferWriteOp. For now, this only applies when all low and high paddings
|
/// TransferWriteOp. For now, this only applies when all low and high paddings
|
||||||
/// are determined to be zero.
|
/// are determined to be zero.
|
||||||
|
@ -763,12 +784,87 @@ struct PadTensorOpVectorizationWithTransferReadPattern
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Rewrite use of PadTensorOp result in SubtensorInsertOp. E.g.:
|
||||||
|
/// ```
|
||||||
|
/// %0 = linalg.pad_tensor %src ... : tensor<?x?xf32> to tensor<17x5xf32>
|
||||||
|
/// %r = subtensor_insert %0 into %dest[%a, %b, 0, 0] [1, 1, 17, 5] [1, 1, 1, 1]
|
||||||
|
/// : tensor<17x5xf32> into tensor<?x?x17x5xf32>
|
||||||
|
/// ```
|
||||||
|
/// is rewritten to:
|
||||||
|
/// ```
|
||||||
|
/// %0 = vector.transfer_read %src[%c0, %c0], %padding
|
||||||
|
/// : tensor<?x?xf32>, vector<17x5xf32>
|
||||||
|
/// %r = vector.transfer_write %0, %dest[%a, %b, %c0, %c0]
|
||||||
|
/// {in_bounds = [true, true]} : vector<17x5xf32>, tensor<?x?x17x5xf32>
|
||||||
|
/// ```
|
||||||
|
///
|
||||||
|
/// This rewrite is possible if:
|
||||||
|
/// - Low padding is static 0.
|
||||||
|
/// - `padOp` result shape is static.
|
||||||
|
/// - The entire padded tensor is inserted.
|
||||||
|
/// (Implies that sizes of `insertOp` are all static.)
|
||||||
|
/// - Only unit strides in `insertOp`.
|
||||||
|
/// - Single, scalar padding value.
|
||||||
|
struct PadTensorOpVectorizationWithSubTensorInsertPattern
|
||||||
|
: public VectorizePadTensorOpUserPattern<SubTensorInsertOp> {
|
||||||
|
using VectorizePadTensorOpUserPattern<
|
||||||
|
SubTensorInsertOp>::VectorizePadTensorOpUserPattern;
|
||||||
|
|
||||||
|
LogicalResult rewriteUser(PatternRewriter &rewriter, PadTensorOp padOp,
|
||||||
|
SubTensorInsertOp insertOp) const override {
|
||||||
|
// Low padding must be static 0.
|
||||||
|
if (!padOp.hasZeroLowPad()) return failure();
|
||||||
|
// Only unit stride supported.
|
||||||
|
if (!insertOp.hasUnitStride()) return failure();
|
||||||
|
// Pad value must be a constant.
|
||||||
|
auto padValue = padOp.getConstantPaddingValue();
|
||||||
|
if (!padValue)
|
||||||
|
return failure();
|
||||||
|
// Dynamic shapes not supported.
|
||||||
|
if (!padOp.result().getType().cast<ShapedType>().hasStaticShape())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto vecType = VectorType::get(padOp.getType().getShape(),
|
||||||
|
padOp.getType().getElementType());
|
||||||
|
unsigned vecRank = vecType.getRank();
|
||||||
|
unsigned tensorRank = insertOp.getType().getRank();
|
||||||
|
|
||||||
|
// Check if sizes match: Insert the entire tensor into most minor dims.
|
||||||
|
// (No permutations allowed.)
|
||||||
|
SmallVector<int64_t> expectedSizes(tensorRank - vecRank, 1);
|
||||||
|
expectedSizes.append(vecType.getShape().begin(), vecType.getShape().end());
|
||||||
|
if (!llvm::all_of(
|
||||||
|
llvm::zip(insertOp.getMixedSizes(), expectedSizes),
|
||||||
|
[](auto it) { return isEqualConstantInt(std::get<0>(it),
|
||||||
|
std::get<1>(it)); }))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
// Generate TransferReadOp: Read entire source tensor and add high padding.
|
||||||
|
SmallVector<Value> readIndices(
|
||||||
|
vecRank, rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
|
||||||
|
auto read = rewriter.create<vector::TransferReadOp>(
|
||||||
|
padOp.getLoc(), vecType, padOp.source(), readIndices, padValue);
|
||||||
|
|
||||||
|
// Generate TransferWriteOp: Write to SubTensorInsertOp's dest tensor at
|
||||||
|
// specified offsets. Write is fully in-bounds because a SubTensorInsertOp's
|
||||||
|
// source must fit into the destination at the specified offsets.
|
||||||
|
auto writeIndices =
|
||||||
|
ofrToIndexValues(rewriter, padOp.getLoc(), insertOp.getMixedOffsets());
|
||||||
|
SmallVector<bool> inBounds(vecRank, true);
|
||||||
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||||
|
insertOp, read, insertOp.dest(), writeIndices, inBounds);
|
||||||
|
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
|
void mlir::linalg::populatePadTensorOpVectorizationPatterns(
|
||||||
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
|
RewritePatternSet &patterns, PatternBenefit baseBenefit) {
|
||||||
patterns.add<GenericPadTensorOpVectorizationPattern>(
|
patterns.add<GenericPadTensorOpVectorizationPattern>(
|
||||||
patterns.getContext(), baseBenefit);
|
patterns.getContext(), baseBenefit);
|
||||||
// Try these specialized patterns first before resorting to the generic one.
|
// Try these specialized patterns first before resorting to the generic one.
|
||||||
patterns.add<PadTensorOpVectorizationWithTransferReadPattern>(
|
patterns.add<PadTensorOpVectorizationWithTransferReadPattern,
|
||||||
|
PadTensorOpVectorizationWithSubTensorInsertPattern>(
|
||||||
patterns.getContext(), baseBenefit.getBenefit() + 1);
|
patterns.getContext(), baseBenefit.getBenefit() + 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -580,6 +580,28 @@ func @pad_and_transfer_read(%arg0: tensor<5x6xf32>) -> vector<7x9xf32> {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @pad_and_subtensor_insert
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<5x6xf32>, %[[ARG1:.*]]: tensor<12x13xf32>
|
||||||
|
// CHECK-NOT: linalg.pad_tensor
|
||||||
|
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C5:.*]] = constant 5.0
|
||||||
|
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]]], %[[C5]] : tensor<5x6xf32>, vector<7x9xf32>
|
||||||
|
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[ARG1]][%[[C0]], %[[C0]]] {in_bounds = [true, true]} : vector<7x9xf32>, tensor<12x13xf32>
|
||||||
|
// CHECK: return %[[WRITE]]
|
||||||
|
func @pad_and_subtensor_insert(
|
||||||
|
%arg0: tensor<5x6xf32>, %arg1: tensor<12x13xf32>) -> tensor<12x13xf32> {
|
||||||
|
%c0 = constant 0 : index
|
||||||
|
%c5 = constant 5.0 : f32
|
||||||
|
%0 = linalg.pad_tensor %arg0 low[0, 0] high[2, 3] {
|
||||||
|
^bb0(%arg2: index, %arg3: index):
|
||||||
|
linalg.yield %c5 : f32
|
||||||
|
} : tensor<5x6xf32> to tensor<7x9xf32>
|
||||||
|
%r = subtensor_insert %0 into %arg1[0, 0][7, 9][1, 1] : tensor<7x9xf32> into tensor<12x13xf32>
|
||||||
|
return %r : tensor<12x13xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)>
|
// CHECK-DAG: #[[$M0:.*]] = affine_map<(d0, d1) -> (d0, d1, 0)>
|
||||||
|
|
||||||
// CHECK-LABEL: func @sum_exp
|
// CHECK-LABEL: func @sum_exp
|
||||||
|
|
Loading…
Reference in New Issue