[mlir][bufferize] Add argument materialization for bufferization

This enables partial bufferization that includes function signatures. To test this, this
change also makes the func-bufferize partial and adds a dedicated finalizing-bufferize pass.

Differential Revision: https://reviews.llvm.org/D92032
This commit is contained in:
Stephan Herhut 2020-11-26 13:26:08 +01:00
parent 1ca174b642
commit 4dd5f79f07
9 changed files with 254 additions and 40 deletions

View File

@ -26,6 +26,13 @@ void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
MLIRContext *ctx,
TypeConverter &converter);
/// Add a pattern to the given pattern list to rewrite branch operations and
/// `return` to use operands that have been legalized by the conversion
/// framework. This can only be done if the branch operation implements the
/// BranchOpInterface. Only needed for partial conversions.
void populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx,
TypeConverter &converter);
} // end namespace mlir
#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_

View File

@ -25,28 +25,26 @@ def StdExpandOps : FunctionPass<"std-expand"> {
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
let summary = "Bufferize func/call/return ops";
let description = [{
A finalizing bufferize pass that bufferizes std.func and std.call ops.
A bufferize pass that bufferizes std.func and std.call ops.
Because this pass updates std.func ops, it must be a module pass. It is
useful to keep this pass separate from other bufferizations so that the
other ones can be run at function-level in parallel.
This pass must be done atomically for two reasons:
1. This pass changes func op signatures, which requires atomically updating
calls as well throughout the entire module.
2. This pass changes the type of block arguments, which requires that all
successor arguments of predecessors be converted. Terminators are not
a closed universe (and need not implement BranchOpInterface), and so we
cannot in general rewrite them.
This pass must be done atomically because it changes func op signatures,
which requires atomically updating calls as well throughout the entire
module.
Note, because this is a "finalizing" bufferize step, it can create
invalid IR because it will not create materializations. To avoid this
situation, the pass must only be run when the only SSA values of
tensor type are:
- block arguments
- the result of tensor_load
Other values of tensor type should be eliminated by earlier
bufferization passes.
This pass also changes the type of block arguments, which requires that all
successor arguments of predecessors be converted. This is achieved by
rewriting terminators based on the information provided by the
`BranchOpInterface`.
As this pass rewrites function operations, it also rewrites the
corresponding return operations. Other return-like operations that
implement the `ReturnLike` trait are not rewritten in general, as they
require that the correspondign parent operation is also rewritten.
Finally, this pass fails for unknown terminators, as we cannot decide
whether they need rewriting.
}];
let constructor = "mlir::createFuncBufferizePass()";
}

View File

@ -46,6 +46,10 @@ std::unique_ptr<Pass>
createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
unsigned bitwidthOfIndexType = 64);
/// Creates a pass that finalizes a partial bufferization by removing remaining
/// tensor_load and tensor_to_memref operations.
std::unique_ptr<FunctionPass> createFinalizingBufferizePass();
/// Creates a pass that converts memref function results to out-params.
std::unique_ptr<Pass> createBufferResultsToOutParamsPass();

View File

@ -290,6 +290,22 @@ def Inliner : Pass<"inline"> {
];
}
def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> {
let summary = "Finalize a partial bufferization";
let description = [{
A bufferize pass that finalizes a partial bufferization by removing
remaining `tensor_load` and `tensor_to_memref` operations.
The removal of those operations is only possible if the operations only
exist in pairs, i.e., all uses of `tensor_load` operations are
`tensor_to_memref` operations.
This pass will fail if not all operations can be removed or if any operation
with tensor typed operands remains.
}];
let constructor = "mlir::createFinalizingBufferizePass()";
}
def LocationSnapshot : Pass<"snapshot-op-locations"> {
let summary = "Generate new locations from the current IR";
let description = [{

View File

@ -21,6 +21,8 @@ using namespace mlir;
namespace {
struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
void runOnOperation() override {
auto module = getOperation();
auto *context = &getContext();
@ -35,14 +37,42 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
typeConverter.isLegal(&op.getBody());
});
populateCallOpTypeConversionPattern(patterns, context, typeConverter);
populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
patterns);
target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
target.addDynamicallyLegalOp<CallOp>(
[&](CallOp op) { return typeConverter.isLegal(op); });
// If all result types are legal, and all block arguments are legal (ensured
// by func conversion above), then all types in the program are legal.
populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context,
typeConverter);
target.addLegalOp<ModuleOp, ModuleTerminatorOp, TensorLoadOp,
TensorToMemrefOp>();
target.addDynamicallyLegalOp<ReturnOp>(
[&](ReturnOp op) { return typeConverter.isLegal(op); });
// Mark terminators as legal if they have the ReturnLike trait or
// implement the BranchOpInterface and have valid types. If they do not
// implement the trait or interface, mark them as illegal no matter what.
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return typeConverter.isLegal(op->getResultTypes());
// If it is not a terminator, ignore it.
if (op->isKnownNonTerminator())
return true;
// If it is not the last operation in the block, also ignore it. We do
// this to handle unknown operations, as well.
Block *block = op->getBlock();
if (!block || &block->back() != op)
return true;
// ReturnLike operations have to be legalized with their parent. For
// return this is handled, for other ops they remain as is.
if (op->hasTrait<OpTrait::ReturnLike>())
return true;
// All successor operands of branch like operations must be rewritten.
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
auto successorOperands = branchOp.getSuccessorOperands(p);
if (successorOperands.hasValue() &&
!typeConverter.isLegal(successorOperands.getValue().getTypes()))
return false;
}
return true;
}
return false;
});
if (failed(applyFullConversion(module, target, std::move(patterns))))

View File

@ -13,21 +13,19 @@
using namespace mlir;
namespace {
// Converts the operand and result types of the Standard's CallOp, used together
// with the FuncOpSignatureConversion.
/// Converts the operand and result types of the Standard's CallOp, used
/// together with the FuncOpSignatureConversion.
struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
: OpConversionPattern(ctx), converter(converter) {}
using OpConversionPattern<CallOp>::OpConversionPattern;
/// Hook for derived classes to implement combined matching and rewriting.
LogicalResult
matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
FunctionType type = callOp.getCalleeType();
// Convert the original function results.
SmallVector<Type, 1> convertedResults;
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
convertedResults)))
return failure();
// Substitute with the new result types from the corresponding FuncType
@ -36,14 +34,77 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
convertedResults, operands);
return success();
}
/// The type converter to use when rewriting the signature.
TypeConverter &converter;
};
} // end anonymous namespace
void mlir::populateCallOpTypeConversionPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx,
TypeConverter &converter) {
patterns.insert<CallOpSignatureConversion>(ctx, converter);
patterns.insert<CallOpSignatureConversion>(converter, ctx);
}
namespace {
/// Only needed to support partial conversion of functions where this pattern
/// ensures that the branch operation arguments matches up with the succesor
/// block arguments.
class BranchOpInterfaceTypeConversion : public ConversionPattern {
public:
BranchOpInterfaceTypeConversion(TypeConverter &typeConverter,
MLIRContext *ctx)
: ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
LogicalResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
auto branchOp = dyn_cast<BranchOpInterface>(op);
if (!branchOp)
return failure();
// For a branch operation, only some operands go to the target blocks, so
// only rewrite those.
SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
succIdx < succEnd; ++succIdx) {
auto successorOperands = branchOp.getSuccessorOperands(succIdx);
if (!successorOperands)
continue;
for (int idx = successorOperands->getBeginOperandIndex(),
eidx = idx + successorOperands->size();
idx < eidx; ++idx) {
newOperands[idx] = operands[idx];
}
}
rewriter.updateRootInPlace(
op, [newOperands, op]() { op->setOperands(newOperands); });
return success();
}
};
} // end anonymous namespace
namespace {
/// Only needed to support partial conversion of functions where this pattern
/// ensures that the branch operation arguments matches up with the succesor
/// block arguments.
class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
public:
using OpConversionPattern<ReturnOp>::OpConversionPattern;
LogicalResult
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const final {
// For a return, all operands go to the results of the parent, so
// rewrite them all.
Operation *operation = op.getOperation();
rewriter.updateRootInPlace(
op, [operands, operation]() { operation->setOperands(operands); });
return success();
}
};
} // end anonymous namespace
void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
OwningRewritePatternList &patterns, MLIRContext *ctx,
TypeConverter &typeConverter) {
patterns.insert<BranchOpInterfaceTypeConversion, ReturnOpTypeConversion>(
typeConverter, ctx);
}

View File

@ -7,7 +7,9 @@
//===----------------------------------------------------------------------===//
#include "mlir/Transforms/Bufferize.h"
#include "PassDetail.h"
#include "mlir/IR/Operation.h"
#include "mlir/Transforms/Passes.h"
using namespace mlir;
@ -15,6 +17,13 @@ using namespace mlir;
// BufferizeTypeConverter
//===----------------------------------------------------------------------===//
static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseMemRefType>());
return builder.create<TensorLoadOp>(loc, type, inputs[0]);
}
/// Registers conversions into BufferizeTypeConverter
BufferizeTypeConverter::BufferizeTypeConverter() {
// Keep all types unchanged.
@ -27,12 +36,8 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
addConversion([](UnrankedTensorType type) -> Type {
return UnrankedMemRefType::get(type.getElementType(), 0);
});
addSourceMaterialization([](OpBuilder &builder, TensorType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<BaseMemRefType>());
return builder.create<TensorLoadOp>(loc, type, inputs[0]);
});
addArgumentMaterialization(materializeTensorLoad);
addSourceMaterialization(materializeTensorLoad);
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
@ -83,3 +88,37 @@ void mlir::populateEliminateBufferizeMaterializationsPatterns(
patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
typeConverter, context);
}
namespace {
struct FinalizingBufferizePass
: public FinalizingBufferizeBase<FinalizingBufferizePass> {
using FinalizingBufferizeBase<
FinalizingBufferizePass>::FinalizingBufferizeBase;
void runOnFunction() override {
auto func = getFunction();
auto *context = &getContext();
BufferizeTypeConverter typeConverter;
OwningRewritePatternList patterns;
ConversionTarget target(*context);
populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
patterns);
target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
// If all result types are legal, and all block arguments are legal (ensured
// by func conversion above), then all types in the program are legal.
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return typeConverter.isLegal(op->getResultTypes());
});
if (failed(applyFullConversion(func, target, std::move(patterns))))
signalPassFailure();
}
};
} // namespace
std::unique_ptr<FunctionPass> mlir::createFinalizingBufferizePass() {
return std::make_unique<FinalizingBufferizePass>();
}

View File

@ -0,0 +1,59 @@
// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics --debug-only=dialect-conversion | FileCheck %s
// CHECK-LABEL: func @block_arguments(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32>
// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32>
// CHECK: br ^bb1(%[[M1]] : memref<f32>)
// CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>):
// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32>
// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32>
// CHECK: return %[[M2]] : memref<f32>
func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
br ^bb1(%arg0: tensor<f32>)
^bb1(%bbarg: tensor<f32>):
return %bbarg : tensor<f32>
}
// CHECK-LABEL: func @partial()
// CHECK-SAME: memref<f32>
func @partial() -> tensor<f32> {
// CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
// CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref<f32>
%0 = "test.source"() : () -> tensor<f32>
// CHECK-NEXT: return %[[MEM]] : memref<f32>
return %0 : tensor<f32>
}
// CHECK-LABEL: func @region_op
// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref<f32>
func @region_op(%arg0: i1) -> tensor<f32> {
// CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor<f32>)
%0 = scf.if %arg0 -> (tensor<f32>) {
// CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
%1 = "test.source"() : () -> tensor<f32>
// CHECK-NEXT: scf.yield %[[SRC]] : tensor<f32>
scf.yield %1 : tensor<f32>
// CHECK-NEXT: else
} else {
// CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor<f32>
%1 = "test.other_source"() : () -> tensor<f32>
// CHECK-NEXT: scf.yield %[[OSRC]] : tensor<f32>
scf.yield %1 : tensor<f32>
}
// CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref<f32>
// CHECK: return %[[MEM]] : memref<f32>
return %0 : tensor<f32>
}
// -----
func @failed_to_legalize(%arg0: tensor<f32>) -> tensor<f32> {
%0 = constant true
cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>)
^bb1(%bbarg0: tensor<f32>):
// expected-error @+1 {{failed to legalize operation 'test.terminator'}}
"test.terminator"() : () -> ()
^bb2(%bbarg1: tensor<f32>):
return %bbarg1 : tensor<f32>
}

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
// CHECK-LABEL: func @identity(
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {