forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Do not cache bufferized function types
This does not work if BufferizationState is passed around as a const reference in most places. Differential Revision: https://reviews.llvm.org/D116741
This commit is contained in:
parent
fb9bfb2c59
commit
cd84cf90e9
|
@ -24,9 +24,6 @@ namespace {
|
|||
/// Extra bufferization state that is required for bufferization of function
|
||||
/// boundaries.
|
||||
struct ModuleBufferizationState : public DialectBufferizationState {
|
||||
/// A map for looking up bufferized function types.
|
||||
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
|
||||
|
||||
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
|
||||
/// indices.
|
||||
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
|
||||
|
@ -161,23 +158,6 @@ static FunctionType getBufferizedFunctionType(MLIRContext *ctx,
|
|||
return FunctionType::get(ctx, argTypes, retTypes);
|
||||
}
|
||||
|
||||
/// If an entry for `funcOp` is available in `bufferizedFunctionTypes`, return
|
||||
/// it. Otherwise, construct a new entry based on `argumentTypes` and
|
||||
/// `resultTypes`.
|
||||
// TODO: improve the layering.
|
||||
static FunctionType getOrCreateBufferizedFunctionType(
|
||||
FuncOp funcOp, TypeRange argumentTypes, TypeRange resultTypes,
|
||||
DenseMap<FuncOp, FunctionType> &bufferizedFunctionTypes) {
|
||||
auto it = bufferizedFunctionTypes.find(funcOp);
|
||||
if (it != bufferizedFunctionTypes.end())
|
||||
return it->second;
|
||||
|
||||
auto it2 = bufferizedFunctionTypes.try_emplace(
|
||||
funcOp, getBufferizedFunctionType(funcOp.getContext(), argumentTypes,
|
||||
resultTypes));
|
||||
return it2.first->second;
|
||||
}
|
||||
|
||||
/// Gather equivalence info of CallOps.
|
||||
/// Note: This only adds new equivalence info if `funcOp` was already analyzed.
|
||||
// TODO: This does not handle cyclic function call graphs etc.
|
||||
|
@ -250,9 +230,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
|
||||
return funcOp->emitError() << "cannot bufferize bodiless function that "
|
||||
<< "returns a tensor";
|
||||
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
|
||||
funcOp, funcOp.getType().getInputs(), TypeRange{},
|
||||
moduleState.bufferizedFunctionTypes);
|
||||
FunctionType bufferizedFuncType = getBufferizedFunctionType(
|
||||
funcOp.getContext(), funcOp.getType().getInputs(), TypeRange{});
|
||||
funcOp.setType(bufferizedFuncType);
|
||||
return success();
|
||||
}
|
||||
|
@ -284,9 +263,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
|
||||
// 2. Rewrite the terminator without the inPlace bufferizable values.
|
||||
ValueRange retValues{returnValues};
|
||||
FunctionType bufferizedFuncType = getOrCreateBufferizedFunctionType(
|
||||
funcOp, funcOp.getType().getInputs(), retValues.getTypes(),
|
||||
moduleState.bufferizedFunctionTypes);
|
||||
FunctionType bufferizedFuncType = getBufferizedFunctionType(
|
||||
funcOp.getContext(), funcOp.getType().getInputs(), retValues.getTypes());
|
||||
OpBuilder b(returnOp);
|
||||
b.create<ReturnOp>(returnOp.getLoc(), returnValues);
|
||||
returnOp->erase();
|
||||
|
@ -590,9 +568,8 @@ struct CallOpInterface
|
|||
SmallVector<Type> argumentTypes{callOp->getOperandTypes()};
|
||||
// Get the bufferized FunctionType for funcOp or construct it if not yet
|
||||
// available.
|
||||
FunctionType bufferizedFuncType =
|
||||
getOrCreateBufferizedFunctionType(funcOp, argumentTypes, resultTypes,
|
||||
moduleState.bufferizedFunctionTypes);
|
||||
FunctionType bufferizedFuncType = getBufferizedFunctionType(
|
||||
funcOp.getContext(), argumentTypes, resultTypes);
|
||||
|
||||
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
|
||||
for (OpOperand &opOperand : callOp->getOpOperands()) {
|
||||
|
|
Loading…
Reference in New Issue