forked from OSchip/llvm-project
[mlir][bufferization][NFC] Move more unknown type conversion logic into BufferizationOptions
The `unknownTypeConversion` bufferization option (enum) is now a type converter function option. Some logic of `getMemRefType` is now handled by that function. This change makes type conversion more controllable. Previously, there were only two options when generating memref types for non-bufferizable ops: Static identity layout or fully dynamic layout. With this change, users of One-Shot Bufferize can provide a function with custom logic. Differential Revision: https://reviews.llvm.org/D129273
This commit is contained in:
parent
8d9dc83f35
commit
606f7c8f7a
|
@ -179,6 +179,10 @@ struct BufferizationOptions {
|
||||||
/// Initializer function for dialect-specific analysis state.
|
/// Initializer function for dialect-specific analysis state.
|
||||||
using DialectStateInitFn =
|
using DialectStateInitFn =
|
||||||
std::function<std::unique_ptr<DialectAnalysisState>()>;
|
std::function<std::unique_ptr<DialectAnalysisState>()>;
|
||||||
|
/// Tensor -> MemRef type converter.
|
||||||
|
/// Parameters: Value, memory space, bufferization options
|
||||||
|
using UnknownTypeConverterFn = std::function<BaseMemRefType(
|
||||||
|
Value, unsigned, const BufferizationOptions &)>;
|
||||||
|
|
||||||
enum class LayoutMapOption : int8_t {
|
enum class LayoutMapOption : int8_t {
|
||||||
InferLayoutMap = 0,
|
InferLayoutMap = 0,
|
||||||
|
@ -266,21 +270,11 @@ struct BufferizationOptions {
|
||||||
LayoutMapOption functionBoundaryTypeConversion =
|
LayoutMapOption functionBoundaryTypeConversion =
|
||||||
LayoutMapOption::InferLayoutMap;
|
LayoutMapOption::InferLayoutMap;
|
||||||
|
|
||||||
/// This flag controls buffer types on unknown ops (to_memref wrappers) and in
|
/// Type converter from tensors to memrefs. This type converter is used if no
|
||||||
/// other cases where a precise memref type cannot be inferred (e.g., the
|
/// memref type could be inferred during bufferization. By default, a type
|
||||||
/// bufferization of "tensor.cast").
|
/// converter that returns a memref type with a fully dynamic layout map is
|
||||||
///
|
/// used.
|
||||||
/// * InferLayoutMap: This option is invalid and cannot be used.
|
UnknownTypeConverterFn unknownTypeConverterFn = nullptr;
|
||||||
/// * FullyDynamicLayoutMap: Assume that unknown ops have results with fully
|
|
||||||
/// dynamic layout maps after bufferization. This option is most efficient
|
|
||||||
/// because any layout map can be casted to a fully dynamic one.
|
|
||||||
/// * IdentityLayoutMap: Assume that unknown ops have results with static
|
|
||||||
/// identity layout (i.e., no layout map) after bufferization. This option
|
|
||||||
/// introduces additional buffer allocs and copies if the unknown op is
|
|
||||||
/// eventually bufferized to an op that returns a buffer with non-identity
|
|
||||||
/// layout.
|
|
||||||
LayoutMapOption unknownTypeConversion =
|
|
||||||
LayoutMapOption::FullyDynamicLayoutMap;
|
|
||||||
|
|
||||||
/// Specifies whether dealloc ops should be generated along with alloc ops. If
|
/// Specifies whether dealloc ops should be generated along with alloc ops. If
|
||||||
/// not, new memory allocations will leak.
|
/// not, new memory allocations will leak.
|
||||||
|
@ -505,20 +499,19 @@ OpTy replaceOpWithNewBufferizedOp(RewriterBase &rewriter, Operation *op,
|
||||||
return newOp;
|
return newOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Return a MemRefType to which the `tensorType` can be bufferized.
|
/// Return a MemRefType to which the type of the given value can be bufferized.
|
||||||
///
|
///
|
||||||
/// If possible, op bufferization implementations should not use this function
|
/// If possible, op bufferization implementations should not use this function
|
||||||
/// and instead infer precise memref types for tensor results by themselves.
|
/// and instead infer precise memref types for tensor results by themselves.
|
||||||
///
|
///
|
||||||
/// Unless a layout map was specified, `options.unknownTypeConverter` determines
|
/// Unless a layout map was specified, `options.unknownTypeConverterFn`
|
||||||
/// what kind of layout map will be used. For best composability (without
|
/// determines what kind of layout map will be used. For best composability
|
||||||
/// copies), the fully dynamic layout map is used by default.
|
/// (without copies), the fully dynamic layout map is used by default.
|
||||||
///
|
///
|
||||||
/// Note: Canonicalization patterns could clean up layout maps and infer more
|
/// Note: Canonicalization patterns could clean up layout maps and infer more
|
||||||
/// precise layout maps after bufferization. However, many possible
|
/// precise layout maps after bufferization. However, many possible
|
||||||
/// canonicalizations are currently not implemented.
|
/// canonicalizations are currently not implemented.
|
||||||
BaseMemRefType getMemRefType(TensorType tensorType,
|
BaseMemRefType getMemRefType(Value value, const BufferizationOptions &options,
|
||||||
const BufferizationOptions &options,
|
|
||||||
MemRefLayoutAttrInterface layout = {},
|
MemRefLayoutAttrInterface layout = {},
|
||||||
unsigned memorySpace = 0);
|
unsigned memorySpace = 0);
|
||||||
|
|
||||||
|
|
|
@ -351,8 +351,9 @@ def BufferizableOpInterface : OpInterface<"BufferizableOpInterface"> {
|
||||||
/*defaultImplementation=*/[{
|
/*defaultImplementation=*/[{
|
||||||
assert(bbArg.getOwner()->getParentOp() == $_op &&
|
assert(bbArg.getOwner()->getParentOp() == $_op &&
|
||||||
"bbArg must belong to this op");
|
"bbArg must belong to this op");
|
||||||
auto tensorType = bbArg.getType().cast<TensorType>();
|
assert(bbArg.getType().isa<TensorType>() &&
|
||||||
return bufferization::getMemRefType(tensorType, options);
|
"expected tensor type");
|
||||||
|
return bufferization::getMemRefType(bbArg, options);
|
||||||
}]
|
}]
|
||||||
>,
|
>,
|
||||||
InterfaceMethod<
|
InterfaceMethod<
|
||||||
|
|
|
@ -222,8 +222,17 @@ bool OpFilter::isOpAllowed(Operation *op) const {
|
||||||
// BufferizationOptions
|
// BufferizationOptions
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
/// Default unknown type converter: Use a fully dynamic layout map.
|
||||||
|
static BaseMemRefType
|
||||||
|
defaultUnknownTypeConverter(Value value, unsigned memorySpace,
|
||||||
|
const BufferizationOptions &options) {
|
||||||
|
return getMemRefTypeWithFullyDynamicLayout(value.getType().cast<TensorType>(),
|
||||||
|
memorySpace);
|
||||||
|
}
|
||||||
|
|
||||||
// Default constructor for BufferizationOptions.
|
// Default constructor for BufferizationOptions.
|
||||||
BufferizationOptions::BufferizationOptions() = default;
|
BufferizationOptions::BufferizationOptions()
|
||||||
|
: unknownTypeConverterFn(defaultUnknownTypeConverter) {}
|
||||||
|
|
||||||
bool BufferizationOptions::isOpAllowed(Operation *op) const {
|
bool BufferizationOptions::isOpAllowed(Operation *op) const {
|
||||||
// Special case: If function boundary bufferization is deactivated, do not
|
// Special case: If function boundary bufferization is deactivated, do not
|
||||||
|
@ -528,8 +537,7 @@ FailureOr<Value> bufferization::getBuffer(RewriterBase &rewriter, Value value,
|
||||||
/// Return the buffer type for a given Value (tensor) after bufferization.
|
/// Return the buffer type for a given Value (tensor) after bufferization.
|
||||||
FailureOr<BaseMemRefType>
|
FailureOr<BaseMemRefType>
|
||||||
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
|
bufferization::getBufferType(Value value, const BufferizationOptions &options) {
|
||||||
auto tensorType = value.getType().dyn_cast<TensorType>();
|
assert(value.getType().isa<TensorType>() && "unexpected non-tensor type");
|
||||||
assert(tensorType && "unexpected non-tensor type");
|
|
||||||
Operation *op = getOwnerOfValue(value);
|
Operation *op = getOwnerOfValue(value);
|
||||||
|
|
||||||
// ToTensorOp: Take buffer type directly from the op.
|
// ToTensorOp: Take buffer type directly from the op.
|
||||||
|
@ -566,7 +574,7 @@ bufferization::getBufferType(Value value, const BufferizationOptions &options) {
|
||||||
if (!memorySpace.hasValue())
|
if (!memorySpace.hasValue())
|
||||||
return op->emitError("could not infer memory space");
|
return op->emitError("could not infer memory space");
|
||||||
|
|
||||||
return getMemRefType(tensorType, options, /*layout=*/{}, *memorySpace);
|
return getMemRefType(value, options, /*layout=*/{}, *memorySpace);
|
||||||
}
|
}
|
||||||
|
|
||||||
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
|
void bufferization::replaceOpWithBufferizedValues(RewriterBase &rewriter,
|
||||||
|
@ -652,10 +660,11 @@ bool bufferization::isFunctionArgument(Value value) {
|
||||||
return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
|
return isa<func::FuncOp>(bbArg.getOwner()->getParentOp());
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
|
BaseMemRefType bufferization::getMemRefType(Value value,
|
||||||
const BufferizationOptions &options,
|
const BufferizationOptions &options,
|
||||||
MemRefLayoutAttrInterface layout,
|
MemRefLayoutAttrInterface layout,
|
||||||
unsigned memorySpace) {
|
unsigned memorySpace) {
|
||||||
|
auto tensorType = value.getType().cast<TensorType>();
|
||||||
auto memorySpaceAttr = IntegerAttr::get(
|
auto memorySpaceAttr = IntegerAttr::get(
|
||||||
IntegerType::get(tensorType.getContext(), 64), memorySpace);
|
IntegerType::get(tensorType.getContext(), 64), memorySpace);
|
||||||
|
|
||||||
|
@ -674,17 +683,7 @@ BaseMemRefType bufferization::getMemRefType(TensorType tensorType,
|
||||||
memorySpaceAttr);
|
memorySpaceAttr);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Case 3: Configured with "fully dynamic layout maps".
|
return options.unknownTypeConverterFn(value, memorySpace, options);
|
||||||
if (options.unknownTypeConversion ==
|
|
||||||
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap)
|
|
||||||
return getMemRefTypeWithFullyDynamicLayout(tensorType, memorySpace);
|
|
||||||
|
|
||||||
// Case 4: Configured with "static identity layout maps".
|
|
||||||
if (options.unknownTypeConversion ==
|
|
||||||
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
|
|
||||||
return getMemRefTypeWithStaticIdentityLayout(tensorType, memorySpace);
|
|
||||||
|
|
||||||
llvm_unreachable("InferLayoutMap is an invalid option");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
BaseMemRefType
|
BaseMemRefType
|
||||||
|
|
|
@ -192,8 +192,26 @@ struct OneShotBufferizePass
|
||||||
opt.printConflicts = printConflicts;
|
opt.printConflicts = printConflicts;
|
||||||
opt.testAnalysisOnly = testAnalysisOnly;
|
opt.testAnalysisOnly = testAnalysisOnly;
|
||||||
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
|
opt.bufferizeFunctionBoundaries = bufferizeFunctionBoundaries;
|
||||||
opt.unknownTypeConversion = parseLayoutMapOption(unknownTypeConversion);
|
|
||||||
|
|
||||||
|
// Configure type converter.
|
||||||
|
BufferizationOptions::LayoutMapOption unknownTypeConversionOption =
|
||||||
|
parseLayoutMapOption(unknownTypeConversion);
|
||||||
|
opt.unknownTypeConverterFn = [=](Value value, unsigned memorySpace,
|
||||||
|
const BufferizationOptions &options) {
|
||||||
|
auto tensorType = value.getType().cast<TensorType>();
|
||||||
|
if (unknownTypeConversionOption ==
|
||||||
|
BufferizationOptions::LayoutMapOption::IdentityLayoutMap)
|
||||||
|
return bufferization::getMemRefTypeWithStaticIdentityLayout(
|
||||||
|
tensorType, memorySpace);
|
||||||
|
assert(
|
||||||
|
unknownTypeConversionOption ==
|
||||||
|
BufferizationOptions::LayoutMapOption::FullyDynamicLayoutMap &&
|
||||||
|
"invalid layout map option");
|
||||||
|
return bufferization::getMemRefTypeWithFullyDynamicLayout(tensorType,
|
||||||
|
memorySpace);
|
||||||
|
};
|
||||||
|
|
||||||
|
// Configure op filter.
|
||||||
OpFilter::Entry::FilterFn filterFn =
|
OpFilter::Entry::FilterFn filterFn =
|
||||||
[&](Operation *op) {
|
[&](Operation *op) {
|
||||||
// Filter may be specified via options.
|
// Filter may be specified via options.
|
||||||
|
@ -372,10 +390,6 @@ LogicalResult bufferization::bufferizeOp(Operation *op,
|
||||||
const BufferizationOptions &options,
|
const BufferizationOptions &options,
|
||||||
bool copyBeforeWrite,
|
bool copyBeforeWrite,
|
||||||
const OpFilter *opFilter) {
|
const OpFilter *opFilter) {
|
||||||
assert(options.unknownTypeConversion !=
|
|
||||||
BufferizationOptions::LayoutMapOption::InferLayoutMap &&
|
|
||||||
"invalid layout map option");
|
|
||||||
|
|
||||||
if (copyBeforeWrite) {
|
if (copyBeforeWrite) {
|
||||||
AnalysisState state(options);
|
AnalysisState state(options);
|
||||||
if (failed(insertTensorCopies(op, state)))
|
if (failed(insertTensorCopies(op, state)))
|
||||||
|
@ -474,8 +488,11 @@ BufferizationOptions bufferization::getPartialBufferizationOptions() {
|
||||||
options.allowUnknownOps = true;
|
options.allowUnknownOps = true;
|
||||||
options.createDeallocs = false;
|
options.createDeallocs = false;
|
||||||
options.enforceAliasingInvariants = false;
|
options.enforceAliasingInvariants = false;
|
||||||
options.unknownTypeConversion =
|
options.unknownTypeConverterFn = [](Value value, unsigned memorySpace,
|
||||||
BufferizationOptions::LayoutMapOption::IdentityLayoutMap;
|
const BufferizationOptions &options) {
|
||||||
|
return getMemRefTypeWithStaticIdentityLayout(
|
||||||
|
value.getType().cast<TensorType>(), memorySpace);
|
||||||
|
};
|
||||||
options.opFilter.allowDialect<BufferizationDialect>();
|
options.opFilter.allowDialect<BufferizationDialect>();
|
||||||
return options;
|
return options;
|
||||||
}
|
}
|
||||||
|
|
|
@ -67,7 +67,7 @@ struct CastOpInterface
|
||||||
|
|
||||||
// Compute the new memref type.
|
// Compute the new memref type.
|
||||||
Type resultMemRefType =
|
Type resultMemRefType =
|
||||||
getMemRefType(resultTensorType, options, layout,
|
getMemRefType(castOp.getResult(), options, layout,
|
||||||
sourceMemRefType.getMemorySpaceAsInt());
|
sourceMemRefType.getMemorySpaceAsInt());
|
||||||
|
|
||||||
// Replace the op with a memref.cast.
|
// Replace the op with a memref.cast.
|
||||||
|
@ -780,9 +780,8 @@ struct ReshapeOpInterface
|
||||||
getBuffer(rewriter, reshapeOp.getShape(), options);
|
getBuffer(rewriter, reshapeOp.getShape(), options);
|
||||||
if (failed(srcBuffer) || failed(shapeBuffer))
|
if (failed(srcBuffer) || failed(shapeBuffer))
|
||||||
return failure();
|
return failure();
|
||||||
auto resultTensorType = reshapeOp.getResult().getType().cast<TensorType>();
|
|
||||||
auto resultMemRefType = getMemRefType(
|
auto resultMemRefType = getMemRefType(
|
||||||
resultTensorType, options, /*layout=*/{},
|
reshapeOp.getResult(), options, /*layout=*/{},
|
||||||
srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
|
srcBuffer->getType().cast<BaseMemRefType>().getMemorySpaceAsInt());
|
||||||
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
|
replaceOpWithNewBufferizedOp<memref::ReshapeOp>(
|
||||||
rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
|
rewriter, op, resultMemRefType, *srcBuffer, *shapeBuffer);
|
||||||
|
|
Loading…
Reference in New Issue