forked from OSchip/llvm-project
[MLIR][Shape] Limit shape to SCF lowering patterns to their supported types
Differential Revision: https://reviews.llvm.org/D84444
This commit is contained in:
parent
1aaf8aa53d
commit
5fc34fafa7
|
@ -121,7 +121,7 @@ LogicalResult
|
|||
ReduceOpConverter::matchAndRewrite(shape::ReduceOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
// For now, this lowering is only defined on `tensor<?xindex>` operands.
|
||||
if (!op.shape().getType().isa<RankedTensorType>())
|
||||
if (op.shape().getType().isa<ShapeType>())
|
||||
return failure();
|
||||
|
||||
auto loc = op.getLoc();
|
||||
|
@ -171,12 +171,15 @@ public:
|
|||
LogicalResult
|
||||
ShapeOfOpConverter::matchAndRewrite(ShapeOfOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
ShapeOfOp::Adaptor transformed(operands);
|
||||
Value arg = transformed.arg();
|
||||
Type argTy = arg.getType();
|
||||
// For now, this lowering supports only error-free arguments.
|
||||
if (op.getType().isa<ShapeType>())
|
||||
return failure();
|
||||
|
||||
// For ranked tensors `shape_of` lowers to `std` and the pattern can be
|
||||
// found in the corresponding pass.
|
||||
ShapeOfOp::Adaptor transformed(operands);
|
||||
Value arg = transformed.arg();
|
||||
Type argTy = arg.getType();
|
||||
if (argTy.isa<RankedTensorType>())
|
||||
return failure();
|
||||
|
||||
|
|
|
@ -24,21 +24,32 @@ func @shape_reduce(%shape : tensor<?xindex>) -> index {
|
|||
|
||||
// -----
|
||||
|
||||
// Don't lower `shape_of` for result type of `shape.shape`.
|
||||
// CHECK-LABEL: @shape_of
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||
func @shape_of(%arg : tensor<*xf32>) {
|
||||
// CHECK: shape.shape
|
||||
%shape = shape.shape_of %arg : tensor<*xf32> -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Lower `shape_of` for unranked tensors.
|
||||
// CHECK-LABEL: @shape_of_unranked
|
||||
// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>)
|
||||
func @shape_of_unranked(%arg : tensor<*xf32>) {
|
||||
// CHECK-DAG: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
|
||||
// CHECK-DAG: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
|
||||
// CHECK-DAG: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK-DAG: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
|
||||
// CHECK-DAG: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
|
||||
// CHECK-DAG: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
|
||||
// CHECK-DAG: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
|
||||
// CHECK: }
|
||||
// CHECK-DAG: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
|
||||
// CHECK-DAG: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
|
||||
// CHECK: %[[RANK:.*]] = rank %[[ARG]] : tensor<*xf32>
|
||||
// CHECK: %[[SHAPE_MEM:.*]] = alloca(%[[RANK]]) : memref<?xi64>
|
||||
// CHECK: %[[C0:.*]] = constant 0 : index
|
||||
// CHECK: %[[C1:.*]] = constant 1 : index
|
||||
// CHECK: scf.for %[[I:.*]] = %[[C0]] to %[[RANK]] step %[[C1]] {
|
||||
// CHECK: %[[DIM:.]] = dim %[[ARG]], %[[I]] : tensor<*xf32>
|
||||
// CHECK: %[[DIM_INT:.*]] = index_cast %[[DIM]] : index to i64
|
||||
// CHECK: store %[[DIM_INT]], %[[SHAPE_MEM]][%[[I]]] : memref<?xi64>
|
||||
// CHECK: }
|
||||
// CHECK: %[[SHAPE_INT:.*]] = tensor_load %[[SHAPE_MEM]] : memref<?xi64>
|
||||
// CHECK: %[[SHAPE:.*]] = index_cast %[[SHAPE_INT]] : tensor<?xi64> to tensor<?xindex>
|
||||
%shape = shape.shape_of %arg : tensor<*xf32> -> tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue