[mlir][bufferization][NFC] Move more unknown type conversion logic into BufferizationOptions

The `unknownTypeConversion` bufferization option (enum) is now a type converter function option. Some logic of `getMemRefType` is now handled by that function.

This change makes type conversion more controllable. Previously, there were only two options when generating memref types for non-bufferizable ops: Static identity layout or fully dynamic layout. With this change, users of One-Shot Bufferize can provide a function with custom logic.

Differential Revision: https://reviews.llvm.org/D129273
This commit is contained in:
Matthias Springer 2022-07-07 13:35:36 +02:00
parent 8d9dc83f35
commit 606f7c8f7a
5 changed files with 58 additions and 49 deletions

View File

@ -179,6 +179,10 @@ struct BufferizationOptions {
/// Initializer function for dialect-specific analysis state.
using DialectStateInitFn =
std::function<std::unique_ptr<DialectAnalysisState>()>;
/// Tensor -> MemRef type converter.
/// Parameters: Value, memory space, bufferization options
using UnknownTypeConverterFn = std::function<BaseMemRefType(
Value, unsigned, const BufferizationOptions &)>;
enum class LayoutMapOption : int8_t {
InferLayoutMap = 0,
@ -266,21 +270,11 @@ struct BufferizationOptions {
LayoutMapOption functionBoundaryTypeConversion =
LayoutMapOption::InferLayoutMap;
/// This flag controls buffer types on unknown ops (to_memref wrappers) and in
/// other cases where a precise memref type cannot be inferred (e.g., the
/// bufferization of "tensor.cast").
///
/// * InferLayoutMap: This option is invalid and cannot be used.
/// * FullyDynamicLayoutMap: Assume that unknown ops have results with fully
/// dynamic layout maps after bufferization. This option is most efficient
/// because any layout map can be casted to a fully dynamic one.
/// * IdentityLayoutMap: Assume that unknown ops have results with static
/// identity layout (i.e., no layout map) after bufferization. This option
/// introduces additional buffer allocs and copies if the unknown op is
/// eventually bufferized to an op that returns a buffer with non-identity
/// layout.
LayoutMapOption unknownTypeConversion =
LayoutMapOption::FullyDynamicLayoutMap;
/// Type converter from tensors to memrefs. This type converter is used if no
/// memref type could be inferred during bufferization. By default, a type
/// converter that returns a memref type with a fully dynamic layout map is
/// used.
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
/// Specifies whether dealloc ops should be generated along with alloc ops. If
/// not, new memory allocations will leak.
@ -505,20 +499,19 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
return newOp;
}
/// Return a MemRefType to which the `tensorType` can be bufferized.
/// Return a MemRefType to which the type of the given value can be bufferized.
///
/// If possible, op bufferization implementations should not use this function
/// and instead infer precise memref types for tensor results by themselves.
///
/// Unless a layout map was specified, `options.unknownTypeConverter` determines
/// what kind of layout map will be used. For best composability (without
/// copies), the fully dynamic layout map is used by default.
/// Unless a layout map was specified, `options.unknownTypeConverterFn`
/// determines what kind of layout map will be used. For best composability
/// (without copies), the fully dynamic layout map is used by default.
///
/// Note: Canonicalization patterns could clean up layout maps and infer more
/// precise layout maps after bufferization. However, many possible
/// canonicalizations are currently not implemented.
BaseMemRefType getMemRefType(TensorType tensorType,
const BufferizationOptions &options,
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
MemRefLayoutAttrInterface layout = {},
unsigned memorySpace = 0);

View File

@ -351,8 +351,9 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
/*defaultImplementation=*/[{
assert(bbArg.getOwner()->getParentOp() == $_op &&
"bbArg must belong to this op");
auto tensorType = bbArg.getType().cast<TensorType>();
return bufferization::getMemRefType(tensorType, options);
assert(bbArg.getType().isa<TensorType>() &&
"expected tensor type");
return bufferization::getMemRefType(bbArg, options);
}]
>,
InterfaceMethod<

View File

@ -222,8 +222,17 @@ bool OpFilter::isOpAllowed(Operation *op) const {
// BufferizationOptions
//===----------------------------------------------------------------------===//
/// Default unknown type converter: Use a fully dynamic layout map.
static BaseMemRefType
defaultUnknownTypeConverter(Value value, unsigned memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
memorySpace);
}
// Default constructor for BufferizationOptions.
BufferizationOptions::BufferizationOptions() = default;
BufferizationOptions::BufferizationOptions()
: unknownTypeConverterFn(defaultUnknownTypeConverter) {}
bool BufferizationOptions::isOpAllowed(Operation *op) const {
// Special case: If function boundary bufferization is deactivated, do not
@ -528,8 +537,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
/// Return the buffer type for a given Value (tensor) after bufferization.
FailureOr<BaseMemRefType>
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
auto tensorType = value.getType().dyn_cast<TensorType>();
assert(tensorType && "unexpected non-tensor type");
assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
Operation *op = getOwnerOfValue(value);
// ToTensorOp: Take buffer type directly from the op.
@ -566,7 +574,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
if (!memorySpace.hasValue())
return op->emitError("could not infer memory space");
return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace);
return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
}
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
@ -652,10 +660,11 @@ bool bufferization::isFunctionArgument(Value value) {
return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
}
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
BaseMemRefType bufferization::getMemRefType(Value value,
const BufferizationOptions &options,
MemRefLayoutAttrInterface layout,
unsigned memorySpace) {
auto tensorType = value.getType().cast<TensorType>();
auto memorySpaceAttr = IntegerAttr::get(
IntegerType::get(tensorType.getContext(), 64), memorySpace);
@ -674,17 +683,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
memorySpaceAttr);
}
// Case 3: Configured with "fully dynamic layout maps".
if (options.unknownTypeConversion ==
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
// Case 4: Configured with "static identity layout maps".
if (options.unknownTypeConversion ==
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
llvm_unreachable("InferLayoutMap is an invalid option");
return options.unknownTypeConverterFn(value, memorySpace, options);
}
BaseMemRefType

View File

@ -192,8 +192,26 @@ struct OneShotBufferizePass
opt.printConflicts = printConflicts;
opt.testAnalysisOnly = testAnalysisOnly;
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
// Configure type converter.
BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
parseLayoutMapOption(unknownTypeConversion);
opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
const BufferizationOptions &options) {
auto tensorType = value.getType().cast<TensorType>();
if (unknownTypeConversionOption ==
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
return bufferization::getMemRefTypeWithStaticIdentityLayout(
tensorType, memorySpace);
assert(
unknownTypeConversionOption ==
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
"invalid layout map option");
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
memorySpace);
};
// Configure op filter.
OpFilter::Entry::FilterFn filterFn =
[&](Operation *op) {
// Filter may be specified via options.
@ -372,10 +390,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
const BufferizationOptions &options,
bool copyBeforeWrite,
const OpFilter *opFilter) {
assert(options.unknownTypeConversion !=
BufferizationOptions::LayoutMapOption::InferLayoutMap &&
"invalid layout map option");
if (copyBeforeWrite) {
AnalysisState state(options);
if (failed(insertTensorCopies(op, state)))
@ -474,8 +488,11 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
options.allowUnknownOps = true;
options.createDeallocs = false;
options.enforceAliasingInvariants = false;
options.unknownTypeConversion =
BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
const BufferizationOptions &options) {
return getMemRefTypeWithStaticIdentityLayout(
value.getType().cast<TensorType>(), memorySpace);
};
options.opFilter.allowDialect<BufferizationDialect>();
return options;
}

View File

@ -67,7 +67,7 @@ struct CastOpInterface
// Compute the new memref type.
Type resultMemRefType =
getMemRefType(resultTensorType, options, layout,
getMemRefType(castOp.getResult(), options, layout,
sourceMemRefType.getMemorySpaceAsInt());
// Replace the op with a memref.cast.
@ -780,9 +780,8 @@ struct ReshapeOpInterface
getBuffer(rewriter, reshapeOp.getShape(), options);
if (failed(srcBuffer) || failed(shapeBuffer))
return failure();
auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
auto resultMemRefType = getMemRefType(
resultTensorType, options, /*layout=*/{},
reshapeOp.getResult(), options, /*layout=*/{},
srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);