forked from OSchip/llvm-project
[mlir][bufferization][NFC] Move more unknown type conversion logic into BufferizationOptions
The `unknownTypeConversion` bufferization option (enum) is now a type converter function option. Some logic of `getMemRefType` is now handled by that function. This change makes type conversion more controllable. Previously, there were only two options when generating memref types for non-bufferizable ops: Static identity layout or fully dynamic layout. With this change, users of One-Shot Bufferize can provide a function with custom logic. Differential Revision: https://reviews.llvm.org/D129273
This commit is contained in:
parent
8d9dc83f35
commit
606f7c8f7a
|
@ -179,6 +179,10 @@ struct BufferizationOptions {
|
|||
/// Initializer function for dialect-specific analysis state.
|
||||
using DialectStateInitFn =
|
||||
std::function<std::unique_ptr<DialectAnalysisState>()>;
|
||||
/// Tensor -> MemRef type converter.
|
||||
/// Parameters: Value, memory space, bufferization options
|
||||
using UnknownTypeConverterFn = std::function<BaseMemRefType(
|
||||
Value, unsigned, const BufferizationOptions &)>;
|
||||
|
||||
enum class LayoutMapOption : int8_t {
|
||||
InferLayoutMap = 0,
|
||||
|
@ -266,21 +270,11 @@ struct BufferizationOptions {
|
|||
LayoutMapOption functionBoundaryTypeConversion =
|
||||
LayoutMapOption::InferLayoutMap;
|
||||
|
||||
/// This flag controls buffer types on unknown ops (to_memref wrappers) and in
|
||||
/// other cases where a precise memref type cannot be inferred (e.g., the
|
||||
/// bufferization of "tensor.cast").
|
||||
///
|
||||
/// * InferLayoutMap: This option is invalid and cannot be used.
|
||||
/// * FullyDynamicLayoutMap: Assume that unknown ops have results with fully
|
||||
/// dynamic layout maps after bufferization. This option is most efficient
|
||||
/// because any layout map can be casted to a fully dynamic one.
|
||||
/// * IdentityLayoutMap: Assume that unknown ops have results with static
|
||||
/// identity layout (i.e., no layout map) after bufferization. This option
|
||||
/// introduces additional buffer allocs and copies if the unknown op is
|
||||
/// eventually bufferized to an op that returns a buffer with non-identity
|
||||
/// layout.
|
||||
LayoutMapOption unknownTypeConversion =
|
||||
LayoutMapOption::FullyDynamicLayoutMap;
|
||||
/// Type converter from tensors to memrefs. This type converter is used if no
|
||||
/// memref type could be inferred during bufferization. By default, a type
|
||||
/// converter that returns a memref type with a fully dynamic layout map is
|
||||
/// used.
|
||||
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
|
||||
|
||||
/// Specifies whether dealloc ops should be generated along with alloc ops. If
|
||||
/// not, new memory allocations will leak.
|
||||
|
@ -505,20 +499,19 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
|
|||
return newOp;
|
||||
}
|
||||
|
||||
/// Return a MemRefType to which the `tensorType` can be bufferized.
|
||||
/// Return a MemRefType to which the type of the given value can be bufferized.
|
||||
///
|
||||
/// If possible, op bufferization implementations should not use this function
|
||||
/// and instead infer precise memref types for tensor results by themselves.
|
||||
///
|
||||
/// Unless a layout map was specified, `options.unknownTypeConverter` determines
|
||||
/// what kind of layout map will be used. For best composability (without
|
||||
/// copies), the fully dynamic layout map is used by default.
|
||||
/// Unless a layout map was specified, `options.unknownTypeConverterFn`
|
||||
/// determines what kind of layout map will be used. For best composability
|
||||
/// (without copies), the fully dynamic layout map is used by default.
|
||||
///
|
||||
/// Note: Canonicalization patterns could clean up layout maps and infer more
|
||||
/// precise layout maps after bufferization. However, many possible
|
||||
/// canonicalizations are currently not implemented.
|
||||
BaseMemRefType getMemRefType(TensorType tensorType,
|
||||
const BufferizationOptions &options,
|
||||
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
|
||||
MemRefLayoutAttrInterface layout = {},
|
||||
unsigned memorySpace = 0);
|
||||
|
||||
|
|
|
@ -351,8 +351,9 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
|||
/*defaultImplementation=*/[{
|
||||
assert(bbArg.getOwner()->getParentOp() == $_op &&
|
||||
"bbArg must belong to this op");
|
||||
auto tensorType = bbArg.getType().cast<TensorType>();
|
||||
return bufferization::getMemRefType(tensorType, options);
|
||||
assert(bbArg.getType().isa<TensorType>() &&
|
||||
"expected tensor type");
|
||||
return bufferization::getMemRefType(bbArg, options);
|
||||
}]
|
||||
>,
|
||||
InterfaceMethod<
|
||||
|
|
|
@ -222,8 +222,17 @@ bool OpFilter::isOpAllowed(Operation *op) const {
|
|||
// BufferizationOptions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Default unknown type converter: Use a fully dynamic layout map.
|
||||
static BaseMemRefType
|
||||
defaultUnknownTypeConverter(Value value, unsigned memorySpace,
|
||||
const BufferizationOptions &options) {
|
||||
return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
|
||||
memorySpace);
|
||||
}
|
||||
|
||||
// Default constructor for BufferizationOptions.
|
||||
BufferizationOptions::BufferizationOptions() = default;
|
||||
BufferizationOptions::BufferizationOptions()
|
||||
: unknownTypeConverterFn(defaultUnknownTypeConverter) {}
|
||||
|
||||
bool BufferizationOptions::isOpAllowed(Operation *op) const {
|
||||
// Special case: If function boundary bufferization is deactivated, do not
|
||||
|
@ -528,8 +537,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
|
|||
/// Return the buffer type for a given Value (tensor) after bufferization.
|
||||
FailureOr<BaseMemRefType>
|
||||
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
|
||||
auto tensorType = value.getType().dyn_cast<TensorType>();
|
||||
assert(tensorType && "unexpected non-tensor type");
|
||||
assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||
Operation *op = getOwnerOfValue(value);
|
||||
|
||||
// ToTensorOp: Take buffer type directly from the op.
|
||||
|
@ -566,7 +574,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
|
|||
if (!memorySpace.hasValue())
|
||||
return op->emitError("could not infer memory space");
|
||||
|
||||
return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace);
|
||||
return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
|
||||
}
|
||||
|
||||
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
|
||||
|
@ -652,10 +660,11 @@ bool bufferization::isFunctionArgument(Value value) {
|
|||
return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
|
||||
}
|
||||
|
||||
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
|
||||
BaseMemRefType bufferization::getMemRefType(Value value,
|
||||
const BufferizationOptions &options,
|
||||
MemRefLayoutAttrInterface layout,
|
||||
unsigned memorySpace) {
|
||||
auto tensorType = value.getType().cast<TensorType>();
|
||||
auto memorySpaceAttr = IntegerAttr::get(
|
||||
IntegerType::get(tensorType.getContext(), 64), memorySpace);
|
||||
|
||||
|
@ -674,17 +683,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
|
|||
memorySpaceAttr);
|
||||
}
|
||||
|
||||
// Case 3: Configured with "fully dynamic layout maps".
|
||||
if (options.unknownTypeConversion ==
|
||||
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
|
||||
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
|
||||
|
||||
// Case 4: Configured with "static identity layout maps".
|
||||
if (options.unknownTypeConversion ==
|
||||
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
|
||||
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
|
||||
|
||||
llvm_unreachable("InferLayoutMap is an invalid option");
|
||||
return options.unknownTypeConverterFn(value, memorySpace, options);
|
||||
}
|
||||
|
||||
BaseMemRefType
|
||||
|
|
|
@ -192,8 +192,26 @@ struct OneShotBufferizePass
|
|||
opt.printConflicts = printConflicts;
|
||||
opt.testAnalysisOnly = testAnalysisOnly;
|
||||
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
|
||||
opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
|
||||
|
||||
// Configure type converter.
|
||||
BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
|
||||
parseLayoutMapOption(unknownTypeConversion);
|
||||
opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
|
||||
const BufferizationOptions &options) {
|
||||
auto tensorType = value.getType().cast<TensorType>();
|
||||
if (unknownTypeConversionOption ==
|
||||
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
|
||||
return bufferization::getMemRefTypeWithStaticIdentityLayout(
|
||||
tensorType, memorySpace);
|
||||
assert(
|
||||
unknownTypeConversionOption ==
|
||||
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
|
||||
"invalid layout map option");
|
||||
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
|
||||
memorySpace);
|
||||
};
|
||||
|
||||
// Configure op filter.
|
||||
OpFilter::Entry::FilterFn filterFn =
|
||||
[&](Operation *op) {
|
||||
// Filter may be specified via options.
|
||||
|
@ -372,10 +390,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
|
|||
const BufferizationOptions &options,
|
||||
bool copyBeforeWrite,
|
||||
const OpFilter *opFilter) {
|
||||
assert(options.unknownTypeConversion !=
|
||||
BufferizationOptions::LayoutMapOption::InferLayoutMap &&
|
||||
"invalid layout map option");
|
||||
|
||||
if (copyBeforeWrite) {
|
||||
AnalysisState state(options);
|
||||
if (failed(insertTensorCopies(op, state)))
|
||||
|
@ -474,8 +488,11 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
|
|||
options.allowUnknownOps = true;
|
||||
options.createDeallocs = false;
|
||||
options.enforceAliasingInvariants = false;
|
||||
options.unknownTypeConversion =
|
||||
BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
|
||||
options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
|
||||
const BufferizationOptions &options) {
|
||||
return getMemRefTypeWithStaticIdentityLayout(
|
||||
value.getType().cast<TensorType>(), memorySpace);
|
||||
};
|
||||
options.opFilter.allowDialect<BufferizationDialect>();
|
||||
return options;
|
||||
}
|
||||
|
|
|
@ -67,7 +67,7 @@ struct CastOpInterface
|
|||
|
||||
// Compute the new memref type.
|
||||
Type resultMemRefType =
|
||||
getMemRefType(resultTensorType, options, layout,
|
||||
getMemRefType(castOp.getResult(), options, layout,
|
||||
sourceMemRefType.getMemorySpaceAsInt());
|
||||
|
||||
// Replace the op with a memref.cast.
|
||||
|
@ -780,9 +780,8 @@ struct ReshapeOpInterface
|
|||
getBuffer(rewriter, reshapeOp.getShape(), options);
|
||||
if (failed(srcBuffer) || failed(shapeBuffer))
|
||||
return failure();
|
||||
auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
|
||||
auto resultMemRefType = getMemRefType(
|
||||
resultTensorType, options, /*layout=*/{},
|
||||
reshapeOp.getResult(), options, /*layout=*/{},
|
||||
srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
|
||||
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
|
||||
rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
|
||||
|
|
Loading…
Reference in New Issue