forked from OSchip/llvm-project
[mlir] Add a flag to allow equivalent results.
Differential Revision: https://reviews.llvm.org/D124931
This commit is contained in:
parent
62b2a47a9f
commit
e8f7d019fc
|
@ -47,6 +47,10 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
|
|||
/// Specifies whether returning newly allocated memrefs should be allowed.
|
||||
/// Otherwise, a pass failure is triggered.
|
||||
bool allowReturnAllocs = false;
|
||||
|
||||
/// Specifies whether buffer return values that are equivalent to a FuncOp
|
||||
/// bbArg should be dropped.
|
||||
bool dropEquivalentFuncResults = true;
|
||||
};
|
||||
|
||||
/// The BufferizationAliasInfo class maintains a list of buffer aliases and
|
||||
|
|
|
@ -230,6 +230,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
|
|||
`test-analysis-only`.
|
||||
}];
|
||||
let options = [
|
||||
Option<"dropEquivalentFuncResults", "drop-equivalent-func-results", "bool",
|
||||
/*default=*/"true",
|
||||
"Drop buffer return values that are equivalent to a FuncOp arg.">,
|
||||
Option<"allowReturnAllocs", "allow-return-allocs", "bool",
|
||||
/*default=*/"false",
|
||||
"Allows returning/yielding new allocations from a block.">,
|
||||
|
|
|
@ -169,6 +169,7 @@ struct OneShotBufferizePass
|
|||
if (!options) {
|
||||
// Make new bufferization options if none were provided when creating the
|
||||
// pass.
|
||||
opt.dropEquivalentFuncResults = dropEquivalentFuncResults;
|
||||
opt.allowReturnAllocs = allowReturnAllocs;
|
||||
opt.allowUnknownOps = allowUnknownOps;
|
||||
opt.alwaysAliasingWithDest = alwaysAliasingWithDest;
|
||||
|
|
|
@ -269,17 +269,19 @@ struct CallOpInterface
|
|||
continue;
|
||||
}
|
||||
|
||||
if (Optional<int64_t> bbArgIdx =
|
||||
getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
|
||||
// Return operands that are equivalent to some bbArg, are not
|
||||
// returned.
|
||||
FailureOr<Value> bufferOrFailure =
|
||||
state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
|
||||
if (failed(bufferOrFailure))
|
||||
return failure();
|
||||
replacementValues[returnValIdx] = *bufferOrFailure;
|
||||
newOperands[*bbArgIdx] = *bufferOrFailure;
|
||||
continue;
|
||||
if (options.dropEquivalentFuncResults) {
|
||||
if (Optional<int64_t> bbArgIdx =
|
||||
getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
|
||||
// Return operands that are equivalent to some bbArg, are not
|
||||
// returned.
|
||||
FailureOr<Value> bufferOrFailure =
|
||||
state.getBuffer(rewriter, callOp->getOpOperand(*bbArgIdx));
|
||||
if (failed(bufferOrFailure))
|
||||
return failure();
|
||||
replacementValues[returnValIdx] = *bufferOrFailure;
|
||||
newOperands[*bbArgIdx] = *bufferOrFailure;
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
if (!options.allowReturnAllocs)
|
||||
|
@ -404,7 +406,8 @@ struct FuncOpInterface
|
|||
FunctionType funcType = funcOp.getFunctionType();
|
||||
const FuncAnalysisState &funcState =
|
||||
getFuncAnalysisState(state.getAnalysisState());
|
||||
const BufferizationOptions &options = state.getOptions();
|
||||
const OneShotBufferizationOptions &options =
|
||||
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
|
||||
|
||||
// Construct the bufferized function type.
|
||||
SmallVector<Type> argTypes;
|
||||
|
@ -479,20 +482,23 @@ struct FuncOpInterface
|
|||
}
|
||||
|
||||
// If return operand is equivalent to some bbArg, no need to return it.
|
||||
if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
|
||||
funcOp, funcState, returnOperand.getOperandNumber())) {
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Location loc = returnOp.getLoc();
|
||||
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
|
||||
loc, getMemRefType(returnVal.getType().cast<TensorType>(), options),
|
||||
returnVal);
|
||||
BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
|
||||
// Note: This copy will fold away. It must be inserted here to ensure
|
||||
// that `returnVal` still has at least one use and does not fold away.
|
||||
if (failed(
|
||||
createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
|
||||
return funcOp->emitError("could not generate copy for bbArg");
|
||||
continue;
|
||||
if (options.dropEquivalentFuncResults) {
|
||||
if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
|
||||
funcOp, funcState, returnOperand.getOperandNumber())) {
|
||||
rewriter.setInsertionPoint(returnOp);
|
||||
Location loc = returnOp.getLoc();
|
||||
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
|
||||
loc,
|
||||
getMemRefType(returnVal.getType().cast<TensorType>(), options),
|
||||
returnVal);
|
||||
BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
|
||||
// Note: This copy will fold away. It must be inserted here to ensure
|
||||
// that `returnVal` still has at least one use and does not fold away.
|
||||
if (failed(
|
||||
createMemCpy(rewriter, loc, toMemrefOp, equivBbArg, options)))
|
||||
return funcOp->emitError("could not generate copy for bbArg");
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
returnValues.push_back(*state.getBuffer(rewriter, returnOperand));
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs" -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs drop-equivalent-func-results=false" -split-input-file | FileCheck %s --check-prefix=EQUIV
|
||||
|
||||
// Run fuzzer with different seeds.
|
||||
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
|
||||
|
@ -62,3 +63,16 @@ func.func @main(%t: tensor<?xf32>, %sz: index, %idx: index) -> (f32, f32) {
|
|||
%r2 = tensor.extract %filled[%idx] : tensor<?xf32>
|
||||
return %r1, %r2 : f32, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> {
|
||||
func.return %A : tensor<?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @return_arg
|
||||
// CHECK-SAME: %[[A:.*]]: memref<?xf32
|
||||
// CHECK-NOT: return %[[A]]
|
||||
|
||||
// EQUIV-LABEL: func @return_arg
|
||||
// EQUIV-SAME: %[[A:.*]]: memref<?xf32
|
||||
// EQUIV: return %[[A]]
|
||||
|
|
Loading…
Reference in New Issue