forked from OSchip/llvm-project
[mlir][linalg] Vectorize linalg.pad_op source copying (static source shape)
If the source operand of a linalg.pad_op operation has static shape, vectorize the copying of the source. Differential Revision: https://reviews.llvm.org/D103747
This commit is contained in:
parent
98fff5153a
commit
4c2f3d810b
|
@ -673,10 +673,8 @@ static SmallVector<Value> ofrToIndexValues(OpBuilder &builder, Location loc,
|
||||||
|
|
||||||
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and
|
/// Rewrite a PadTensorOp into a sequence of InitTensorOp, FillOp and
|
||||||
/// SubTensorInsertOp. For now, only constant padding values are supported.
|
/// SubTensorInsertOp. For now, only constant padding values are supported.
|
||||||
/// Note: This rewrite is not yet a vectorization, but some of the generated ops
|
/// If there is enough static type information, TransferReadOps and
|
||||||
/// may be vectorized down the line (e.g., FillOp).
|
/// TransferWriteOps may be generated instead of SubTensorInsertOps.
|
||||||
/// TODO: If there is enough static shape information, generate TransferReadOps
|
|
||||||
/// and TransferWriteOps instead of SubTensorInsertOp.
|
|
||||||
struct GenericPadTensorOpVectorizationPattern
|
struct GenericPadTensorOpVectorizationPattern
|
||||||
: public OpRewritePattern<PadTensorOp> {
|
: public OpRewritePattern<PadTensorOp> {
|
||||||
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
||||||
|
@ -720,6 +718,20 @@ struct GenericPadTensorOpVectorizationPattern
|
||||||
rewriter.create<FillOp>(padOp.getLoc(), init, padValue).result();
|
rewriter.create<FillOp>(padOp.getLoc(), init, padValue).result();
|
||||||
|
|
||||||
auto sourceType = padOp.getSourceType();
|
auto sourceType = padOp.getSourceType();
|
||||||
|
|
||||||
|
// Copy of source with static shape can be vectorized.
|
||||||
|
if (sourceType.hasStaticShape()) {
|
||||||
|
auto vecType = VectorType::get(sourceType.getShape(),
|
||||||
|
sourceType.getElementType());
|
||||||
|
vectorizeStaticShapeSource(rewriter, padOp, fill, vecType);
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: Vectorize dynamic source but static destination.
|
||||||
|
|
||||||
|
// Neither source type nor PadTensorOp result type have static shape. Such
|
||||||
|
// PadTensorOps cannot be vectorized. Generate a SubTensorInsertOp instead.
|
||||||
|
|
||||||
// Compute size of source of PadTensorOp.
|
// Compute size of source of PadTensorOp.
|
||||||
SmallVector<OpFoldResult> srcSizes;
|
SmallVector<OpFoldResult> srcSizes;
|
||||||
for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
|
for (unsigned dim = 0; dim < sourceType.getRank(); ++dim) {
|
||||||
|
@ -738,6 +750,25 @@ struct GenericPadTensorOpVectorizationPattern
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Vectorize the copying of a PadTensorOp's source that has static shape.
|
||||||
|
void vectorizeStaticShapeSource(PatternRewriter &rewriter, PadTensorOp padOp,
|
||||||
|
Value dest, VectorType vecType) const {
|
||||||
|
// Generate TransferReadOp.
|
||||||
|
SmallVector<Value> readIndices(
|
||||||
|
vecType.getRank(), rewriter.create<ConstantIndexOp>(padOp.getLoc(), 0));
|
||||||
|
auto read = rewriter.create<vector::TransferReadOp>(
|
||||||
|
padOp.getLoc(), vecType, padOp.source(), readIndices);
|
||||||
|
|
||||||
|
// Generate TransferWriteOp. The destination dimensions may be dynamic, but
|
||||||
|
// the write cannot be out-of-bounds. (A large enough destination tensor is
|
||||||
|
// allocated in this pattern.)
|
||||||
|
auto writeIndices = ofrToIndexValues(
|
||||||
|
rewriter, padOp.getLoc(), padOp.getMixedLowPad());
|
||||||
|
SmallVector<bool> inBounds(vecType.getRank(), true);
|
||||||
|
rewriter.replaceOpWithNewOp<vector::TransferWriteOp>(
|
||||||
|
padOp, read, dest, writeIndices, inBounds);
|
||||||
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
|
/// Base pattern for rewriting PadTensorOps whose result is consumed by a given
|
||||||
|
|
|
@ -532,6 +532,27 @@ func @pad_static(%arg0: tensor<2x?x2xf32>, %pad_value: f32) -> tensor<2x3x4xf32>
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @pad_static_source(
|
||||||
|
// CHECK-SAME: %[[ARG0:.*]]: tensor<2x5x2xf32>, %[[PAD:.*]]: f32
|
||||||
|
// CHECK-NOT: linalg.pad_tensor
|
||||||
|
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||||
|
// CHECK-DAG: %[[C2:.*]] = constant 2 : index
|
||||||
|
// CHECK: %[[INIT:.*]] = linalg.init_tensor [2, 6, 4] : tensor<2x6x4xf32>
|
||||||
|
// CHECK: %[[VEC:.*]] = vector.broadcast %[[PAD]] : f32 to vector<2x6x4xf32>
|
||||||
|
// CHECK: %[[FILL:.*]] = vector.transfer_write %[[VEC]], %[[INIT]][%[[C0]], %[[C0]], %[[C0]]] {in_bounds = [true, true, true]} : vector<2x6x4xf32>, tensor<2x6x4xf32>
|
||||||
|
// CHECK: %[[READ:.*]] = vector.transfer_read %[[ARG0]][%[[C0]], %[[C0]], %[[C0]]], %{{.*}} {in_bounds = [true, true, true]} : tensor<2x5x2xf32>, vector<2x5x2xf32>
|
||||||
|
// CHECK: %[[WRITE:.*]] = vector.transfer_write %[[READ]], %[[FILL]][%[[C0]], %[[C0]], %[[C2]]] {in_bounds = [true, true, true]} : vector<2x5x2xf32>, tensor<2x6x4xf32>
|
||||||
|
// CHECK: return %[[WRITE]]
|
||||||
|
func @pad_static_source(%arg0: tensor<2x5x2xf32>, %pad_value: f32) -> tensor<2x6x4xf32> {
|
||||||
|
%0 = linalg.pad_tensor %arg0 low[0, 0, 2] high[0, 1, 0] {
|
||||||
|
^bb0(%arg1: index, %arg2: index, %arg3: index):
|
||||||
|
linalg.yield %pad_value : f32
|
||||||
|
} : tensor<2x5x2xf32> to tensor<2x6x4xf32>
|
||||||
|
return %0 : tensor<2x6x4xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @pad_static_dynamic(
|
// CHECK-LABEL: func @pad_static_dynamic(
|
||||||
// CHECK-SAME: %[[SRC:.*]]: tensor<1x2x2x?xf32>, %[[LOW:.*]]: index, %[[HIGH:.*]]: index
|
// CHECK-SAME: %[[SRC:.*]]: tensor<1x2x2x?xf32>, %[[LOW:.*]]: index, %[[HIGH:.*]]: index
|
||||||
// CHECK-NOT: linalg.pad_tensor
|
// CHECK-NOT: linalg.pad_tensor
|
||||||
|
|
Loading…
Reference in New Issue