forked from OSchip/llvm-project
[mlir] add pad_tensor(tensor.cast) -> pad_tensor canonicalizer
This canonicalization pattern complements the tensor.cast(pad_tensor) one in propagating constant type information when possible. It contributes to the feasibility of pad hoisting. Reviewed By: nicolasvasilache Differential Revision: https://reviews.llvm.org/D110343
This commit is contained in:
parent
751be2a064
commit
3f89e339bb
|
@ -53,6 +53,10 @@ SmallVector<Range, 8> getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
|
|||
namespace mlir {
|
||||
namespace tensor {
|
||||
|
||||
/// Returns true if `target` is a ranked tensor type that preserves static
|
||||
/// information available in the `source` ranked tensor type.
|
||||
bool preservesStaticInformation(Type source, Type target);
|
||||
|
||||
/// Determines whether tensor::CastOp casts to a more dynamic version of the
|
||||
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
|
||||
/// implement canonicalization patterns for ops in different dialects that may
|
||||
|
|
|
@ -1482,11 +1482,41 @@ struct FoldSourceTensorCast : public OpRewritePattern<PadTensorOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Fold CastOp using the result of PadTensorOp back into the latter if it adds
|
||||
// static information.
|
||||
struct FoldTargetTensorCast : public OpRewritePattern<PadTensorOp> {
|
||||
using OpRewritePattern<PadTensorOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(PadTensorOp padTensorOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
if (!padTensorOp.result().hasOneUse())
|
||||
return failure();
|
||||
auto tensorCastOp =
|
||||
dyn_cast<tensor::CastOp>(*padTensorOp->getUsers().begin());
|
||||
if (!tensorCastOp)
|
||||
return failure();
|
||||
if (!tensor::preservesStaticInformation(padTensorOp.result().getType(),
|
||||
tensorCastOp.dest().getType()))
|
||||
return failure();
|
||||
|
||||
auto replacementOp = rewriter.create<PadTensorOp>(
|
||||
padTensorOp.getLoc(), tensorCastOp.dest().getType(),
|
||||
padTensorOp.source(), padTensorOp.low(), padTensorOp.high(),
|
||||
padTensorOp.static_low(), padTensorOp.static_high());
|
||||
replacementOp.region().takeBody(padTensorOp.region());
|
||||
|
||||
rewriter.replaceOp(padTensorOp, replacementOp.result());
|
||||
rewriter.replaceOp(tensorCastOp, replacementOp.result());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void PadTensorOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
MLIRContext *context) {
|
||||
results.add<FoldStaticZeroPadding, FoldSourceTensorCast>(context);
|
||||
results.add<FoldTargetTensorCast>(context);
|
||||
}
|
||||
|
||||
/// Return the padding value of the PadTensorOp if it constant. In this context,
|
||||
|
|
|
@ -31,6 +31,34 @@ Operation *TensorDialect::materializeConstant(OpBuilder &builder,
|
|||
// CastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if `target` is a ranked tensor type that preserves static
|
||||
/// information available in the `source` ranked tensor type.
|
||||
bool mlir::tensor::preservesStaticInformation(Type source, Type target) {
|
||||
auto sourceType = source.dyn_cast<RankedTensorType>();
|
||||
auto targetType = target.dyn_cast<RankedTensorType>();
|
||||
|
||||
// Requires RankedTensorType.
|
||||
if (!sourceType || !targetType)
|
||||
return false;
|
||||
|
||||
// Requires same elemental type.
|
||||
if (sourceType.getElementType() != targetType.getElementType())
|
||||
return false;
|
||||
|
||||
// Requires same rank.
|
||||
if (sourceType.getRank() != targetType.getRank())
|
||||
return false;
|
||||
|
||||
// If cast is towards more static sizes along any dimension, don't fold.
|
||||
for (auto t : llvm::zip(sourceType.getShape(), targetType.getShape())) {
|
||||
if (!ShapedType::isDynamic(std::get<0>(t)) &&
|
||||
ShapedType::isDynamic(std::get<1>(t)))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
/// Determines whether tensor::CastOp casts to a more dynamic version of the
|
||||
/// source tensor. This is useful to fold a tensor.cast into a consuming op and
|
||||
/// implement canonicalization patterns for ops in different dialects that may
|
||||
|
@ -57,30 +85,10 @@ bool mlir::tensor::canFoldIntoConsumerOp(CastOp castOp) {
|
|||
if (!castOp)
|
||||
return false;
|
||||
|
||||
RankedTensorType sourceType =
|
||||
castOp.source().getType().dyn_cast<RankedTensorType>();
|
||||
RankedTensorType resultType = castOp.getType().dyn_cast<RankedTensorType>();
|
||||
|
||||
// Requires RankedTensorType.
|
||||
if (!sourceType || !resultType)
|
||||
return false;
|
||||
|
||||
// Requires same elemental type.
|
||||
if (sourceType.getElementType() != resultType.getElementType())
|
||||
return false;
|
||||
|
||||
// Requires same rank.
|
||||
if (sourceType.getRank() != resultType.getRank())
|
||||
return false;
|
||||
|
||||
// If cast is towards more static sizes along any dimension, don't fold.
|
||||
for (auto t : llvm::zip(sourceType.getShape(), resultType.getShape())) {
|
||||
if (ShapedType::isDynamic(std::get<0>(t)) &&
|
||||
!ShapedType::isDynamic(std::get<1>(t)))
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
// Can fold if the source of cast has at least as much static information as
|
||||
// its results.
|
||||
return preservesStaticInformation(castOp.getType(),
|
||||
castOp.source().getType());
|
||||
}
|
||||
|
||||
/// Performs folding of any operand of `op` if it comes from a tensor::CastOp
|
||||
|
|
|
@ -696,6 +696,39 @@ func @pad_tensor_of_cast(%t: tensor<8x?xf32>, %s: index) -> tensor<8x32xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @cast_of_pad_more_static
|
||||
func @cast_of_pad_more_static(%arg0: tensor<?x?xf32>, %padding: index) -> tensor<32x32xf32> {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[PAD:.*]] = linalg.pad_tensor
|
||||
// CHECK: tensor<?x?xf32> to tensor<32x32xf32>
|
||||
%padded = linalg.pad_tensor %arg0 low[%padding, %padding] high[0, 0] {
|
||||
^bb0(%arg1: index, %arg2: index):
|
||||
linalg.yield %cst : f32
|
||||
} : tensor<?x?xf32> to tensor<?x?xf32>
|
||||
// CHECK-NOT: tensor.cast
|
||||
%casted = tensor.cast %padded : tensor<?x?xf32> to tensor<32x32xf32>
|
||||
// CHECK: return %[[PAD]]
|
||||
return %casted : tensor<32x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @cast_of_pad_less_static
|
||||
func @cast_of_pad_less_static(%arg0: tensor<32x?x?xf32>, %padding: index) -> tensor<?x32x32xf32> {
|
||||
%cst = constant 0.000000e+00 : f32
|
||||
// CHECK: linalg.pad_tensor
|
||||
%padded = linalg.pad_tensor %arg0 low[%padding, %padding, %padding] high[0, 0, 0] {
|
||||
^bb0(%arg1: index, %arg2: index, %arg3: index):
|
||||
linalg.yield %cst : f32
|
||||
} : tensor<32x?x?xf32> to tensor<32x?x?xf32>
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast
|
||||
%casted = tensor.cast %padded : tensor<32x?x?xf32> to tensor<?x32x32xf32>
|
||||
// CHECK: return %[[CAST]]
|
||||
return %casted : tensor<?x32x32xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @propogate_casts(%arg0 : tensor<?x?xf32>, %arg1 : f32, %arg2 : index,
|
||||
%arg3 : index) -> tensor<?x?xf32> {
|
||||
%c0 = constant 0 : index
|
||||
|
|
|
@ -140,8 +140,7 @@ func @static_mixed_data_low_high_pad(%arg0 : tensor<4x5xf32>, %pad : f32)
|
|||
// CHECK: } else {
|
||||
// CHECK: %[[SUBTENSOR:.*]] = tensor.extract_slice %[[ARG0]][%{{.*}}, 4] [%{{.*}}, 1] [1, 1] : tensor<?x5xf32> to tensor<?x1xf32>
|
||||
// CHECK: %[[PADTENSOR:.*]] = linalg.pad_tensor %[[SUBTENSOR]] low[0, 0] high[%{{.*}}, 3]
|
||||
// CHECK: %[[CAST:.*]] = tensor.cast %[[PADTENSOR]] : tensor<?x4xf32> to tensor<3x4xf32>
|
||||
// CHECK: scf.yield %[[CAST]]
|
||||
// CHECK: scf.yield %[[PADTENSOR]]
|
||||
// CHECK: }
|
||||
// CHECK: return %[[RESULT]]
|
||||
func @dynamic_high_pad(%arg0 : tensor<?x5xf32>, %h1: index, %pad : f32) -> tensor<3x4xf32> {
|
||||
|
|
|
@ -289,7 +289,6 @@ func @conv_tensors_dynamic(%input: tensor<?x?x?x?xf32>, %filter: tensor<?x?x?x?x
|
|||
// CHECK: else
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: linalg.pad_tensor
|
||||
// CHECK: tensor.cast
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: tensor.extract_slice
|
||||
// CHECK: linalg.generic
|
||||
|
|
|
@ -111,8 +111,7 @@ func @static_pad_tensor(%input_tensor: tensor<7x9xf32>,
|
|||
// TILE1: else
|
||||
// TILE1: %[[SLICE:.*]] = tensor.extract_slice %arg0[0, %{{.*}}] [7, %{{.*}}] [1, 1] : tensor<7x9xf32> to tensor<7x?xf32>
|
||||
// TILE1: %[[PAD:.*]] = linalg.pad_tensor %[[SLICE]] low[0, 0] high[7, %{{.*}}]
|
||||
// TILE1: %[[CAST:.*]] = tensor.cast %[[PAD]] : tensor<14x?xf32> to tensor<14x3xf32>
|
||||
// TILE1: scf.yield %[[CAST]] : tensor<14x3xf32>
|
||||
// TILE1: scf.yield %[[PAD]] : tensor<14x3xf32>
|
||||
// TILE1: %[[R3:.*]] = tensor.insert_slice %[[R2]] into %[[INNER_OUT]][0, %[[IV]]] [14, 3] [1, 1] : tensor<14x3xf32> into tensor<14x15xf32>
|
||||
// TILE1: scf.yield %[[R3]] : tensor<14x15xf32>
|
||||
// TILE1: return %[[RESULT]] : tensor<14x15xf32>
|
||||
|
|
Loading…
Reference in New Issue