[mlir] Introduce CallOp converter for buffer placement

Add BufferAssignmentCallOpConverter as a pattern rewriter for Buffer
Placement. It matches the signature of the caller operation with the callee
after rewriting the callee with FunctionAndBlockSignatureConverter.

Differential Revision: https://reviews.llvm.org/D80785
This commit is contained in:
Ehsan Toosi 2020-05-26 15:03:45 +02:00
parent 3c626c714c
commit 3f6a35e3ff
4 changed files with 162 additions and 0 deletions

View File

@ -157,6 +157,21 @@ public:
return success();
}
};
/// Converts `CallOp` to match its operands and results with the
/// the callee after rewriting the callee with
/// FunctionAndBlockSignatureConverter.
class BufferAssignmentCallOpConverter
: public BufferAssignmentOpConversionPattern<CallOp> {
public:
using BufferAssignmentOpConversionPattern<
CallOp>::BufferAssignmentOpConversionPattern;
/// Performs the actual `CallOp` conversion step.
LogicalResult
matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final;
};
} // end namespace mlir
#endif // MLIR_TRANSFORMS_BUFFERPLACEMENT_H

View File

@ -468,6 +468,57 @@ LogicalResult FunctionAndBlockSignatureConverter::matchAndRewrite(
return success();
}
//===----------------------------------------------------------------------===//
// BufferAssignmentCallOpConverter
//===----------------------------------------------------------------------===//
// Performs `CallOp` conversion to match its operands and results with the
// signature of the callee after rewriting the callee with
// FunctionAndBlockSignatureConverter.
LogicalResult BufferAssignmentCallOpConverter::matchAndRewrite(
CallOp callOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const {
Location loc = callOp.getLoc();
SmallVector<Value, 2> newOperands, replacingValues;
SmallVector<Type, 2> newResultTypes;
unsigned numResults = callOp.getNumResults();
newOperands.reserve(numResults + operands.size());
newOperands.append(operands.begin(), operands.end());
newResultTypes.reserve(numResults);
replacingValues.reserve(numResults);
// For each memref result of `CallOp` which has not been a memref before type
// conversion, a new buffer is allocated and passed to the operands list of
// the new `CallOp`. Otherwise, it remains as a caller result.
for (Value result : callOp.getResults()) {
Type currType = result.getType();
Type newType = converter->convertType(result.getType());
if (BufferAssignmentTypeConverter::isConvertedMemref(newType, currType)) {
OpBuilder::InsertionGuard guard(rewriter);
rewriter.restoreInsertionPoint(
bufferAssignment->computeAllocPosition(result.dyn_cast<OpResult>()));
Value alloc =
rewriter.create<AllocOp>(loc, newType.dyn_cast<MemRefType>());
newOperands.push_back(alloc);
replacingValues.push_back(alloc);
} else {
newResultTypes.push_back(currType);
// No replacing is required.
replacingValues.push_back(nullptr);
}
}
// Creating the new `CallOp`.
rewriter.create<CallOp>(loc, callOp.getCallee(), newResultTypes, newOperands);
// Replacing the results of the old `CallOp`.
rewriter.replaceOp(callOp, replacingValues);
return success();
}
//===----------------------------------------------------------------------===//
// BufferAssignmentTypeConverter
//===----------------------------------------------------------------------===//

View File

@ -195,3 +195,92 @@ func @compute_allocs_position(%cond: i1, %arg0: tensor<2xf32>) -> tensor<2xf32>{
// CHECK-NEXT: linalg.generic {{.*}} %[[ARG0]], %[[ALLOC6]]
// CHECK: %[[ALLOC7:.*]] = alloc()
// CHECK-NEXT: linalg.generic {{.*}} %[[ALLOC6]], %[[ALLOC7]]
// -----
// Test case: Checking BufferAssignmentCallOpConverter and
// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all
// together. The signature of `callee` after signature conversion would be:
// func @callee(%arg0: memref<5xf32>,%arg1: memref<5xf32>) -> ()
// The operands and results of caller and return operations must be matched
// respectively.
#map0 = affine_map<(d0) -> (d0)>
// CHECK-LABEL: func @callee
func @callee(%arg1: tensor<5xf32>) -> tensor<5xf32> {
%0 = linalg.generic {
args_in = 1 : i64,
args_out = 1 : i64,
indexing_maps = [#map0, #map0],
iterator_types = ["parallel"]
} %arg1 {
^bb0(%gen1_arg0: f32):
%tmp1 = exp %gen1_arg0 : f32
linalg.yield %tmp1 : f32
}: tensor<5xf32> -> tensor<5xf32>
return %0 : tensor<5xf32>
}
// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>)
// CHECK: %[[ALLOC:.*]] = alloc()
// CHECK: linalg.generic
// CHECK: linalg.copy(%[[ALLOC]], %[[CALLEE_RESULT]])
// CHECK: return
// CHECK-LABEL: func @caller
func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
%x = call @callee(%arg0) : (tensor<5xf32>) -> tensor<5xf32>
%y = call @callee(%x) : (tensor<5xf32>) -> tensor<5xf32>
return %y : tensor<5xf32>
}
// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>, %[[CALLER_RESULT:.*]]: memref<5xf32>)
// CHECK: %[[FIRST_ALLOC:.*]] = alloc()
// CHECK: call @callee(%[[CALLER_ARG]], %[[FIRST_ALLOC]])
// CHECK: %[[SECOND_ALLOC:.*]] = alloc()
// CHECK: call @callee(%[[FIRST_ALLOC]], %[[SECOND_ALLOC]])
// CHECK: linalg.copy(%[[SECOND_ALLOC]], %[[CALLER_RESULT]])
// CHECK: return
// -----
// Test case: Checking BufferAssignmentCallOpConverter and
// FunctionAndBlockSignatureConverter and BufferAssignmentReturnOpConverter all
// together on functions that also have memref typed results. The signature of
// `callee` after signature conversion would be:
// func @callee(%arg0: memref<5xf32>,%arg1: memref<5xf32>)-> memref<2xf32>
// where %arg0 is the input and %arg1 is the output buffer and the original memref
// type result remain as the function result. Then, the rewriter should match the
// caller's signature with the callee. Thus, two buffers will be allocated instead
// of %x0 and %y0 and they are passed to the callers' operands list as the output
// buffers. %x1 and %y1 remain as callers' results.
// CHECK-LABEL: func @callee
func @callee(%arg1: tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>) {
%buff = alloc() : memref<2xf32>
return %arg1, %buff : tensor<5xf32>, memref<2xf32>
}
// CHECK: (%[[CALLEE_ARG:.*]]: memref<5xf32>, %[[CALLEE_RESULT:.*]]: memref<5xf32>)
// CHECK-SAME: memref<2xf32>
// CHECK: %[[ALLOC:.*]] = alloc()
// CHECK: linalg.copy(%[[CALLEE_ARG]], %[[CALLEE_RESULT]])
// CHECK: return %[[ALLOC]]
// CHECK-LABEL: func @caller
func @caller(%arg0: tensor<5xf32>) -> tensor<5xf32> {
%x0, %x1 = call @callee(%arg0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
%y0, %y1 = call @callee(%x0) : (tensor<5xf32>) -> (tensor<5xf32>, memref<2xf32>)
return %y0 : tensor<5xf32>
}
// CHECK: (%[[CALLER_ARG:.*]]: memref<5xf32>, %[[CALLER_RESULT:.*]]: memref<5xf32>)
// CHECK: %[[X0:.*]] = alloc()
// CHECK: %[[X1:.*]] = call @callee(%[[CALLER_ARG]], %[[X0]])
// CHECK: %[[Y0:.*]] = alloc()
// CHECK: %[[Y1:.*]] = call @callee(%[[X0]], %[[Y0]])
// CHECK: linalg.copy(%[[Y0]], %[[CALLER_RESULT]])
// CHECK: return

View File

@ -106,6 +106,7 @@ struct TestBufferPlacementPreparationPass
TypeConverter *converter, OwningRewritePatternList *patterns) {
// clang-format off
patterns->insert<
BufferAssignmentCallOpConverter,
FunctionAndBlockSignatureConverter,
GenericOpConverter,
BufferAssignmentReturnOpConverter<
@ -137,6 +138,12 @@ struct TestBufferPlacementPreparationPass
return llvm::none_of(returnOp.getOperandTypes(), isIllegalType);
});
// Mark Standard Call Operation illegal as long as it operates on tensor.
target.addDynamicallyLegalOp<mlir::CallOp>([&](mlir::CallOp callOp) {
return llvm::none_of(callOp.getOperandTypes(), isIllegalType) &&
llvm::none_of(callOp.getResultTypes(), isIllegalType);
});
// Mark the function whose arguments are in tensor-type illegal.
target.addDynamicallyLegalOp<FuncOp>([&](FuncOp funcOp) {
return converter.isSignatureLegal(funcOp.getType());