forked from OSchip/llvm-project
[mlir] Support unranked types in func signature conversion in BufferPlacement.
Currently, only ranked tensor args and results can be converted to memref types. Differential Revision: https://reviews.llvm.org/D83324
This commit is contained in:
parent
1143f09678
commit
1a2ed71a8a
|
@ -700,15 +700,19 @@ BufferAssignmentPlacer::computeAllocPosition(OpResult result) {
|
|||
BufferAssignmentTypeConverter::BufferAssignmentTypeConverter() {
|
||||
// Keep all types unchanged.
|
||||
addConversion([](Type type) { return type; });
|
||||
// A type conversion that converts ranked-tensor type to memref type.
|
||||
// Convert RankedTensorType to MemRefType.
|
||||
addConversion([](RankedTensorType type) {
|
||||
return (Type)MemRefType::get(type.getShape(), type.getElementType());
|
||||
});
|
||||
// Convert UnrankedTensorType to UnrankedMemRefType.
|
||||
addConversion([](UnrankedTensorType type) {
|
||||
return (Type)UnrankedMemRefType::get(type.getElementType(), 0);
|
||||
});
|
||||
}
|
||||
|
||||
/// Checks if `type` has been converted from non-memref type to memref.
|
||||
bool BufferAssignmentTypeConverter::isConvertedMemref(Type type, Type before) {
|
||||
return type.isa<MemRefType>() && !before.isa<MemRefType>();
|
||||
return type.isa<BaseMemRefType>() && !before.isa<BaseMemRefType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -64,6 +64,15 @@ func @simple_signature_conversion(%arg0: tensor<4x8xf32>) -> tensor<4x8xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_with_unranked_arg_and_result
|
||||
func @func_with_unranked_arg_and_result(%arg0: tensor<*xf32>) -> tensor<*xf32> {
|
||||
return %arg0 : tensor<*xf32>
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32>
|
||||
// CHECK-NEXT: return [[ARG]] : memref<*xf32>
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_and_block_signature_conversion
|
||||
func @func_and_block_signature_conversion(%arg0 : tensor<2xf32>, %cond : i1, %arg1: tensor<4x4xf32>) -> tensor<4x4xf32>{
|
||||
cond_br %cond, ^bb1, ^bb2
|
||||
|
|
|
@ -284,3 +284,9 @@ func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
|
|||
// CHECK: %[[Y1:.*]] = call @callee(%[[X0]], %[[Y0]])
|
||||
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
|
||||
// CHECK: return
|
||||
|
||||
// CHECK-LABEL: func @func_with_unranked_arg
|
||||
func @func_with_unranked_arg(%arg0: tensor<*xf32>) {
|
||||
return
|
||||
}
|
||||
// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>)
|
||||
|
|
Loading…
Reference in New Issue