[mlir] Add a flag to allow equivalent results.

Differential Revision: https://reviews.llvm.org/D124931
This commit is contained in:
Alexander Belyaev 2022-05-04 17:46:17 +02:00
parent 62b2a47a9f
commit e8f7d019fc
5 changed files with 54 additions and 26 deletions

View File

@ -47,6 +47,10 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
/// Specifies whether returning newly allocated memrefs should be allowed.
/// Otherwise, a pass failure is triggered.
bool allowReturnAllocs = false;
/// Specifies whether buffer return values that are equivalent to a FuncOp
/// bbArg should be dropped.
bool dropEquivalentFuncResults = true;
};
/// The BufferizationAliasInfo class maintains a list of buffer aliases and

View File

@ -230,6 +230,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
`test-analysis-only`.
}];
let options = [
Option<"dropEquivalentFuncResults", "drop-equivalent-func-results", "bool",
/*default=*/"true",
"Drop buffer return values that are equivalent to a FuncOp arg.">,
Option<"allowReturnAllocs", "allow-return-allocs", "bool",
/*default=*/"false",
"Allows returning/yielding new allocations from a block.">,

View File

@ -169,6 +169,7 @@ struct OneShotBufferizePass
if (!options) {
// Make new bufferization options if none were provided when creating the
// pass.
opt.dropEquivalentFuncResults = dropEquivalentFuncResults;
opt.allowReturnAllocs = allowReturnAllocs;
opt.allowUnknownOps = allowUnknownOps;
opt.alwaysAliasingWithDest = alwaysAliasingWithDest;

View File

@ -269,17 +269,19 @@ struct CallOpInterface
continue;
}
if (Optional<int64_t> bbArgIdx =
getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
// Return operands that are equivalent to some bbArg, are not
// returned.
FailureOr<Value> bufferOrFailure =
state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
if (failed(bufferOrFailure))
return failure();
replacementValues[returnValIdx] = *bufferOrFailure;
newOperands[*bbArgIdx] = *bufferOrFailure;
continue;
if (options.dropEquivalentFuncResults) {
if (Optional<int64_t> bbArgIdx =
getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
// Return operands that are equivalent to some bbArg, are not
// returned.
FailureOr<Value> bufferOrFailure =
state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
if (failed(bufferOrFailure))
return failure();
replacementValues[returnValIdx] = *bufferOrFailure;
newOperands[*bbArgIdx] = *bufferOrFailure;
continue;
}
}
if (!options.allowReturnAllocs)
@ -404,7 +406,8 @@ struct FuncOpInterface
FunctionType funcType = funcOp.getFunctionType();
const FuncAnalysisState &funcState =
getFuncAnalysisState(state.getAnalysisState());
const BufferizationOptions &options = state.getOptions();
const OneShotBufferizationOptions &options =
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
// Construct the bufferized function type.
SmallVector<Type> argTypes;
@ -479,20 +482,23 @@ struct FuncOpInterface
}
// If return operand is equivalent to some bbArg, no need to return it.
if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
funcOp, funcState, returnOperand.getOperandNumber())) {
rewriter.setInsertionPoint(returnOp);
Location loc = returnOp.getLoc();
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
loc, getMemRefType(returnVal.getType().cast<TensorType>(), options),
returnVal);
BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
// Note: This copy will fold away. It must be inserted here to ensure
// that `returnVal` still has at least one use and does not fold away.
if (failed(
createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
return funcOp->emitError("could not generate copy for bbArg");
continue;
if (options.dropEquivalentFuncResults) {
if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
funcOp, funcState, returnOperand.getOperandNumber())) {
rewriter.setInsertionPoint(returnOp);
Location loc = returnOp.getLoc();
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
loc,
getMemRefType(returnVal.getType().cast<TensorType>(), options),
returnVal);
BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
// Note: This copy will fold away. It must be inserted here to ensure
// that `returnVal` still has at least one use and does not fold away.
if (failed(
createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
return funcOp->emitError("could not generate copy for bbArg");
continue;
}
}
returnValues.push_back(*state.getBuffer(rewriter, returnOperand));

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs drop-equivalent-func-results=false" -split-input-file | FileCheck %s --check-prefix=EQUIV
// Run fuzzer with different seeds.
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
@ -62,3 +63,16 @@ func.func @main(%t: tensor<?xf32>, %sz: index, %idx: index) -> (f32, f32) {
%r2 = tensor.extract %filled[%idx] : tensor<?xf32>
return %r1, %r2 : f32, f32
}
// -----
func.func @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> {
func.return %A : tensor<?xf32>
}
// CHECK-LABEL: func @return_arg
// CHECK-SAME: %[[A:.*]]: memref<?xf32
// CHECK-NOT: return %[[A]]
// EQUIV-LABEL: func @return_arg
// EQUIV-SAME: %[[A:.*]]: memref<?xf32
// EQUIV: return %[[A]]