[mlir] add pad_tensor(tensor.cast) -> pad_tensor canonicalizer

This canonicalization pattern complements the tensor.cast(pad_tensor) one in
propagating constant type information when possible. It contributes to the
feasibility of pad hoisting.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D110343
This commit is contained in:
Alex Zinenko 2021-09-23 18:37:28 +02:00
parent 751be2a064
commit 3f89e339bb
7 changed files with 101 additions and 29 deletions

View File

@ -53,6 +53,10 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
namespace mlir {
namespace tensor {
/// Returns true if `target` is a ranked tensor type that preserves static
/// information available in the `source` ranked tensor type.
bool preservesStaticInformation(Type source, Type target);
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may

View File

@ -1482,11 +1482,41 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
return success();
}
};
// Fold CastOp using the result of PadTensorOp back into the latter if it adds
// static information.
struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
PatternRewriter &rewriter) const override {
if (!padTensorOp.result().hasOneUse())
return failure();
auto tensorCastOp =
dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
if (!tensorCastOp)
return failure();
if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
tensorCastOp.dest().getType()))
return failure();
auto replacementOp = rewriter.create<PadTensorOp>(
padTensorOp.getLoc(), tensorCastOp.dest().getType(),
padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
padTensorOp.static_low(), padTensorOp.static_high());
replacementOp.region().takeBody(padTensorOp.region());
rewriter.replaceOp(padTensorOp, replacementOp.result());
rewriter.replaceOp(tensorCastOp, replacementOp.result());
return success();
}
};
} // namespace
void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
MLIRContext *context) {
results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
results.add<FoldTargetTensorCast>(context);
}
/// Return the padding value of the PadTensorOp if it constant. In this context,

View File

@ -31,6 +31,34 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
// CastOp
//===----------------------------------------------------------------------===//
/// Returns true if `target` is a ranked tensor type that preserves static
/// information available in the `source` ranked tensor type.
bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
auto sourceType = source.dyn_cast<RankedTensorType>();
auto targetType = target.dyn_cast<RankedTensorType>();
// Requires RankedTensorType.
if (!sourceType || !targetType)
return false;
// Requires same elemental type.
if (sourceType.getElementType() != targetType.getElementType())
return false;
// Requires same rank.
if (sourceType.getRank() != targetType.getRank())
return false;
// If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
if (!ShapedType::isDynamic(std::get<0>(t)) &&
ShapedType::isDynamic(std::get<1>(t)))
return false;
}
return true;
}
/// Determines whether tensor::CastOp casts to a more dynamic version of the
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
/// implement canonicalization patterns for ops in different dialects that may
@ -57,30 +85,10 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
if (!castOp)
return false;
RankedTensorType sourceType =
castOp.source().getType().dyn_cast<RankedTensorType>();
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
// Requires RankedTensorType.
if (!sourceType || !resultType)
return false;
// Requires same elemental type.
if (sourceType.getElementType() != resultType.getElementType())
return false;
// Requires same rank.
if (sourceType.getRank() != resultType.getRank())
return false;
// If cast is towards more static sizes along any dimension, don't fold.
for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
if (ShapedType::isDynamic(std::get<0>(t)) &&
!ShapedType::isDynamic(std::get<1>(t)))
return false;
}
return true;
// Can fold if the source of cast has at least as much static information as
// its results.
return preservesStaticInformation(castOp.getType(),
castOp.source().getType());
}
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp

View File

@ -696,6 +696,39 @@ func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
// -----
// CHECK-LABEL: @cast_of_pad_more_static
func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
%cst = constant 0.000000e+00 : f32
// CHECK: %[[PAD:.*]] = linalg.pad_tensor
// CHECK: tensor<?x?xf32> to tensor<32x32xf32>
%padded = linalg.pad_tensor %arg0 low[%padding, %padding] high[0, 0] {
^bb0(%arg1: index, %arg2: index):
linalg.yield %cst : f32
} : tensor<?x?xf32> to tensor<?x?xf32>
// CHECK-NOT: tensor.cast
%casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
// CHECK: return %[[PAD]]
return %casted : tensor<32x32xf32>
}
// -----
// CHECK-LABEL: @cast_of_pad_less_static
func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
%cst = constant 0.000000e+00 : f32
// CHECK: linalg.pad_tensor
%padded = linalg.pad_tensor %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
^bb0(%arg1: index, %arg2: index, %arg3: index):
linalg.yield %cst : f32
} : tensor<32x?x?xf32> to tensor<32x?x?xf32>
// CHECK: %[[CAST:.*]] = tensor.cast
%casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
// CHECK: return %[[CAST]]
return %casted : tensor<?x32x32xf32>
}
// -----
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
%arg3 : index) -> tensor<?x?xf32> {
%c0 = constant 0 : index

View File

@ -140,8 +140,7 @@ func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
// CHECK: } else {
// CHECK: %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
// CHECK: %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
// CHECK: %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor<?x4xf32> to tensor<3x4xf32>
// CHECK: scf.yield %[[CAST]]
// CHECK: scf.yield %[[PADTENSOR]]
// CHECK: }
// CHECK: return %[[RESULT]]
func @dynamic_high_pad(%arg0 : tensor<?x5xf32>, %h1: index, %pad : f32) -> tensor<3x4xf32> {

View File

@ -289,7 +289,6 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
// CHECK: else
// CHECK: tensor.extract_slice
// CHECK: linalg.pad_tensor
// CHECK: tensor.cast
// CHECK: tensor.extract_slice
// CHECK: tensor.extract_slice
// CHECK: linalg.generic

View File

@ -111,8 +111,7 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
// TILE1: else
// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
// TILE1: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[0, 0] high[7, %{{.*}}]
// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32>
// TILE1: scf.yield %[[PAD]] : tensor<14x3xf32>
// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
// TILE1: scf.yield %[[R3]] : tensor<14x15xf32>
// TILE1: return %[[RESULT]] : tensor<14x15xf32>