[mlir][linalg][bufferize][NFC] Do not cache bufferized function types

This does not work if BufferizationState is passed around as a const reference in most places.

Differential Revision: https://reviews.llvm.org/D116741
This commit is contained in:
Matthias Springer 2022-01-06 23:58:20 +09:00
parent fb9bfb2c59
commit cd84cf90e9
1 changed files with 6 additions and 29 deletions

View File

@ -24,9 +24,6 @@ namespace {
/// Extra bufferization state that is required for bufferization of function
/// boundaries.
struct ModuleBufferizationState : public DialectBufferizationState {
/// A map for looking up bufferized function types.
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
@ -161,23 +158,6 @@ static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
return FunctionType::get(ctx, argTypes, retTypes);
}
/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return
/// it. Otherwise, construct a new entry based on `argumentTypes` and
/// `resultTypes`.
// TODO: improve the layering.
static FunctionType getOrCreateBufferizedFunctionType(
FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes,
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
auto it = bufferizedFunctionTypes.find(funcOp);
if (it != bufferizedFunctionTypes.end())
return it->second;
auto it2 = bufferizedFunctionTypes.try_emplace(
funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
resultTypes));
return it2.first->second;
}
/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if `funcOp` was already analyzed.
// TODO: This does not handle cyclic function call graphs etc.
@ -250,9 +230,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
return funcOp->emitError() << "cannot bufferize bodiless function that "
<< "returns a tensor";
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, funcOp.getType().getInputs(), TypeRange{},
moduleState.bufferizedFunctionTypes);
FunctionType bufferizedFuncType = getBufferizedFunctionType(
funcOp.getContext(), funcOp.getType().getInputs(), TypeRange{});
funcOp.setType(bufferizedFuncType);
return success();
}
@ -284,9 +263,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
// 2. Rewrite the terminator without the inPlace bufferizable values.
ValueRange retValues{returnValues};
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
moduleState.bufferizedFunctionTypes);
FunctionType bufferizedFuncType = getBufferizedFunctionType(
funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes());
OpBuilder b(returnOp);
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
returnOp->erase();
@ -590,9 +568,8 @@ struct CallOpInterface
SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
// Get the bufferized FunctionType for funcOp or construct it if not yet
// available.
FunctionType bufferizedFuncType =
getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes,
moduleState.bufferizedFunctionTypes);
FunctionType bufferizedFuncType = getBufferizedFunctionType(
funcOp.getContext(), argumentTypes, resultTypes);
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
for (OpOperand &opOperand : callOp->getOpOperands()) {