[mlir] Add a flag to allow equivalent results.

Differential Revision: https://reviews.llvm.org/D124931
This commit is contained in:
Alexander Belyaev 2022-05-04 17:46:17 +02:00
parent 62b2a47a9f
commit e8f7d019fc
5 changed files with 54 additions and 26 deletions

View File

@ -47,6 +47,10 @@ struct OneShotBufferizationOptions : public BufferizationOptions {
/// Specifies whether returning newly allocated memrefs should be allowed. /// Specifies whether returning newly allocated memrefs should be allowed.
/// Otherwise, a pass failure is triggered. /// Otherwise, a pass failure is triggered.
bool allowReturnAllocs = false; bool allowReturnAllocs = false;
/// Specifies whether buffer return values that are equivalent to a FuncOp
/// bbArg should be dropped.
bool dropEquivalentFuncResults = true;
}; };
/// The BufferizationAliasInfo class maintains a list of buffer aliases and /// The BufferizationAliasInfo class maintains a list of buffer aliases and

View File

@ -230,6 +230,9 @@ def OneShotBufferize : Pass<"one-shot-bufferize", "ModuleOp"> {
`test-analysis-only`. `test-analysis-only`.
}]; }];
let options = [ let options = [
Option<"dropEquivalentFuncResults", "drop-equivalent-func-results", "bool",
/*default=*/"true",
"Drop buffer return values that are equivalent to a FuncOp arg.">,
Option<"allowReturnAllocs", "allow-return-allocs", "bool", Option<"allowReturnAllocs", "allow-return-allocs", "bool",
/*default=*/"false", /*default=*/"false",
"Allows returning/yielding new allocations from a block.">, "Allows returning/yielding new allocations from a block.">,

View File

@ -169,6 +169,7 @@ struct OneShotBufferizePass
if (!options) { if (!options) {
// Make new bufferization options if none were provided when creating the // Make new bufferization options if none were provided when creating the
// pass. // pass.
opt.dropEquivalentFuncResults = dropEquivalentFuncResults;
opt.allowReturnAllocs = allowReturnAllocs; opt.allowReturnAllocs = allowReturnAllocs;
opt.allowUnknownOps = allowUnknownOps; opt.allowUnknownOps = allowUnknownOps;
opt.alwaysAliasingWithDest = alwaysAliasingWithDest; opt.alwaysAliasingWithDest = alwaysAliasingWithDest;

View File

@ -269,6 +269,7 @@ struct CallOpInterface
continue; continue;
} }
if (options.dropEquivalentFuncResults) {
if (Optional<int64_t> bbArgIdx = if (Optional<int64_t> bbArgIdx =
getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) { getEquivalentFuncArgIdx(funcOp, funcState, returnValIdx)) {
// Return operands that are equivalent to some bbArg, are not // Return operands that are equivalent to some bbArg, are not
@ -281,6 +282,7 @@ struct CallOpInterface
newOperands[*bbArgIdx] = *bufferOrFailure; newOperands[*bbArgIdx] = *bufferOrFailure;
continue; continue;
} }
}
if (!options.allowReturnAllocs) if (!options.allowReturnAllocs)
return callOp->emitError( return callOp->emitError(
@ -404,7 +406,8 @@ struct FuncOpInterface
FunctionType funcType = funcOp.getFunctionType(); FunctionType funcType = funcOp.getFunctionType();
const FuncAnalysisState &funcState = const FuncAnalysisState &funcState =
getFuncAnalysisState(state.getAnalysisState()); getFuncAnalysisState(state.getAnalysisState());
const BufferizationOptions &options = state.getOptions(); const OneShotBufferizationOptions &options =
static_cast<const OneShotBufferizationOptions &>(state.getOptions());
// Construct the bufferized function type. // Construct the bufferized function type.
SmallVector<Type> argTypes; SmallVector<Type> argTypes;
@ -479,12 +482,14 @@ struct FuncOpInterface
} }
// If return operand is equivalent to some bbArg, no need to return it. // If return operand is equivalent to some bbArg, no need to return it.
if (options.dropEquivalentFuncResults) {
if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx( if (Optional<int64_t> equivBbArgIdx = getEquivalentFuncArgIdx(
funcOp, funcState, returnOperand.getOperandNumber())) { funcOp, funcState, returnOperand.getOperandNumber())) {
rewriter.setInsertionPoint(returnOp); rewriter.setInsertionPoint(returnOp);
Location loc = returnOp.getLoc(); Location loc = returnOp.getLoc();
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>( Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
loc, getMemRefType(returnVal.getType().cast<TensorType>(), options), loc,
getMemRefType(returnVal.getType().cast<TensorType>(), options),
returnVal); returnVal);
BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx); BlockArgument equivBbArg = funcOp.getArgument(*equivBbArgIdx);
// Note: This copy will fold away. It must be inserted here to ensure // Note: This copy will fold away. It must be inserted here to ensure
@ -494,6 +499,7 @@ struct FuncOpInterface
return funcOp->emitError("could not generate copy for bbArg"); return funcOp->emitError("could not generate copy for bbArg");
continue; continue;
} }
}
returnValues.push_back(*state.getBuffer(rewriter, returnOperand)); returnValues.push_back(*state.getBuffer(rewriter, returnOperand));
} }

View File

@ -1,4 +1,5 @@
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs" -split-input-file | FileCheck %s // RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs" -split-input-file | FileCheck %s
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs drop-equivalent-func-results=false" -split-input-file | FileCheck %s --check-prefix=EQUIV
// Run fuzzer with different seeds. // Run fuzzer with different seeds.
// RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null // RUN: mlir-opt %s -one-shot-bufferize="bufferize-function-boundaries=1 allow-return-allocs test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
@ -62,3 +63,16 @@ func.func @main(%t: tensor<?xf32>, %sz: index, %idx: index) -> (f32, f32) {
%r2 = tensor.extract %filled[%idx] : tensor<?xf32> %r2 = tensor.extract %filled[%idx] : tensor<?xf32>
return %r1, %r2 : f32, f32 return %r1, %r2 : f32, f32
} }
// -----
func.func @return_arg(%A: tensor<?xf32>) -> tensor<?xf32> {
func.return %A : tensor<?xf32>
}
// CHECK-LABEL: func @return_arg
// CHECK-SAME: %[[A:.*]]: memref<?xf32
// CHECK-NOT: return %[[A]]
// EQUIV-LABEL: func @return_arg
// EQUIV-SAME: %[[A:.*]]: memref<?xf32
// EQUIV: return %[[A]]