forked from OSchip/llvm-project
[mlir][bufferize] Add argument materialization for bufferization
This enables partial bufferization that includes function signatures. To test this, this change also makes the func-bufferize partial and adds a dedicated finalizing-bufferize pass. Differential Revision: https://reviews.llvm.org/D92032
This commit is contained in:
parent
1ca174b642
commit
4dd5f79f07
|
@ -26,6 +26,13 @@ void populateCallOpTypeConversionPattern(OwningRewritePatternList &patterns,
|
|||
MLIRContext *ctx,
|
||||
TypeConverter &converter);
|
||||
|
||||
/// Add a pattern to the given pattern list to rewrite branch operations and
|
||||
/// `return` to use operands that have been legalized by the conversion
|
||||
/// framework. This can only be done if the branch operation implements the
|
||||
/// BranchOpInterface. Only needed for partial conversions.
|
||||
void populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx,
|
||||
TypeConverter &converter);
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_DIALECT_STANDARDOPS_TRANSFORMS_FUNCCONVERSIONS_H_
|
||||
|
|
|
@ -25,28 +25,26 @@ def StdExpandOps : FunctionPass<"std-expand"> {
|
|||
def FuncBufferize : Pass<"func-bufferize", "ModuleOp"> {
|
||||
let summary = "Bufferize func/call/return ops";
|
||||
let description = [{
|
||||
A finalizing bufferize pass that bufferizes std.func and std.call ops.
|
||||
A bufferize pass that bufferizes std.func and std.call ops.
|
||||
|
||||
Because this pass updates std.func ops, it must be a module pass. It is
|
||||
useful to keep this pass separate from other bufferizations so that the
|
||||
other ones can be run at function-level in parallel.
|
||||
|
||||
This pass must be done atomically for two reasons:
|
||||
1. This pass changes func op signatures, which requires atomically updating
|
||||
calls as well throughout the entire module.
|
||||
2. This pass changes the type of block arguments, which requires that all
|
||||
successor arguments of predecessors be converted. Terminators are not
|
||||
a closed universe (and need not implement BranchOpInterface), and so we
|
||||
cannot in general rewrite them.
|
||||
This pass must be done atomically because it changes func op signatures,
|
||||
which requires atomically updating calls as well throughout the entire
|
||||
module.
|
||||
|
||||
Note, because this is a "finalizing" bufferize step, it can create
|
||||
invalid IR because it will not create materializations. To avoid this
|
||||
situation, the pass must only be run when the only SSA values of
|
||||
tensor type are:
|
||||
- block arguments
|
||||
- the result of tensor_load
|
||||
Other values of tensor type should be eliminated by earlier
|
||||
bufferization passes.
|
||||
This pass also changes the type of block arguments, which requires that all
|
||||
successor arguments of predecessors be converted. This is achieved by
|
||||
rewriting terminators based on the information provided by the
|
||||
`BranchOpInterface`.
|
||||
As this pass rewrites function operations, it also rewrites the
|
||||
corresponding return operations. Other return-like operations that
|
||||
implement the `ReturnLike` trait are not rewritten in general, as they
|
||||
require that the correspondign parent operation is also rewritten.
|
||||
Finally, this pass fails for unknown terminators, as we cannot decide
|
||||
whether they need rewriting.
|
||||
}];
|
||||
let constructor = "mlir::createFuncBufferizePass()";
|
||||
}
|
||||
|
|
|
@ -46,6 +46,10 @@ std::unique_ptr<Pass>
|
|||
createPromoteBuffersToStackPass(unsigned maxAllocSizeInBytes = 1024,
|
||||
unsigned bitwidthOfIndexType = 64);
|
||||
|
||||
/// Creates a pass that finalizes a partial bufferization by removing remaining
|
||||
/// tensor_load and tensor_to_memref operations.
|
||||
std::unique_ptr<FunctionPass> createFinalizingBufferizePass();
|
||||
|
||||
/// Creates a pass that converts memref function results to out-params.
|
||||
std::unique_ptr<Pass> createBufferResultsToOutParamsPass();
|
||||
|
||||
|
|
|
@ -290,6 +290,22 @@ def Inliner : Pass<"inline"> {
|
|||
];
|
||||
}
|
||||
|
||||
def FinalizingBufferize : FunctionPass<"finalizing-bufferize"> {
|
||||
let summary = "Finalize a partial bufferization";
|
||||
let description = [{
|
||||
A bufferize pass that finalizes a partial bufferization by removing
|
||||
remaining `tensor_load` and `tensor_to_memref` operations.
|
||||
|
||||
The removal of those operations is only possible if the operations only
|
||||
exist in pairs, i.e., all uses of `tensor_load` operations are
|
||||
`tensor_to_memref` operations.
|
||||
|
||||
This pass will fail if not all operations can be removed or if any operation
|
||||
with tensor typed operands remains.
|
||||
}];
|
||||
let constructor = "mlir::createFinalizingBufferizePass()";
|
||||
}
|
||||
|
||||
def LocationSnapshot : Pass<"snapshot-op-locations"> {
|
||||
let summary = "Generate new locations from the current IR";
|
||||
let description = [{
|
||||
|
|
|
@ -21,6 +21,8 @@ using namespace mlir;
|
|||
|
||||
namespace {
|
||||
struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
|
||||
using FuncBufferizeBase<FuncBufferizePass>::FuncBufferizeBase;
|
||||
|
||||
void runOnOperation() override {
|
||||
auto module = getOperation();
|
||||
auto *context = &getContext();
|
||||
|
@ -35,14 +37,42 @@ struct FuncBufferizePass : public FuncBufferizeBase<FuncBufferizePass> {
|
|||
typeConverter.isLegal(&op.getBody());
|
||||
});
|
||||
populateCallOpTypeConversionPattern(patterns, context, typeConverter);
|
||||
populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
|
||||
patterns);
|
||||
target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
|
||||
target.addDynamicallyLegalOp<CallOp>(
|
||||
[&](CallOp op) { return typeConverter.isLegal(op); });
|
||||
|
||||
// If all result types are legal, and all block arguments are legal (ensured
|
||||
// by func conversion above), then all types in the program are legal.
|
||||
populateBranchOpInterfaceAndReturnOpTypeConversionPattern(patterns, context,
|
||||
typeConverter);
|
||||
target.addLegalOp<ModuleOp, ModuleTerminatorOp, TensorLoadOp,
|
||||
TensorToMemrefOp>();
|
||||
target.addDynamicallyLegalOp<ReturnOp>(
|
||||
[&](ReturnOp op) { return typeConverter.isLegal(op); });
|
||||
// Mark terminators as legal if they have the ReturnLike trait or
|
||||
// implement the BranchOpInterface and have valid types. If they do not
|
||||
// implement the trait or interface, mark them as illegal no matter what.
|
||||
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
|
||||
return typeConverter.isLegal(op->getResultTypes());
|
||||
// If it is not a terminator, ignore it.
|
||||
if (op->isKnownNonTerminator())
|
||||
return true;
|
||||
// If it is not the last operation in the block, also ignore it. We do
|
||||
// this to handle unknown operations, as well.
|
||||
Block *block = op->getBlock();
|
||||
if (!block || &block->back() != op)
|
||||
return true;
|
||||
// ReturnLike operations have to be legalized with their parent. For
|
||||
// return this is handled, for other ops they remain as is.
|
||||
if (op->hasTrait<OpTrait::ReturnLike>())
|
||||
return true;
|
||||
// All successor operands of branch like operations must be rewritten.
|
||||
if (auto branchOp = dyn_cast<BranchOpInterface>(op)) {
|
||||
for (int p = 0, e = op->getBlock()->getNumSuccessors(); p < e; ++p) {
|
||||
auto successorOperands = branchOp.getSuccessorOperands(p);
|
||||
if (successorOperands.hasValue() &&
|
||||
!typeConverter.isLegal(successorOperands.getValue().getTypes()))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
});
|
||||
|
||||
if (failed(applyFullConversion(module, target, std::move(patterns))))
|
||||
|
|
|
@ -13,21 +13,19 @@
|
|||
using namespace mlir;
|
||||
|
||||
namespace {
|
||||
// Converts the operand and result types of the Standard's CallOp, used together
|
||||
// with the FuncOpSignatureConversion.
|
||||
/// Converts the operand and result types of the Standard's CallOp, used
|
||||
/// together with the FuncOpSignatureConversion.
|
||||
struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
|
||||
CallOpSignatureConversion(MLIRContext *ctx, TypeConverter &converter)
|
||||
: OpConversionPattern(ctx), converter(converter) {}
|
||||
using OpConversionPattern<CallOp>::OpConversionPattern;
|
||||
|
||||
/// Hook for derived classes to implement combined matching and rewriting.
|
||||
LogicalResult
|
||||
matchAndRewrite(CallOp callOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
FunctionType type = callOp.getCalleeType();
|
||||
|
||||
// Convert the original function results.
|
||||
SmallVector<Type, 1> convertedResults;
|
||||
if (failed(converter.convertTypes(type.getResults(), convertedResults)))
|
||||
if (failed(typeConverter->convertTypes(callOp.getResultTypes(),
|
||||
convertedResults)))
|
||||
return failure();
|
||||
|
||||
// Substitute with the new result types from the corresponding FuncType
|
||||
|
@ -36,14 +34,77 @@ struct CallOpSignatureConversion : public OpConversionPattern<CallOp> {
|
|||
convertedResults, operands);
|
||||
return success();
|
||||
}
|
||||
|
||||
/// The type converter to use when rewriting the signature.
|
||||
TypeConverter &converter;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void mlir::populateCallOpTypeConversionPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx,
|
||||
TypeConverter &converter) {
|
||||
patterns.insert<CallOpSignatureConversion>(ctx, converter);
|
||||
patterns.insert<CallOpSignatureConversion>(converter, ctx);
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Only needed to support partial conversion of functions where this pattern
|
||||
/// ensures that the branch operation arguments matches up with the succesor
|
||||
/// block arguments.
|
||||
class BranchOpInterfaceTypeConversion : public ConversionPattern {
|
||||
public:
|
||||
BranchOpInterfaceTypeConversion(TypeConverter &typeConverter,
|
||||
MLIRContext *ctx)
|
||||
: ConversionPattern(/*benefit=*/1, typeConverter, MatchAnyOpTypeTag()) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
auto branchOp = dyn_cast<BranchOpInterface>(op);
|
||||
if (!branchOp)
|
||||
return failure();
|
||||
|
||||
// For a branch operation, only some operands go to the target blocks, so
|
||||
// only rewrite those.
|
||||
SmallVector<Value, 4> newOperands(op->operand_begin(), op->operand_end());
|
||||
for (int succIdx = 0, succEnd = op->getBlock()->getNumSuccessors();
|
||||
succIdx < succEnd; ++succIdx) {
|
||||
auto successorOperands = branchOp.getSuccessorOperands(succIdx);
|
||||
if (!successorOperands)
|
||||
continue;
|
||||
for (int idx = successorOperands->getBeginOperandIndex(),
|
||||
eidx = idx + successorOperands->size();
|
||||
idx < eidx; ++idx) {
|
||||
newOperands[idx] = operands[idx];
|
||||
}
|
||||
}
|
||||
rewriter.updateRootInPlace(
|
||||
op, [newOperands, op]() { op->setOperands(newOperands); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
namespace {
|
||||
/// Only needed to support partial conversion of functions where this pattern
|
||||
/// ensures that the branch operation arguments matches up with the succesor
|
||||
/// block arguments.
|
||||
class ReturnOpTypeConversion : public OpConversionPattern<ReturnOp> {
|
||||
public:
|
||||
using OpConversionPattern<ReturnOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(ReturnOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const final {
|
||||
// For a return, all operands go to the results of the parent, so
|
||||
// rewrite them all.
|
||||
Operation *operation = op.getOperation();
|
||||
rewriter.updateRootInPlace(
|
||||
op, [operands, operation]() { operation->setOperands(operands); });
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
void mlir::populateBranchOpInterfaceAndReturnOpTypeConversionPattern(
|
||||
OwningRewritePatternList &patterns, MLIRContext *ctx,
|
||||
TypeConverter &typeConverter) {
|
||||
patterns.insert<BranchOpInterfaceTypeConversion, ReturnOpTypeConversion>(
|
||||
typeConverter, ctx);
|
||||
}
|
||||
|
|
|
@ -7,7 +7,9 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/Transforms/Bufferize.h"
|
||||
#include "PassDetail.h"
|
||||
#include "mlir/IR/Operation.h"
|
||||
#include "mlir/Transforms/Passes.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -15,6 +17,13 @@ using namespace mlir;
|
|||
// BufferizeTypeConverter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static Value materializeTensorLoad(OpBuilder &builder, TensorType type,
|
||||
ValueRange inputs, Location loc) {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<BaseMemRefType>());
|
||||
return builder.create<TensorLoadOp>(loc, type, inputs[0]);
|
||||
}
|
||||
|
||||
/// Registers conversions into BufferizeTypeConverter
|
||||
BufferizeTypeConverter::BufferizeTypeConverter() {
|
||||
// Keep all types unchanged.
|
||||
|
@ -27,12 +36,8 @@ BufferizeTypeConverter::BufferizeTypeConverter() {
|
|||
addConversion([](UnrankedTensorType type) -> Type {
|
||||
return UnrankedMemRefType::get(type.getElementType(), 0);
|
||||
});
|
||||
addSourceMaterialization([](OpBuilder &builder, TensorType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
assert(inputs[0].getType().isa<BaseMemRefType>());
|
||||
return builder.create<TensorLoadOp>(loc, type, inputs[0]);
|
||||
});
|
||||
addArgumentMaterialization(materializeTensorLoad);
|
||||
addSourceMaterialization(materializeTensorLoad);
|
||||
addTargetMaterialization([](OpBuilder &builder, BaseMemRefType type,
|
||||
ValueRange inputs, Location loc) -> Value {
|
||||
assert(inputs.size() == 1);
|
||||
|
@ -83,3 +88,37 @@ void mlir::populateEliminateBufferizeMaterializationsPatterns(
|
|||
patterns.insert<BufferizeTensorLoadOp, BufferizeTensorToMemrefOp>(
|
||||
typeConverter, context);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct FinalizingBufferizePass
|
||||
: public FinalizingBufferizeBase<FinalizingBufferizePass> {
|
||||
using FinalizingBufferizeBase<
|
||||
FinalizingBufferizePass>::FinalizingBufferizeBase;
|
||||
|
||||
void runOnFunction() override {
|
||||
auto func = getFunction();
|
||||
auto *context = &getContext();
|
||||
|
||||
BufferizeTypeConverter typeConverter;
|
||||
OwningRewritePatternList patterns;
|
||||
ConversionTarget target(*context);
|
||||
|
||||
populateEliminateBufferizeMaterializationsPatterns(context, typeConverter,
|
||||
patterns);
|
||||
target.addIllegalOp<TensorLoadOp, TensorToMemrefOp>();
|
||||
|
||||
// If all result types are legal, and all block arguments are legal (ensured
|
||||
// by func conversion above), then all types in the program are legal.
|
||||
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
|
||||
return typeConverter.isLegal(op->getResultTypes());
|
||||
});
|
||||
|
||||
if (failed(applyFullConversion(func, target, std::move(patterns))))
|
||||
signalPassFailure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<FunctionPass> mlir::createFinalizingBufferizePass() {
|
||||
return std::make_unique<FinalizingBufferizePass>();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,59 @@
|
|||
// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics --debug-only=dialect-conversion | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @block_arguments(
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
|
||||
// CHECK: %[[T1:.*]] = tensor_load %[[ARG]] : memref<f32>
|
||||
// CHECK: %[[M1:.*]] = tensor_to_memref %[[T1]] : memref<f32>
|
||||
// CHECK: br ^bb1(%[[M1]] : memref<f32>)
|
||||
// CHECK: ^bb1(%[[BBARG:.*]]: memref<f32>):
|
||||
// CHECK: %[[T2:.*]] = tensor_load %[[BBARG]] : memref<f32>
|
||||
// CHECK: %[[M2:.*]] = tensor_to_memref %[[T2]] : memref<f32>
|
||||
// CHECK: return %[[M2]] : memref<f32>
|
||||
func @block_arguments(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
br ^bb1(%arg0: tensor<f32>)
|
||||
^bb1(%bbarg: tensor<f32>):
|
||||
return %bbarg : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @partial()
|
||||
// CHECK-SAME: memref<f32>
|
||||
func @partial() -> tensor<f32> {
|
||||
// CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
|
||||
// CHECK-NEXT: %[[MEM:.*]] = tensor_to_memref %[[SRC]] : memref<f32>
|
||||
%0 = "test.source"() : () -> tensor<f32>
|
||||
// CHECK-NEXT: return %[[MEM]] : memref<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @region_op
|
||||
// CHECK-SAME: (%[[ARG0:.*]]: i1) -> memref<f32>
|
||||
func @region_op(%arg0: i1) -> tensor<f32> {
|
||||
// CHECK-NEXT: %[[IF:.*]] = scf.if %[[ARG0]] -> (tensor<f32>)
|
||||
%0 = scf.if %arg0 -> (tensor<f32>) {
|
||||
// CHECK-NEXT: %[[SRC:.*]] = "test.source"() : () -> tensor<f32>
|
||||
%1 = "test.source"() : () -> tensor<f32>
|
||||
// CHECK-NEXT: scf.yield %[[SRC]] : tensor<f32>
|
||||
scf.yield %1 : tensor<f32>
|
||||
// CHECK-NEXT: else
|
||||
} else {
|
||||
// CHECK-NEXT: %[[OSRC:.*]] = "test.other_source"() : () -> tensor<f32>
|
||||
%1 = "test.other_source"() : () -> tensor<f32>
|
||||
// CHECK-NEXT: scf.yield %[[OSRC]] : tensor<f32>
|
||||
scf.yield %1 : tensor<f32>
|
||||
}
|
||||
// CHECK: %[[MEM:.*]] = tensor_to_memref %[[IF]] : memref<f32>
|
||||
// CHECK: return %[[MEM]] : memref<f32>
|
||||
return %0 : tensor<f32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @failed_to_legalize(%arg0: tensor<f32>) -> tensor<f32> {
|
||||
%0 = constant true
|
||||
cond_br %0, ^bb1(%arg0: tensor<f32>), ^bb2(%arg0: tensor<f32>)
|
||||
^bb1(%bbarg0: tensor<f32>):
|
||||
// expected-error @+1 {{failed to legalize operation 'test.terminator'}}
|
||||
"test.terminator"() : () -> ()
|
||||
^bb2(%bbarg1: tensor<f32>):
|
||||
return %bbarg1 : tensor<f32>
|
||||
}
|
|
@ -1,4 +1,4 @@
|
|||
// RUN: mlir-opt %s -func-bufferize -split-input-file -verify-diagnostics | FileCheck %s
|
||||
// RUN: mlir-opt %s -func-bufferize -finalizing-bufferize -split-input-file -verify-diagnostics | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: func @identity(
|
||||
// CHECK-SAME: %[[ARG:.*]]: memref<f32>) -> memref<f32> {
|
||||
|
|
Loading…
Reference in New Issue