[mlir][linalg][bufferize][NFC] Collect equivalent FuncOp BBArgs in PostAnalysisStep

Collect equivalent BBArgs right after the equivalence analysis of the FuncOp and before bufferizing. This is in preparation of decoupling bufferization from aliasInfo.

Also gather equivalence info for CallOps, which was missing in the
previous commit.

Differential Revision: https://reviews.llvm.org/D114847
This commit is contained in:
Matthias Springer 2021-12-06 17:25:13 +09:00
parent b15d77928e
commit cb4d0bf997
6 changed files with 300 additions and 57 deletions

View File

@ -71,6 +71,8 @@ struct PostAnalysisStep {
SmallVector<Operation *> &newOps) = 0;
};
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
/// Options for ComprehensiveBufferize.
struct BufferizationOptions {
BufferizationOptions();
@ -107,7 +109,7 @@ struct BufferizationOptions {
bool testAnalysisOnly = false;
/// Registered post analysis steps.
std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
PostAnalysisStepList postAnalysisSteps;
};
/// Specify fine-grain relationship between buffers to enable more analysis.

View File

@ -18,13 +18,16 @@ namespace comprehensive_bufferize {
struct BufferizationOptions;
struct BufferizationState;
struct PostAnalysisStep;
/// Bufferize the given function. Does not bufferize the function boundary.
/// Reuses an existing BufferizationState object.
// TODO: This function is meant to be called from ModuleBufferize and not can
// not yet be called standalone.
LogicalResult runComprehensiveBufferize(FuncOp funcOp,
const BufferizationOptions &options,
BufferizationState &state);
LogicalResult runComprehensiveBufferize(
FuncOp funcOp, const BufferizationOptions &options,
BufferizationState &state,
const std::vector<std::unique_ptr<PostAnalysisStep>> &extraSteps);
} // namespace comprehensive_bufferize
} // namespace linalg

View File

@ -726,7 +726,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
FuncOp funcOp, const BufferizationOptions &options,
BufferizationState &state) {
BufferizationState &state, const PostAnalysisStepList &extraSteps) {
DominanceInfo domInfo(funcOp);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
@ -744,16 +744,23 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
return failure();
equivalenceAnalysis(op, aliasInfo);
for (const std::unique_ptr<PostAnalysisStep> &step :
options.postAnalysisSteps) {
SmallVector<Operation *> newOps;
if (failed(step->run(funcOp, state, newOps)))
return failure();
// Analyze ops that were created by the PostAnalysisStep.
if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
return failure();
equivalenceAnalysis(newOps, aliasInfo);
}
auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
SmallVector<Operation *> newOps;
if (failed(step->run(funcOp, state, newOps)))
return failure();
// Analyze ops that were created by the PostAnalysisStep.
if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
return failure();
equivalenceAnalysis(newOps, aliasInfo);
}
return success();
};
if (failed(runPostAnalysisSteps(extraSteps)))
return failure();
if (failed(runPostAnalysisSteps(options.postAnalysisSteps)))
return failure();
// Annotate operations if we only want to report the analysis.
if (options.testAnalysisOnly) {

View File

@ -33,8 +33,9 @@ struct ModuleBufferizationState : public DialectBufferizationState {
/// A map for looking up bufferized function types.
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
/// A mapping of return values to equivalent BlockArguments.
DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
/// indices.
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
};
} // namespace
@ -44,6 +45,70 @@ getModuleBufferizationState(BufferizationState &state) {
StandardOpsDialect::getDialectNamespace());
}
/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
ReturnOp returnOp;
for (Block &b : funcOp.body()) {
if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
}
}
return returnOp;
}
namespace {
/// Store function BlockArguments that are equivalent to a returned value in
/// ModuleBufferizationState.
struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
/// Annotate IR with the results of the analysis. For testing purposes only.
static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
const char *kEquivalentArgsAttr = "__equivalent_func_args__";
Operation *op = returnVal.getOwner();
SmallVector<int64_t> equivBbArgs;
if (op->hasAttr(kEquivalentArgsAttr)) {
auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
return a.cast<IntegerAttr>().getValue().getSExtValue();
}));
} else {
equivBbArgs.append(op->getNumOperands(), -1);
}
equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
OpBuilder b(op->getContext());
op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
}
LogicalResult run(FuncOp funcOp, BufferizationState &state,
SmallVector<Operation *> &newOps) override {
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
// Support only single return-terminated block in the function.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
for (OpOperand &returnVal : returnOp->getOpOperands())
if (returnVal.get().getType().isa<RankedTensorType>())
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<RankedTensorType>())
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
bbArg)) {
moduleState
.equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
bbArg.getArgNumber();
if (state.options.testAnalysisOnly)
annotateReturnOp(returnVal, bbArg);
}
return success();
}
};
} // namespace
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
/// If `value` is a memref::CastOp, return its source. Otherwise, return
@ -73,20 +138,6 @@ static FuncOp getCalledFunction(CallOpInterface callOp) {
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
}
/// Return the unique ReturnOp that terminates `funcOp`.
/// Return nullptr if there is no such unique ReturnOp.
static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
ReturnOp returnOp;
for (Block &b : funcOp.body()) {
if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
if (returnOp)
return nullptr;
returnOp = candidateOp;
}
}
return returnOp;
}
/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
/// tensor is replaced by the corresponding buffer type.
/// In order for all the callers to agree, this *must* bufferize to the most
@ -128,22 +179,30 @@ static FunctionType getOrCreateBufferizedFunctionType(
return it2.first->second;
}
/// Store function BlockArguments that are equivalent to a returned value in
/// the given ModuleBufferizationState.
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
BufferizationState &state) {
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
/// Gather equivalence info of CallOps.
/// Note: This only adds new equivalence info if `funcOp` was already analyzed.
// TODO: This does not handle cyclic function call graphs etc.
static void equivalenceAnalysis(FuncOp funcOp,
BufferizationAliasInfo &aliasInfo,
ModuleBufferizationState &moduleState) {
funcOp->walk([&](CallOp callOp) {
FuncOp calledFunction = getCalledFunction(callOp);
assert(calledFunction && "could not retrieved called FuncOp");
// Support only single return-terminated block in the function.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
// No equivalence info available for the called function.
if (!moduleState.equivalentFuncArgs.count(calledFunction))
return WalkResult::skip();
for (Value returnVal : returnOp.operands())
if (returnVal.getType().isa<RankedTensorType>())
for (BlockArgument bbArg : funcOp.getArguments())
if (bbArg.getType().isa<RankedTensorType>())
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
for (auto it : moduleState.equivalentFuncArgs[calledFunction]) {
int64_t returnIdx = it.first;
int64_t bbargIdx = it.second;
Value returnVal = callOp.getResult(returnIdx);
Value argVal = callOp->getOperand(bbargIdx);
aliasInfo.unionEquivalenceClasses(returnVal, argVal);
}
return WalkResult::advance();
});
}
/// Rewrite the `funcOp` arguments analysis return values and terminator into
@ -217,7 +276,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
}
// If return operand is equivalent to some bbArg, no need to return it.
if (moduleState.equivalentReturnValToBBArg.count(returnVal))
if (moduleState.equivalentFuncArgs[funcOp].count(
returnOperand.getOperandNumber()))
continue;
// Cast values at the call site if necessary.
@ -493,12 +553,12 @@ struct CallOpInterface
}
// If return operand is equivalent to some bbArg, no need to return it.
Value returnVal = returnOperand.get();
if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
BlockArgument bbArg =
moduleState.equivalentReturnValToBBArg[returnVal];
if (moduleState.equivalentFuncArgs[funcOp].count(
returnOperand.getOperandNumber())) {
int64_t idx =
moduleState
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
int64_t idx = bbArg.getArgNumber();
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
// Add CallOp operand/result equivalence: this is interprocedural
// info.
@ -661,6 +721,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
return failure();
BufferizationState state(moduleOp, options);
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
// Interestingly, all function args that are not visible outside of a module
@ -692,11 +753,17 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
aliasInfo.setBufferizesToWritableMemory(bbArg);
}
// Analyze and bufferize funcOp.
if (failed(runComprehensiveBufferize(funcOp, options, state)))
return failure();
// Register extra post analysis steps. These cannot be stored in `options`
// because `options` is immutable.
PostAnalysisStepList extraSteps;
extraSteps.emplace_back(std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
populateEquivalentFuncOpBBArgs(funcOp, state);
// Gather equivalence info for CallOps.
equivalenceAnalysis(funcOp, aliasInfo, moduleState);
// Analyze and bufferize funcOp.
if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps)))
return failure();
}
if (options.testAnalysisOnly)

View File

@ -40,15 +40,17 @@ func @insert_slice_fun(
-> (tensor<?xf32>, tensor<?xf32>)
{
// must bufferize out of place.
// CHECK: tensor.insert_slice
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
%r0 = tensor.insert_slice %C into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
// bufferizes inplace.
// CHECK: tensor.insert_slice
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
%r1 = tensor.insert_slice %C into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
}
@ -81,6 +83,8 @@ func @conflict_on_B(
outs(%B: tensor<4x4xf32>)
-> tensor<4x4xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [-1, -1, 1]}
return %C, %D, %E: tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>
}
@ -136,6 +140,8 @@ func @insert_slice_insert_slice(
// CHECK: {__inplace_results_attr__ = ["false"]}
%r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
}
@ -172,6 +178,8 @@ func @extract_slice_nonmatching_insert_slice(
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
%r3 = tensor.insert_slice %r2 into %B[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
}
@ -208,6 +216,8 @@ func @extract_slice_matching_insert_slice(
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
%r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
}
@ -234,6 +244,9 @@ func @read_of_matching_insert_slice_source(
%2 = tensor.insert_slice %1 into %A[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
%3 = vector.transfer_read %1[%idx2], %cst2 : tensor<?xf32>, vector<5xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
return %2, %3 : tensor<?xf32>, vector<5xf32>
}
@ -274,6 +287,8 @@ func @read_of_matching_insert_slice_source_interleaved(
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
%6 = tensor.insert_slice %5 into %2[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
return %6, %3 : tensor<?xf32>, vector<5xf32>
}
@ -306,6 +321,8 @@ func @extract_slice_linalg_readonly_use(
outs(%C: tensor<4x4xf32>)
-> tensor<4x4xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [-1, 2]}
return %D, %E: tensor<4x4xf32>, tensor<4x4xf32>
}
@ -372,6 +389,8 @@ func @insert_slice_double_extract_slice(
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
%20 = tensor.insert_slice %19 into %C[%s3, %s4] [%s1, %s2] [1, 1] : tensor<?x?xf32> into tensor<30x20xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [6]}
return %20 : tensor<30x20xf32>
}
@ -504,6 +523,8 @@ func @nested_extract_slice_and_insert(
%rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor<?x4xf32> into tensor<?x?xf32>
%rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [-1, 1, 2]}
return %rA, %rB, %rC: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
}
@ -533,6 +554,8 @@ func @scf_for_yield_only(%A : tensor<?xf32>,
scf.yield %t : tensor<?xf32>
}
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
}
@ -564,6 +587,8 @@ func @scf_for_with_tensor.insert_slice(%A : tensor<?xf32>,
scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
}
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
}
@ -623,6 +648,8 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
linalg.yield %t : tensor<?xf32>
}
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, 1]}
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
}
@ -768,6 +795,8 @@ builtin.func @matmul_on_tensors(
ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [2]}
return %r : tensor<256x256xf32>
}
@ -813,6 +842,8 @@ builtin.func @matmul_on_tensors(
ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [2]}
return %r : tensor<256x256xf32>
}
@ -858,6 +889,8 @@ func @insert_slice_chain(
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%15 = tensor.insert_slice %14 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [4]}
return %15 : tensor<62x90xf32>
}
@ -883,6 +916,9 @@ func @ip(%t: tensor<10x20xf32> {linalg.inplaceable = true},
%t3 = tensor.insert_slice %t2 into %arg1[%x, 0] [5, %y] [1, 1] : tensor<5x?xf32> into tensor<10x20xf32>
scf.yield %t3 : tensor<10x20xf32>
}
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %r : tensor<10x20xf32>
}
@ -910,6 +946,9 @@ func @linalg_op_same_out_tensors(
^bb(%0: f32, %1: f32, %2 : f32) :
linalg.yield %0, %0 : f32, f32
} -> (tensor<?xf32>, tensor<?xf32>)
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [1, -1]}
return %o#0, %o#1 : tensor<?xf32>, tensor<?xf32>
}
@ -951,6 +990,8 @@ func @double_insert_slice_into_alias(
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%15 = tensor.insert_slice %14 into %e[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<?x?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [2, -1]}
return %8, %15 : tensor<62x90xf32>, tensor<?x?xf32>
}
@ -980,6 +1021,8 @@ func @interleaved_extract_insert_slice_chain_1(
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%15 = tensor.insert_slice %10 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %15 : tensor<62x90xf32>
}
@ -1009,6 +1052,8 @@ func @interleaved_extract_insert_slice_chain_2(
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%15 = tensor.insert_slice %10 into %8[31, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %15 : tensor<62x90xf32>
}
@ -1031,6 +1076,8 @@ func @extract_once_insert_twice(
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%15 = tensor.insert_slice %2 into %8[15, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %15 : tensor<62x90xf32>
}
@ -1132,6 +1179,8 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
linalg.yield %cst : f32
} -> (tensor<?xf32>)
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
return %o, %v3 : tensor<?xf32>, vector<5xf32>
}
@ -1160,6 +1209,9 @@ func @buffer_forwarding_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = true
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%3 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [-1, 0]}
return %2, %3 : tensor<?xf32>, tensor<?xf32>
}
@ -1180,6 +1232,9 @@ func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = t
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, 0]}
return %2, %2 : tensor<?xf32>, tensor<?xf32>
}
@ -1214,6 +1269,8 @@ func @scf_if_inplace2(%t1: tensor<?xf32> {linalg.inplaceable = true},
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t2 : tensor<?xf32>
}
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %r : tensor<?xf32>
}
@ -1263,6 +1320,9 @@ func @scf_if_in_place4(%t1: tensor<?xf32> {linalg.inplaceable = true},
scf.yield %r : tensor<?xf32>
}
%v2 = vector.transfer_read %r_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
return %r_alias, %v2 : tensor<?xf32>, vector<10xf32>
}
@ -1288,6 +1348,9 @@ func @scf_if_inplace5(%t1: tensor<?xf32> {linalg.inplaceable = true},
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %r2 : tensor<?xf32>
}
@ -1318,6 +1381,9 @@ func @scf_if_inplace6(%t1: tensor<?xf32> {linalg.inplaceable = true},
%t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
scf.yield %t3 : tensor<?xf32>
}
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %r : tensor<?xf32>
}
@ -1396,6 +1462,9 @@ func @scf_if_out_of_place1b(%t1: tensor<?xf32> {linalg.inplaceable = true},
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %r2 : tensor<?xf32>
}
@ -1420,6 +1489,9 @@ func @scf_if_out_of_place1c(%t1: tensor<?xf32> {linalg.inplaceable = true},
// CHECK: tensor.insert_slice
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
%r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %r2 : tensor<?xf32>
}
@ -1533,3 +1605,44 @@ func @to_memref_op_is_reading(%t1: tensor<?xf32> {linalg.inplaceable = true},
return %r1, %r2 : vector<5xf32>, vector<5xf32>
}
// -----
// CHECK-LABEL: func @inner_func
func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %t : tensor<?xf32>
}
func @equivalent_func_arg(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf32>) -> tensor<?xf32> {
// This test does not check IR. It just asserts there is no failure due to
// non-equivalent scf.for yield values.
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
%3 = call @inner_func(%t1) : (tensor<?xf32>) -> tensor<?xf32>
scf.yield %3 : tensor<?xf32>
}
return %1: tensor<?xf32>
}
// -----
// CHECK-LABEL: func @inner_func_2
func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
// CHECK: return
// CHECK-SAME: {__equivalent_func_args__ = [0]}
return %0 : tensor<?xf32>
}
func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf32>) -> tensor<?xf32> {
// This test does not check IR. It just asserts there is no failure due to
// non-equivalent scf.for yield values.
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
%3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
scf.yield %3 : tensor<?xf32>
}
return %1: tensor<?xf32>
}

View File

@ -928,3 +928,54 @@ func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},
// CHECK: return
return %0 : tensor<?xf32>
}
// -----
// CHECK-LABEL: func @inner_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
// CHECK: memref.store %{{.*}}, %[[arg0]]
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @equivalent_func_arg(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func @equivalent_func_arg(%t0: tensor<?xf32> {linalg.inplaceable = true},
%c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
// CHECK-NOT: copy
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
// CHECK: call @inner_func(%[[arg0]])
%3 = call @inner_func(%t1) : (tensor<?xf32>) -> tensor<?xf32>
scf.yield %3 : tensor<?xf32>
}
return %1: tensor<?xf32>
}
// -----
// CHECK-LABEL: func @inner_func_2(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
// CHECK: memref.store %{{.*}}, %[[arg0]]
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
return %0 : tensor<?xf32>
}
// CHECK-LABEL: func @equivalent_func_arg_2(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
%c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
// TODO: There should be a memory copy here. This is a bug in CallOp
// bufferization.
// CHECK: call @inner_func_2(%[[arg0]])
%3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
scf.yield %t1 : tensor<?xf32>
}
return %1: tensor<?xf32>
}