[mlir][bufferization][NFC] Move more unknown type conversion logic into BufferizationOptions

The `unknownTypeConversion` bufferization option (enum) is now a type converter function option. Some logic of `getMemRefType` is now handled by that function.

This change makes type conversion more controllable. Previously, there were only two options when generating memref types for non-bufferizable ops: Static identity layout or fully dynamic layout. With this change, users of One-Shot Bufferize can provide a function with custom logic.

Differential Revision: https://reviews.llvm.org/D129273
This commit is contained in:
Matthias Springer 2022-07-07 13:35:36 +02:00
parent 8d9dc83f35
commit 606f7c8f7a
5 changed files with 58 additions and 49 deletions

View File

@ -179,6 +179,10 @@ struct BufferizationOptions {
/// Initializer function for dialect-specific analysis state. /// Initializer function for dialect-specific analysis state.
using DialectStateInitFn = using DialectStateInitFn =
std::function<std::unique_ptr<DialectAnalysisState>()>; std::function<std::unique_ptr<DialectAnalysisState>()>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Value, unsigned, const BufferizationOptions &)>;
enum class LayoutMapOption : int8_t { enum class LayoutMapOption : int8_t {
InferLayoutMap = 0, InferLayoutMap = 0,
@ -266,21 +270,11 @@ struct BufferizationOptions {
LayoutMapOption functionBoundaryTypeConversion = LayoutMapOption functionBoundaryTypeConversion =
LayoutMapOption::InferLayoutMap; LayoutMapOption::InferLayoutMap;
/// This flag controls buffer types on unknown ops (to_memref wrappers) and in /// Type converter from tensors to memrefs. This type converter is used if no
/// other cases where a precise memref type cannot be inferred (e.g., the /// memref type could be inferred during bufferization. By default, a type
/// bufferization of "tensor.cast"). /// converter that returns a memref type with a fully dynamic layout map is
/// /// used.
/// * InferLayoutMap: This option is invalid and cannot be used. UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
/// * FullyDynamicLayoutMap: Assume that unknown ops have results with fully
/// dynamic layout maps after bufferization. This option is most efficient
/// because any layout map can be casted to a fully dynamic one.
/// * IdentityLayoutMap: Assume that unknown ops have results with static
/// identity layout (i.e., no layout map) after bufferization. This option
/// introduces additional buffer allocs and copies if the unknown op is
/// eventually bufferized to an op that returns a buffer with non-identity
/// layout.
LayoutMapOption unknownTypeConversion =
LayoutMapOption::FullyDynamicLayoutMap;
/// Specifies whether dealloc ops should be generated along with alloc ops. If /// Specifies whether dealloc ops should be generated along with alloc ops. If
/// not, new memory allocations will leak. /// not, new memory allocations will leak.
@ -505,20 +499,19 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp; return newOp;
} }
/// Return a MemRefType to which the `tensorType` can be bufferized. /// Return a MemRefType to which the type of the given value can be bufferized.
/// ///
/// If possible, op bufferization implementations should not use this function /// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves. /// and instead infer precise memref types for tensor results by themselves.
/// ///
/// Unless a layout map was specified, `options.unknownTypeConverter` determines /// Unless a layout map was specified, `options.unknownTypeConverterFn`
/// what kind of layout map will be used. For best composability (without /// determines what kind of layout map will be used. For best composability
/// copies), the fully dynamic layout map is used by default. /// (without copies), the fully dynamic layout map is used by default.
/// ///
/// Note: Canonicalization patterns could clean up layout maps and infer more /// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible /// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented. /// canonicalizations are currently not implemented.
BaseMemRefType getMemRefType(TensorType tensorType, BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {}, MemRefLayoutAttrInterface layout = {},
unsigned memorySpace = 0); unsigned memorySpace = 0);

View File

@ -351,8 +351,9 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*defaultImplementation=*/[{ /*defaultImplementation=*/[{
assert(bbArg.getOwner()->getParentOp() == $_op && assert(bbArg.getOwner()->getParentOp() == $_op &&
"bbArg must belong to this op"); "bbArg must belong to this op");
auto tensorType = bbArg.getType().cast<TensorType>(); assert(bbArg.getType().isa<TensorType>() &&
return bufferization::getMemRefType(tensorType, options); "expected tensor type");
return bufferization::getMemRefType(bbArg, options);
}] }]
>, >,
InterfaceMethod< InterfaceMethod<

View File

@ -222,8 +222,17 @@ bool OpFilter::isOpAllowed(Operation *op) const {
// BufferizationOptions // BufferizationOptions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
/// Default unknown type converter: Use a fully dynamic layout map.
static BaseMemRefType
defaultUnknownTypeConverter(Value value, unsigned memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
memorySpace);
}
// Default constructor for BufferizationOptions. // Default constructor for BufferizationOptions.
BufferizationOptions::BufferizationOptions() = default; BufferizationOptions::BufferizationOptions()
: unknownTypeConverterFn(defaultUnknownTypeConverter) {}
bool BufferizationOptions::isOpAllowed(Operation *op) const { bool BufferizationOptions::isOpAllowed(Operation *op) const {
// Special case: If function boundary bufferization is deactivated, do not // Special case: If function boundary bufferization is deactivated, do not
@ -528,8 +537,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization. /// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType> FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) { bufferization::getBufferType(Value value, const BufferizationOptions &options) {
auto tensorType = value.getType().dyn_cast<TensorType>(); assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
assert(tensorType && "unexpected non-tensor type");
Operation *op = getOwnerOfValue(value); Operation *op = getOwnerOfValue(value);
// ToTensorOp: Take buffer type directly from the op. // ToTensorOp: Take buffer type directly from the op.
@ -566,7 +574,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
if (!memorySpace.hasValue()) if (!memorySpace.hasValue())
return op->emitError("could not infer memory space"); return op->emitError("could not infer memory space");
return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace); return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
} }
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter, void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
@ -652,10 +660,11 @@ bool bufferization::isFunctionArgument(Value value) {
return isa<func::FuncOp>(bbArg.getOwner()->getParentOp()); return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
} }
BaseMemRefType bufferization::getMemRefType(TensorType tensorType, BaseMemRefType bufferization::getMemRefType(Value value,
const BufferizationOptions &options, const BufferizationOptions &options,
MemRefLayoutAttrInterface layout, MemRefLayoutAttrInterface layout,
unsigned memorySpace) { unsigned memorySpace) {
auto tensorType = value.getType().cast<TensorType>();
auto memorySpaceAttr = IntegerAttr::get( auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace); IntegerType::get(tensorType.getContext(), 64), memorySpace);
@ -674,17 +683,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
memorySpaceAttr); memorySpaceAttr);
} }
// Case 3: Configured with "fully dynamic layout maps". return options.unknownTypeConverterFn(value, memorySpace, options);
if (options.unknownTypeConversion ==
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
// Case 4: Configured with "static identity layout maps".
if (options.unknownTypeConversion ==
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
llvm_unreachable("InferLayoutMap is an invalid option");
} }
BaseMemRefType BaseMemRefType

View File

@ -192,8 +192,26 @@ struct OneShotBufferizePass
opt.printConflicts = printConflicts; opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly; opt.testAnalysisOnly = testAnalysisOnly;
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries; opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
// Configure type converter.
BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
parseLayoutMapOption(unknownTypeConversion);
opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
const BufferizationOptions &options) {
auto tensorType = value.getType().cast<TensorType>();
if (unknownTypeConversionOption ==
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
assert(
unknownTypeConversionOption ==
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
"invalid layout map option");
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
};
// Configure op filter.
OpFilter::Entry::FilterFn filterFn = OpFilter::Entry::FilterFn filterFn =
[&](Operation *op) { [&](Operation *op) {
// Filter may be specified via options. // Filter may be specified via options.
@ -372,10 +390,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options, const BufferizationOptions &options,
bool copyBeforeWrite, bool copyBeforeWrite,
const OpFilter *opFilter) { const OpFilter *opFilter) {
assert(options.unknownTypeConversion !=
BufferizationOptions::LayoutMapOption::InferLayoutMap &&
"invalid layout map option");
if (copyBeforeWrite) { if (copyBeforeWrite) {
AnalysisState state(options); AnalysisState state(options);
if (failed(insertTensorCopies(op, state))) if (failed(insertTensorCopies(op, state)))
@ -474,8 +488,11 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
options.allowUnknownOps = true; options.allowUnknownOps = true;
options.createDeallocs = false; options.createDeallocs = false;
options.enforceAliasingInvariants = false; options.enforceAliasingInvariants = false;
options.unknownTypeConversion = options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
BufferizationOptions::LayoutMapOption::IdentityLayoutMap; const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
value.getType().cast<TensorType>(), memorySpace);
};
options.opFilter.allowDialect<BufferizationDialect>(); options.opFilter.allowDialect<BufferizationDialect>();
return options; return options;
} }

View File

@ -67,7 +67,7 @@ struct CastOpInterface
// Compute the new memref type. // Compute the new memref type.
Type resultMemRefType = Type resultMemRefType =
getMemRefType(resultTensorType, options, layout, getMemRefType(castOp.getResult(), options, layout,
sourceMemRefType.getMemorySpaceAsInt()); sourceMemRefType.getMemorySpaceAsInt());
// Replace the op with a memref.cast. // Replace the op with a memref.cast.
@ -780,9 +780,8 @@ struct ReshapeOpInterface
getBuffer(rewriter, reshapeOp.getShape(), options); getBuffer(rewriter, reshapeOp.getShape(), options);
if (failed(srcBuffer) || failed(shapeBuffer)) if (failed(srcBuffer) || failed(shapeBuffer))
return failure(); return failure();
auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
auto resultMemRefType = getMemRefType( auto resultMemRefType = getMemRefType(
resultTensorType, options, /*layout=*/{}, reshapeOp.getResult(), options, /*layout=*/{},
srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt()); srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
replaceOpWithNewBufferizedOp<memref::ReshapeOp>( replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer); rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);