[MLIR][Shape] Limit shape to SCF lowering patterns to their supported types

Differential Revision: https://reviews.llvm.org/D84444
This commit is contained in:
Frederik Gossen 2020-07-29 14:52:27 +00:00
parent 1aaf8aa53d
commit 5fc34fafa7
2 changed files with 29 additions and 15 deletions

View File

@ -121,7 +121,7 @@ LogicalResult
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
// For now, this lowering is only defined on `tensor<?xindex>` operands.
if (!op.shape().getType().isa<RankedTensorType>())
if (op.shape().getType().isa<ShapeType>())
return failure();
auto loc = op.getLoc();
@ -171,12 +171,15 @@ public:
LogicalResult
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
ShapeOfOp::Adaptor transformed(operands);
Value arg = transformed.arg();
Type argTy = arg.getType();
// For now, this lowering supports only error-free arguments.
if (op.getType().isa<ShapeType>())
return failure();
// For ranked tensors `shape_of` lowers to `std` and the pattern can be
// found in the corresponding pass.
ShapeOfOp::Adaptor transformed(operands);
Value arg = transformed.arg();
Type argTy = arg.getType();
if (argTy.isa<RankedTensorType>())
return failure();

View File

@ -24,21 +24,32 @@ func @shape_reduce(%shape : tensor<?xindex>) -> index {
// -----
// Don't lower `shape_of` for result type of `shape.shape`.
// CHECK-LABEL: @shape_of
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @shape_of(%arg : tensor<*xf32>) {
// CHECK: shape.shape
%shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
return
}
// -----
// Lower `shape_of` for unranked tensors.
// CHECK-LABEL: @shape_of_unranked
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
func @shape_of_unranked(%arg : tensor<*xf32>) {
// CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
// CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
// CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
// CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
// CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
// CHECK: }
// CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
// CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
// CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
// CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
// CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
// CHECK: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
// CHECK: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
// CHECK: }
// CHECK: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
// CHECK: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
%shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
return
}