[MLIR][Shape] Canonicalize casted dynamic extent tensor

Differential Revision: https://reviews.llvm.org/D99161
This commit is contained in:
Frederik Gossen 2021-03-29 13:44:03 +02:00
parent c6e5c4654b
commit 630afc61a8
2 changed files with 78 additions and 2 deletions

View File

@ -987,11 +987,43 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
return success();
}
};
// Canonicalize
// ```
// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
// ```
// to
// ```
// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
// ```
struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
LogicalResult matchAndRewrite(tensor::CastOp op,
PatternRewriter &rewriter) const override {
auto ty = op.getType().dyn_cast<RankedTensorType>();
if (!ty || ty.getRank() != 1)
return failure();
auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
if (!shapeOfOp)
return failure();
// Argument type must be ranked and must not conflict.
auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
return failure();
rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
return success();
}
};
} // namespace
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add<ShapeOfWithTensor>(context);
patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -648,7 +648,7 @@ func @f() {
// CHECK: shape.cstr_broadcastable
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%cs0 = shape.const_shape [8, 1] : !shape.shape
%cs0 = shape.const_shape [8, 1] : !shape.shape
%cs1 = shape.const_shape [1, 8] : !shape.shape
%cs2 = shape.const_shape [1, -1] : !shape.shape
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
@ -1144,3 +1144,47 @@ func @broadcast_on_single_operand(%a : tensor<3xindex>) {
"use"(%0) : (tensor<?xindex>) -> ()
return
}
// -----
// CHECK-LABEL: @casted_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
// CHECK: return %[[RESULT]] : tensor<?xindex>
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
%1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
return %1 : tensor<?xindex>
}
// -----
// CHECK-LABEL: @casted_extent_tensor
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
// CHECK: return %[[RESULT]] : tensor<3xindex>
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
return %1 : tensor<3xindex>
}
// -----
// CHECK-LABEL: @casted_extent_tensor
func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
%0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
return %1 : tensor<3xindex>
}
// -----
// CHECK-LABEL: @casted_extent_tensor
func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
%0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
return %1 : tensor<3xindex>
}