forked from OSchip/llvm-project
[mlir] Introduce CallOp converter for buffer placement
Add BufferAssignmentCallOpConverter as a pattern rewriter for Buffer Placement. It matches the signature of the caller operation with the callee after rewriting the callee with FunctionAndBlockSignatureConverter. Differential Revision: https://reviews.llvm.org/D80785
This commit is contained in:
parent
3c626c714c
commit
3f6a35e3ff
|
@ -157,6 +157,21 @@ public:
|
|||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts `CallOp` to match its operands and results with the
|
||||
/// the callee after rewriting the callee with
|
||||
/// FunctionAndBlockSignatureConverter.
|
||||
class BufferAssignmentCallOpConverter
|
||||
: public BufferAssignmentOpConversionPattern<CallOp> {
|
||||
public:
|
||||
using BufferAssignmentOpConversionPattern<
|
||||
CallOp>::BufferAssignmentOpConversionPattern;
|
||||
|
||||
/// Performs the actual `CallOp` conversion step.
|
||||
LogicalResult
|
||||
matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final;
|
||||
};
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H
|
||||
|
|
|
@ -468,6 +468,57 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentCallOpConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Performs `CallOp` conversion to match its operands and results with the
|
||||
// signature of the callee after rewriting the callee with
|
||||
// FunctionAndBlockSignatureConverter.
|
||||
LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
|
||||
CallOp callOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
||||
Location loc = callOp.getLoc();
|
||||
SmallVector<Value, 2> newOperands, replacingValues;
|
||||
SmallVector<Type, 2> newResultTypes;
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
newOperands.reserve(numResults + operands.size());
|
||||
newOperands.append(operands.begin(), operands.end());
|
||||
newResultTypes.reserve(numResults);
|
||||
replacingValues.reserve(numResults);
|
||||
|
||||
// For each memref result of `CallOp` which has not been a memref before type
|
||||
// conversion, a new buffer is allocated and passed to the operands list of
|
||||
// the new `CallOp`. Otherwise, it remains as a caller result.
|
||||
for (Value result : callOp.getResults()) {
|
||||
Type currType = result.getType();
|
||||
Type newType = converter->convertType(result.getType());
|
||||
if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
|
||||
OpBuilder::InsertionGuard guard(rewriter);
|
||||
rewriter.restoreInsertionPoint(
|
||||
bufferAssignment->computeAllocPosition(result.dyn_cast<OpResult>()));
|
||||
Value alloc =
|
||||
rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
|
||||
newOperands.push_back(alloc);
|
||||
replacingValues.push_back(alloc);
|
||||
} else {
|
||||
newResultTypes.push_back(currType);
|
||||
|
||||
// No replacing is required.
|
||||
replacingValues.push_back(nullptr);
|
||||
}
|
||||
}
|
||||
|
||||
// Creating the new `CallOp`.
|
||||
rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes, newOperands);
|
||||
|
||||
// Replacing the results of the old `CallOp`.
|
||||
rewriter.replaceOp(callOp, replacingValues);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BufferAssignmentTypeConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -195,3 +195,92 @@ func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
|
|||
// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC6]]
|
||||
// CHECK: %[[ALLOC7:.*]] = alloc()
|
||||
// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC6]], %[[ALLOC7]]
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Checking BufferAssignmentCallOpConverter and
|
||||
// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all
|
||||
// together. The signature of `callee` after signature conversion would be:
|
||||
|
||||
// func @callee(%arg0: memref<5xf32>,%arg1: memref<5xf32>) -> ()
|
||||
|
||||
// The operands and results of caller and return operations must be matched
|
||||
// respectively.
|
||||
|
||||
#map0 = affine_map<(d0) -> (d0)>
|
||||
|
||||
// CHECK-LABEL: func @callee
|
||||
func @callee(%arg1: tensor<5xf32>) -> tensor<5xf32> {
|
||||
%0 = linalg.generic {
|
||||
args_in = 1 : i64,
|
||||
args_out = 1 : i64,
|
||||
indexing_maps = [#map0, #map0],
|
||||
iterator_types = ["parallel"]
|
||||
} %arg1 {
|
||||
^bb0(%gen1_arg0: f32):
|
||||
%tmp1 = exp %gen1_arg0 : f32
|
||||
linalg.yield %tmp1 : f32
|
||||
}: tensor<5xf32> -> tensor<5xf32>
|
||||
return %0 : tensor<5xf32>
|
||||
}
|
||||
// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>)
|
||||
// CHECK: %[[ALLOC:.*]] = alloc()
|
||||
// CHECK: linalg.generic
|
||||
// CHECK: linalg.copy(%[[ALLOC]], %[[CALLEE_RESULT]])
|
||||
// CHECK: return
|
||||
|
||||
// CHECK-LABEL: func @caller
|
||||
func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
|
||||
%x = call @callee(%arg0) : (tensor<5xf32>) -> tensor<5xf32>
|
||||
%y = call @callee(%x) : (tensor<5xf32>) -> tensor<5xf32>
|
||||
return %y : tensor<5xf32>
|
||||
}
|
||||
// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>, %[[CALLER_RESULT:.*]]: memref<5xf32>)
|
||||
// CHECK: %[[FIRST_ALLOC:.*]] = alloc()
|
||||
// CHECK: call @callee(%[[CALLER_ARG]], %[[FIRST_ALLOC]])
|
||||
// CHECK: %[[SECOND_ALLOC:.*]] = alloc()
|
||||
// CHECK: call @callee(%[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
|
||||
// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[CALLER_RESULT]])
|
||||
// CHECK: return
|
||||
|
||||
// -----
|
||||
|
||||
// Test case: Checking BufferAssignmentCallOpConverter and
|
||||
// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all
|
||||
// together on functions that also have memref typed results. The signature of
|
||||
// `callee` after signature conversion would be:
|
||||
|
||||
// func @callee(%arg0: memref<5xf32>,%arg1: memref<5xf32>)-> memref<2xf32>
|
||||
|
||||
// where %arg0 is the input and %arg1 is the output buffer and the original memref
|
||||
// type result remain as the function result. Then, the rewriter should match the
|
||||
// caller's signature with the callee. Thus, two buffers will be allocated instead
|
||||
// of %x0 and %y0 and they are passed to the callers' operands list as the output
|
||||
// buffers. %x1 and %y1 remain as callers' results.
|
||||
|
||||
|
||||
// CHECK-LABEL: func @callee
|
||||
func @callee(%arg1: tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) {
|
||||
%buff = alloc() : memref<2xf32>
|
||||
return %arg1, %buff : tensor<5xf32>, memref<2xf32>
|
||||
}
|
||||
// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>)
|
||||
// CHECK-SAME: memref<2xf32>
|
||||
// CHECK: %[[ALLOC:.*]] = alloc()
|
||||
// CHECK: linalg.copy(%[[CALLEE_ARG]], %[[CALLEE_RESULT]])
|
||||
// CHECK: return %[[ALLOC]]
|
||||
|
||||
|
||||
// CHECK-LABEL: func @caller
|
||||
func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
|
||||
%x0, %x1 = call @callee(%arg0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
|
||||
%y0, %y1 = call @callee(%x0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
|
||||
return %y0 : tensor<5xf32>
|
||||
}
|
||||
// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>, %[[CALLER_RESULT:.*]]: memref<5xf32>)
|
||||
// CHECK: %[[X0:.*]] = alloc()
|
||||
// CHECK: %[[X1:.*]] = call @callee(%[[CALLER_ARG]], %[[X0]])
|
||||
// CHECK: %[[Y0:.*]] = alloc()
|
||||
// CHECK: %[[Y1:.*]] = call @callee(%[[X0]], %[[Y0]])
|
||||
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
|
||||
// CHECK: return
|
||||
|
|
|
@ -106,6 +106,7 @@ struct TestBufferPlacementPreparationPass
|
|||
TypeConverter *converter, OwningRewritePatternList *patterns) {
|
||||
// clang-format off
|
||||
patterns->insert<
|
||||
BufferAssignmentCallOpConverter,
|
||||
FunctionAndBlockSignatureConverter,
|
||||
GenericOpConverter,
|
||||
BufferAssignmentReturnOpConverter<
|
||||
|
@ -137,6 +138,12 @@ struct TestBufferPlacementPreparationPass
|
|||
return llvm::none_of(returnOp.getOperandTypes(), isIllegalType);
|
||||
});
|
||||
|
||||
// Mark Standard Call Operation illegal as long as it operates on tensor.
|
||||
target.addDynamicallyLegalOp<mlir::CallOp>([&](mlir::CallOp callOp) {
|
||||
return llvm::none_of(callOp.getOperandTypes(), isIllegalType) &&
|
||||
llvm::none_of(callOp.getResultTypes(), isIllegalType);
|
||||
});
|
||||
|
||||
// Mark the function whose arguments are in tensor-type illegal.
|
||||
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
|
||||
return converter.isSignatureLegal(funcOp.getType());
|
||||
|
|
Loading…
Reference in New Issue