forked from OSchip/llvm-project
[MLIR][Shape] Canonicalize casted dynamic extent tensor
Differential Revision: https://reviews.llvm.org/D99161
This commit is contained in:
parent
c6e5c4654b
commit
630afc61a8
|
@ -987,11 +987,43 @@ struct ShapeOfWithTensor : public OpRewritePattern<shape::ShapeOfOp> {
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
// Canonicalize
|
||||
// ```
|
||||
// %0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
|
||||
// %1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
|
||||
// ```
|
||||
// to
|
||||
// ```
|
||||
// %1 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
|
||||
// ```
|
||||
struct ShapeOfCastedExtentTensor : public OpRewritePattern<tensor::CastOp> {
|
||||
using OpRewritePattern<tensor::CastOp>::OpRewritePattern;
|
||||
|
||||
LogicalResult matchAndRewrite(tensor::CastOp op,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto ty = op.getType().dyn_cast<RankedTensorType>();
|
||||
if (!ty || ty.getRank() != 1)
|
||||
return failure();
|
||||
|
||||
auto shapeOfOp = op.source().getDefiningOp<ShapeOfOp>();
|
||||
if (!shapeOfOp)
|
||||
return failure();
|
||||
|
||||
// Argument type must be ranked and must not conflict.
|
||||
auto argTy = shapeOfOp.arg().getType().dyn_cast<RankedTensorType>();
|
||||
if (!argTy || (!ty.isDynamicDim(0) && ty.getDimSize(0) != argTy.getRank()))
|
||||
return failure();
|
||||
|
||||
rewriter.replaceOpWithNewOp<ShapeOfOp>(op, ty, shapeOfOp.arg());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void ShapeOfOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
|
||||
MLIRContext *context) {
|
||||
patterns.add<ShapeOfWithTensor>(context);
|
||||
patterns.add<ShapeOfCastedExtentTensor, ShapeOfWithTensor>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -648,7 +648,7 @@ func @f() {
|
|||
// CHECK: shape.cstr_broadcastable
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%cs0 = shape.const_shape [8, 1] : !shape.shape
|
||||
%cs0 = shape.const_shape [8, 1] : !shape.shape
|
||||
%cs1 = shape.const_shape [1, 8] : !shape.shape
|
||||
%cs2 = shape.const_shape [1, -1] : !shape.shape
|
||||
%0 = shape.cstr_broadcastable %cs0, %cs1, %cs2 : !shape.shape, !shape.shape, !shape.shape
|
||||
|
@ -1144,3 +1144,47 @@ func @broadcast_on_single_operand(%a : tensor<3xindex>) {
|
|||
"use"(%0) : (tensor<?xindex>) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @casted_extent_tensor
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<?xindex>
|
||||
func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<?xindex> {
|
||||
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<?xindex>
|
||||
// CHECK: return %[[RESULT]] : tensor<?xindex>
|
||||
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<3xindex>
|
||||
%1 = tensor.cast %0 : tensor<3xindex> to tensor<?xindex>
|
||||
return %1 : tensor<?xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @casted_extent_tensor
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<?x?x?xf32>) -> tensor<3xindex>
|
||||
func @casted_extent_tensor(%arg : tensor<?x?x?xf32>) -> tensor<3xindex> {
|
||||
// CHECK: %[[RESULT:.*]] = shape.shape_of %[[ARG]] : tensor<?x?x?xf32> -> tensor<3xindex>
|
||||
// CHECK: return %[[RESULT]] : tensor<3xindex>
|
||||
%0 = shape.shape_of %arg : tensor<?x?x?xf32> -> tensor<?xindex>
|
||||
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
|
||||
return %1 : tensor<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @casted_extent_tensor
|
||||
func @casted_extent_tensor(%arg : tensor<?x?x?x?xf32>) -> tensor<3xindex> {
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
|
||||
%0 = shape.shape_of %arg : tensor<?x?x?x?xf32> -> tensor<?xindex>
|
||||
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
|
||||
return %1 : tensor<3xindex>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: @casted_extent_tensor
|
||||
func @casted_extent_tensor(%arg : tensor<*xf32>) -> tensor<3xindex> {
|
||||
// CHECK: tensor.cast %{{.*}} : tensor<?xindex> to tensor<3xindex>
|
||||
%0 = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
|
||||
%1 = tensor.cast %0 : tensor<?xindex> to tensor<3xindex>
|
||||
return %1 : tensor<3xindex>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue