forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Collect equivalent FuncOp BBArgs in PostAnalysisStep
Collect equivalent BBArgs right after the equivalence analysis of the FuncOp and before bufferizing. This is in preparation of decoupling bufferization from aliasInfo. Also gather equivalence info for CallOps, which was missing in the previous commit. Differential Revision: https://reviews.llvm.org/D114847
This commit is contained in:
parent
b15d77928e
commit
cb4d0bf997
|
@ -71,6 +71,8 @@ struct PostAnalysisStep {
|
|||
SmallVector<Operation *> &newOps) = 0;
|
||||
};
|
||||
|
||||
using PostAnalysisStepList = std::vector<std::unique_ptr<PostAnalysisStep>>;
|
||||
|
||||
/// Options for ComprehensiveBufferize.
|
||||
struct BufferizationOptions {
|
||||
BufferizationOptions();
|
||||
|
@ -107,7 +109,7 @@ struct BufferizationOptions {
|
|||
bool testAnalysisOnly = false;
|
||||
|
||||
/// Registered post analysis steps.
|
||||
std::vector<std::unique_ptr<PostAnalysisStep>> postAnalysisSteps;
|
||||
PostAnalysisStepList postAnalysisSteps;
|
||||
};
|
||||
|
||||
/// Specify fine-grain relationship between buffers to enable more analysis.
|
||||
|
|
|
@ -18,13 +18,16 @@ namespace comprehensive_bufferize {
|
|||
|
||||
struct BufferizationOptions;
|
||||
struct BufferizationState;
|
||||
struct PostAnalysisStep;
|
||||
|
||||
/// Bufferize the given function. Does not bufferize the function boundary.
|
||||
/// Reuses an existing BufferizationState object.
|
||||
// TODO: This function is meant to be called from ModuleBufferize and not can
|
||||
// not yet be called standalone.
|
||||
LogicalResult runComprehensiveBufferize(FuncOp funcOp,
|
||||
const BufferizationOptions &options,
|
||||
BufferizationState &state);
|
||||
LogicalResult runComprehensiveBufferize(
|
||||
FuncOp funcOp, const BufferizationOptions &options,
|
||||
BufferizationState &state,
|
||||
const std::vector<std::unique_ptr<PostAnalysisStep>> &extraSteps);
|
||||
|
||||
} // namespace comprehensive_bufferize
|
||||
} // namespace linalg
|
||||
|
|
|
@ -726,7 +726,7 @@ annotateOpsWithBufferizationMarkers(Operation *op,
|
|||
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
||||
FuncOp funcOp, const BufferizationOptions &options,
|
||||
BufferizationState &state) {
|
||||
BufferizationState &state, const PostAnalysisStepList &extraSteps) {
|
||||
|
||||
DominanceInfo domInfo(funcOp);
|
||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||
|
@ -744,16 +744,23 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
return failure();
|
||||
equivalenceAnalysis(op, aliasInfo);
|
||||
|
||||
for (const std::unique_ptr<PostAnalysisStep> &step :
|
||||
options.postAnalysisSteps) {
|
||||
SmallVector<Operation *> newOps;
|
||||
if (failed(step->run(funcOp, state, newOps)))
|
||||
return failure();
|
||||
// Analyze ops that were created by the PostAnalysisStep.
|
||||
if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
|
||||
return failure();
|
||||
equivalenceAnalysis(newOps, aliasInfo);
|
||||
}
|
||||
auto runPostAnalysisSteps = [&](const PostAnalysisStepList &steps) {
|
||||
for (const std::unique_ptr<PostAnalysisStep> &step : steps) {
|
||||
SmallVector<Operation *> newOps;
|
||||
if (failed(step->run(funcOp, state, newOps)))
|
||||
return failure();
|
||||
// Analyze ops that were created by the PostAnalysisStep.
|
||||
if (failed(inPlaceAnalysis(newOps, aliasInfo, domInfo)))
|
||||
return failure();
|
||||
equivalenceAnalysis(newOps, aliasInfo);
|
||||
}
|
||||
return success();
|
||||
};
|
||||
|
||||
if (failed(runPostAnalysisSteps(extraSteps)))
|
||||
return failure();
|
||||
if (failed(runPostAnalysisSteps(options.postAnalysisSteps)))
|
||||
return failure();
|
||||
|
||||
// Annotate operations if we only want to report the analysis.
|
||||
if (options.testAnalysisOnly) {
|
||||
|
|
|
@ -33,8 +33,9 @@ struct ModuleBufferizationState : public DialectBufferizationState {
|
|||
/// A map for looking up bufferized function types.
|
||||
DenseMap<FuncOp, FunctionType> bufferizedFunctionTypes;
|
||||
|
||||
/// A mapping of return values to equivalent BlockArguments.
|
||||
DenseMap<Value, BlockArgument> equivalentReturnValToBBArg;
|
||||
/// A mapping of ReturnOp OpOperand indices to equivalent FuncOp BBArg
|
||||
/// indices.
|
||||
DenseMap<FuncOp, DenseMap<int64_t, int64_t>> equivalentFuncArgs;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
|
@ -44,6 +45,70 @@ getModuleBufferizationState(BufferizationState &state) {
|
|||
StandardOpsDialect::getDialectNamespace());
|
||||
}
|
||||
|
||||
/// Return the unique ReturnOp that terminates `funcOp`.
|
||||
/// Return nullptr if there is no such unique ReturnOp.
|
||||
static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
|
||||
ReturnOp returnOp;
|
||||
for (Block &b : funcOp.body()) {
|
||||
if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
|
||||
if (returnOp)
|
||||
return nullptr;
|
||||
returnOp = candidateOp;
|
||||
}
|
||||
}
|
||||
return returnOp;
|
||||
}
|
||||
|
||||
namespace {
|
||||
/// Store function BlockArguments that are equivalent to a returned value in
|
||||
/// ModuleBufferizationState.
|
||||
struct EquivalentFuncOpBBArgsAnalysis : public PostAnalysisStep {
|
||||
/// Annotate IR with the results of the analysis. For testing purposes only.
|
||||
static void annotateReturnOp(OpOperand &returnVal, BlockArgument bbArg) {
|
||||
const char *kEquivalentArgsAttr = "__equivalent_func_args__";
|
||||
Operation *op = returnVal.getOwner();
|
||||
|
||||
SmallVector<int64_t> equivBbArgs;
|
||||
if (op->hasAttr(kEquivalentArgsAttr)) {
|
||||
auto attr = op->getAttr(kEquivalentArgsAttr).cast<ArrayAttr>();
|
||||
equivBbArgs = llvm::to_vector<4>(llvm::map_range(attr, [](Attribute a) {
|
||||
return a.cast<IntegerAttr>().getValue().getSExtValue();
|
||||
}));
|
||||
} else {
|
||||
equivBbArgs.append(op->getNumOperands(), -1);
|
||||
}
|
||||
equivBbArgs[returnVal.getOperandNumber()] = bbArg.getArgNumber();
|
||||
|
||||
OpBuilder b(op->getContext());
|
||||
op->setAttr(kEquivalentArgsAttr, b.getI64ArrayAttr(equivBbArgs));
|
||||
}
|
||||
|
||||
LogicalResult run(FuncOp funcOp, BufferizationState &state,
|
||||
SmallVector<Operation *> &newOps) override {
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
|
||||
// Support only single return-terminated block in the function.
|
||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||
assert(returnOp && "expected func with single return op");
|
||||
|
||||
for (OpOperand &returnVal : returnOp->getOpOperands())
|
||||
if (returnVal.get().getType().isa<RankedTensorType>())
|
||||
for (BlockArgument bbArg : funcOp.getArguments())
|
||||
if (bbArg.getType().isa<RankedTensorType>())
|
||||
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal.get(),
|
||||
bbArg)) {
|
||||
moduleState
|
||||
.equivalentFuncArgs[funcOp][returnVal.getOperandNumber()] =
|
||||
bbArg.getArgNumber();
|
||||
if (state.options.testAnalysisOnly)
|
||||
annotateReturnOp(returnVal, bbArg);
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
static bool isaTensor(Type t) { return t.isa<TensorType>(); }
|
||||
|
||||
/// If `value` is a memref::CastOp, return its source. Otherwise, return
|
||||
|
@ -73,20 +138,6 @@ static FuncOp getCalledFunction(CallOpInterface callOp) {
|
|||
SymbolTable::lookupNearestSymbolFrom(callOp, sym));
|
||||
}
|
||||
|
||||
/// Return the unique ReturnOp that terminates `funcOp`.
|
||||
/// Return nullptr if there is no such unique ReturnOp.
|
||||
static ReturnOp getAssumedUniqueReturnOp(FuncOp funcOp) {
|
||||
ReturnOp returnOp;
|
||||
for (Block &b : funcOp.body()) {
|
||||
if (auto candidateOp = dyn_cast<ReturnOp>(b.getTerminator())) {
|
||||
if (returnOp)
|
||||
return nullptr;
|
||||
returnOp = candidateOp;
|
||||
}
|
||||
}
|
||||
return returnOp;
|
||||
}
|
||||
|
||||
/// Return the FunctionType with `argumentTypes` and `resultTypes` where each
|
||||
/// tensor is replaced by the corresponding buffer type.
|
||||
/// In order for all the callers to agree, this *must* bufferize to the most
|
||||
|
@ -128,22 +179,30 @@ static FunctionType getOrCreateBufferizedFunctionType(
|
|||
return it2.first->second;
|
||||
}
|
||||
|
||||
/// Store function BlockArguments that are equivalent to a returned value in
|
||||
/// the given ModuleBufferizationState.
|
||||
static void populateEquivalentFuncOpBBArgs(FuncOp funcOp,
|
||||
BufferizationState &state) {
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
/// Gather equivalence info of CallOps.
|
||||
/// Note: This only adds new equivalence info if `funcOp` was already analyzed.
|
||||
// TODO: This does not handle cyclic function call graphs etc.
|
||||
static void equivalenceAnalysis(FuncOp funcOp,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
ModuleBufferizationState &moduleState) {
|
||||
funcOp->walk([&](CallOp callOp) {
|
||||
FuncOp calledFunction = getCalledFunction(callOp);
|
||||
assert(calledFunction && "could not retrieved called FuncOp");
|
||||
|
||||
// Support only single return-terminated block in the function.
|
||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||
assert(returnOp && "expected func with single return op");
|
||||
// No equivalence info available for the called function.
|
||||
if (!moduleState.equivalentFuncArgs.count(calledFunction))
|
||||
return WalkResult::skip();
|
||||
|
||||
for (Value returnVal : returnOp.operands())
|
||||
if (returnVal.getType().isa<RankedTensorType>())
|
||||
for (BlockArgument bbArg : funcOp.getArguments())
|
||||
if (bbArg.getType().isa<RankedTensorType>())
|
||||
if (state.aliasInfo.areEquivalentBufferizedValues(returnVal, bbArg))
|
||||
moduleState.equivalentReturnValToBBArg[returnVal] = bbArg;
|
||||
for (auto it : moduleState.equivalentFuncArgs[calledFunction]) {
|
||||
int64_t returnIdx = it.first;
|
||||
int64_t bbargIdx = it.second;
|
||||
Value returnVal = callOp.getResult(returnIdx);
|
||||
Value argVal = callOp->getOperand(bbargIdx);
|
||||
aliasInfo.unionEquivalenceClasses(returnVal, argVal);
|
||||
}
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
}
|
||||
|
||||
/// Rewrite the `funcOp` arguments analysis return values and terminator into
|
||||
|
@ -217,7 +276,8 @@ static LogicalResult bufferizeFuncOpBoundary(FuncOp funcOp,
|
|||
}
|
||||
|
||||
// If return operand is equivalent to some bbArg, no need to return it.
|
||||
if (moduleState.equivalentReturnValToBBArg.count(returnVal))
|
||||
if (moduleState.equivalentFuncArgs[funcOp].count(
|
||||
returnOperand.getOperandNumber()))
|
||||
continue;
|
||||
|
||||
// Cast values at the call site if necessary.
|
||||
|
@ -493,12 +553,12 @@ struct CallOpInterface
|
|||
}
|
||||
|
||||
// If return operand is equivalent to some bbArg, no need to return it.
|
||||
Value returnVal = returnOperand.get();
|
||||
if (moduleState.equivalentReturnValToBBArg.count(returnVal)) {
|
||||
BlockArgument bbArg =
|
||||
moduleState.equivalentReturnValToBBArg[returnVal];
|
||||
if (moduleState.equivalentFuncArgs[funcOp].count(
|
||||
returnOperand.getOperandNumber())) {
|
||||
int64_t idx =
|
||||
moduleState
|
||||
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
|
||||
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
|
||||
int64_t idx = bbArg.getArgNumber();
|
||||
Value buffer = state.lookupBuffer(callOp->getOperand(idx));
|
||||
// Add CallOp operand/result equivalence: this is interprocedural
|
||||
// info.
|
||||
|
@ -661,6 +721,7 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
return failure();
|
||||
|
||||
BufferizationState state(moduleOp, options);
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
BufferizationAliasInfo &aliasInfo = state.aliasInfo;
|
||||
|
||||
// Interestingly, all function args that are not visible outside of a module
|
||||
|
@ -692,11 +753,17 @@ LogicalResult mlir::linalg::comprehensive_bufferize::runComprehensiveBufferize(
|
|||
aliasInfo.setBufferizesToWritableMemory(bbArg);
|
||||
}
|
||||
|
||||
// Analyze and bufferize funcOp.
|
||||
if (failed(runComprehensiveBufferize(funcOp, options, state)))
|
||||
return failure();
|
||||
// Register extra post analysis steps. These cannot be stored in `options`
|
||||
// because `options` is immutable.
|
||||
PostAnalysisStepList extraSteps;
|
||||
extraSteps.emplace_back(std::make_unique<EquivalentFuncOpBBArgsAnalysis>());
|
||||
|
||||
populateEquivalentFuncOpBBArgs(funcOp, state);
|
||||
// Gather equivalence info for CallOps.
|
||||
equivalenceAnalysis(funcOp, aliasInfo, moduleState);
|
||||
|
||||
// Analyze and bufferize funcOp.
|
||||
if (failed(runComprehensiveBufferize(funcOp, options, state, extraSteps)))
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (options.testAnalysisOnly)
|
||||
|
|
|
@ -40,15 +40,17 @@ func @insert_slice_fun(
|
|||
-> (tensor<?xf32>, tensor<?xf32>)
|
||||
{
|
||||
// must bufferize out of place.
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
|
||||
%r0 = tensor.insert_slice %C into %A[0][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
// bufferizes inplace.
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
|
||||
%r1 = tensor.insert_slice %C into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
|
||||
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -81,6 +83,8 @@ func @conflict_on_B(
|
|||
outs(%B: tensor<4x4xf32>)
|
||||
-> tensor<4x4xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [-1, -1, 1]}
|
||||
return %C, %D, %E: tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>
|
||||
}
|
||||
|
||||
|
@ -136,6 +140,8 @@ func @insert_slice_insert_slice(
|
|||
// CHECK: {__inplace_results_attr__ = ["false"]}
|
||||
%r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
|
||||
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -172,6 +178,8 @@ func @extract_slice_nonmatching_insert_slice(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
|
||||
%r3 = tensor.insert_slice %r2 into %B[%idx][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
|
||||
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -208,6 +216,8 @@ func @extract_slice_matching_insert_slice(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["false"]}
|
||||
%r3 = tensor.insert_slice %r2 into %B[0][4][1] : tensor<4xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
|
||||
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -234,6 +244,9 @@ func @read_of_matching_insert_slice_source(
|
|||
%2 = tensor.insert_slice %1 into %A[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
|
||||
|
||||
%3 = vector.transfer_read %1[%idx2], %cst2 : tensor<?xf32>, vector<5xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
|
||||
return %2, %3 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
|
||||
|
@ -274,6 +287,8 @@ func @read_of_matching_insert_slice_source_interleaved(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
|
||||
%6 = tensor.insert_slice %5 into %2[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
|
||||
return %6, %3 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
|
||||
|
@ -306,6 +321,8 @@ func @extract_slice_linalg_readonly_use(
|
|||
outs(%C: tensor<4x4xf32>)
|
||||
-> tensor<4x4xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [-1, 2]}
|
||||
return %D, %E: tensor<4x4xf32>, tensor<4x4xf32>
|
||||
}
|
||||
|
||||
|
@ -372,6 +389,8 @@ func @insert_slice_double_extract_slice(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]}
|
||||
%20 = tensor.insert_slice %19 into %C[%s3, %s4] [%s1, %s2] [1, 1] : tensor<?x?xf32> into tensor<30x20xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [6]}
|
||||
return %20 : tensor<30x20xf32>
|
||||
}
|
||||
|
||||
|
@ -504,6 +523,8 @@ func @nested_extract_slice_and_insert(
|
|||
%rsC = tensor.insert_slice %FC into %sC[0, 0][%sz2, 4][1, 1] : tensor<?x4xf32> into tensor<?x?xf32>
|
||||
%rC = tensor.insert_slice %rsC into %C[0, 0][%idx, %idx][1, 1] : tensor<?x?xf32> into tensor<?x?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [-1, 1, 2]}
|
||||
return %rA, %rB, %rC: tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
|
@ -533,6 +554,8 @@ func @scf_for_yield_only(%A : tensor<?xf32>,
|
|||
scf.yield %t : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
|
||||
return %r0, %r1: tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -564,6 +587,8 @@ func @scf_for_with_tensor.insert_slice(%A : tensor<?xf32>,
|
|||
scf.yield %ttA, %ttB : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [-1, 1]}
|
||||
return %r0#0, %r0#1: tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -623,6 +648,8 @@ func @scf_for_deps(%A : tensor<?xf32> {linalg.inplaceable = true},
|
|||
linalg.yield %t : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, 1]}
|
||||
return %r1, %r3: tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -768,6 +795,8 @@ builtin.func @matmul_on_tensors(
|
|||
ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
|
||||
outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [2]}
|
||||
return %r : tensor<256x256xf32>
|
||||
}
|
||||
|
||||
|
@ -813,6 +842,8 @@ builtin.func @matmul_on_tensors(
|
|||
ins(%sA, %sB : tensor<256x16xf32>, tensor<16x256xf32>)
|
||||
outs(%arg2 : tensor<256x256xf32>) -> tensor<256x256xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [2]}
|
||||
return %r : tensor<256x256xf32>
|
||||
}
|
||||
|
||||
|
@ -858,6 +889,8 @@ func @insert_slice_chain(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%15 = tensor.insert_slice %14 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [4]}
|
||||
return %15 : tensor<62x90xf32>
|
||||
}
|
||||
|
||||
|
@ -883,6 +916,9 @@ func @ip(%t: tensor<10x20xf32> {linalg.inplaceable = true},
|
|||
%t3 = tensor.insert_slice %t2 into %arg1[%x, 0] [5, %y] [1, 1] : tensor<5x?xf32> into tensor<10x20xf32>
|
||||
scf.yield %t3 : tensor<10x20xf32>
|
||||
}
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %r : tensor<10x20xf32>
|
||||
}
|
||||
|
||||
|
@ -910,6 +946,9 @@ func @linalg_op_same_out_tensors(
|
|||
^bb(%0: f32, %1: f32, %2 : f32) :
|
||||
linalg.yield %0, %0 : f32, f32
|
||||
} -> (tensor<?xf32>, tensor<?xf32>)
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [1, -1]}
|
||||
return %o#0, %o#1 : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -951,6 +990,8 @@ func @double_insert_slice_into_alias(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%15 = tensor.insert_slice %14 into %e[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<?x?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [2, -1]}
|
||||
return %8, %15 : tensor<62x90xf32>, tensor<?x?xf32>
|
||||
}
|
||||
|
||||
|
@ -980,6 +1021,8 @@ func @interleaved_extract_insert_slice_chain_1(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%15 = tensor.insert_slice %10 into %8[32, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %15 : tensor<62x90xf32>
|
||||
}
|
||||
|
||||
|
@ -1009,6 +1052,8 @@ func @interleaved_extract_insert_slice_chain_2(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%15 = tensor.insert_slice %10 into %8[31, 0] [30, 90] [1, 1] : tensor<30x90xf32> into tensor<62x90xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %15 : tensor<62x90xf32>
|
||||
}
|
||||
|
||||
|
@ -1031,6 +1076,8 @@ func @extract_once_insert_twice(
|
|||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%15 = tensor.insert_slice %2 into %8[15, 0] [32, 90] [1, 1] : tensor<32x90xf32> into tensor<62x90xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %15 : tensor<62x90xf32>
|
||||
}
|
||||
|
||||
|
@ -1132,6 +1179,8 @@ func @non_reading_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
linalg.yield %cst : f32
|
||||
} -> (tensor<?xf32>)
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
|
||||
return %o, %v3 : tensor<?xf32>, vector<5xf32>
|
||||
}
|
||||
|
||||
|
@ -1160,6 +1209,9 @@ func @buffer_forwarding_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = true
|
|||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%3 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [-1, 0]}
|
||||
return %2, %3 : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -1180,6 +1232,9 @@ func @buffer_forwarding_no_conflict(%arg0: tensor<?xf32> {linalg.inplaceable = t
|
|||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%2 = tensor.insert_slice %1 into %arg0[42] [%arg1] [1] : tensor<?xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, 0]}
|
||||
return %2, %2 : tensor<?xf32>, tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -1214,6 +1269,8 @@ func @scf_if_inplace2(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
%t2 = vector.transfer_write %v, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
}
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -1263,6 +1320,9 @@ func @scf_if_in_place4(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
scf.yield %r : tensor<?xf32>
|
||||
}
|
||||
%v2 = vector.transfer_read %r_alias[%idx], %cst : tensor<?xf32>, vector<10xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0, -1]}
|
||||
return %r_alias, %v2 : tensor<?xf32>, vector<10xf32>
|
||||
}
|
||||
|
||||
|
@ -1288,6 +1348,9 @@ func @scf_if_inplace5(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%r2 = tensor.insert_slice %r into %t1[%idx][%idx][1] : tensor<?xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %r2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -1318,6 +1381,9 @@ func @scf_if_inplace6(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
%t3 = vector.transfer_write %v2, %t1[%idx] : vector<5xf32>, tensor<?xf32>
|
||||
scf.yield %t3 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %r : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -1396,6 +1462,9 @@ func @scf_if_out_of_place1b(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%r2 = tensor.insert_slice %r into %t1[%idx3][%idx3][1] : tensor<?xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %r2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -1420,6 +1489,9 @@ func @scf_if_out_of_place1c(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
// CHECK: tensor.insert_slice
|
||||
// CHECK-SAME: {__inplace_results_attr__ = ["true"]
|
||||
%r2 = tensor.insert_slice %r into %t1[%idx2][%idx2][1] : tensor<?xf32> into tensor<?xf32>
|
||||
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %r2 : tensor<?xf32>
|
||||
}
|
||||
|
||||
|
@ -1533,3 +1605,44 @@ func @to_memref_op_is_reading(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
|
||||
return %r1, %r2 : vector<5xf32>, vector<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @inner_func
|
||||
func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %t : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @equivalent_func_arg(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// This test does not check IR. It just asserts there is no failure due to
|
||||
// non-equivalent scf.for yield values.
|
||||
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
|
||||
%3 = call @inner_func(%t1) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
scf.yield %3 : tensor<?xf32>
|
||||
}
|
||||
return %1: tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @inner_func_2
|
||||
func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%f = arith.constant 1.0 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
|
||||
// CHECK: return
|
||||
// CHECK-SAME: {__equivalent_func_args__ = [0]}
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
func @equivalent_func_arg_2(%c0: index, %c10: index, %c1: index, %t0: tensor<?xf32>) -> tensor<?xf32> {
|
||||
// This test does not check IR. It just asserts there is no failure due to
|
||||
// non-equivalent scf.for yield values.
|
||||
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
|
||||
%3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
scf.yield %3 : tensor<?xf32>
|
||||
}
|
||||
return %1: tensor<?xf32>
|
||||
}
|
||||
|
|
|
@ -928,3 +928,54 @@ func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},
|
|||
// CHECK: return
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @inner_func(
|
||||
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
|
||||
func @inner_func(%t: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%f = arith.constant 1.0 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
// CHECK: memref.store %{{.*}}, %[[arg0]]
|
||||
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equivalent_func_arg(
|
||||
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
|
||||
func @equivalent_func_arg(%t0: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
|
||||
// CHECK-NOT: copy
|
||||
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
|
||||
// CHECK: call @inner_func(%[[arg0]])
|
||||
%3 = call @inner_func(%t1) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
scf.yield %3 : tensor<?xf32>
|
||||
}
|
||||
return %1: tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @inner_func_2(
|
||||
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
|
||||
func @inner_func_2(%t: tensor<?xf32>) -> tensor<?xf32> {
|
||||
%f = arith.constant 1.0 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
// CHECK: memref.store %{{.*}}, %[[arg0]]
|
||||
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @equivalent_func_arg_2(
|
||||
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
|
||||
func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%c0: index, %c10: index, %c1: index) -> tensor<?xf32> {
|
||||
%1 = scf.for %iv = %c0 to %c10 step %c1 iter_args(%t1 = %t0) -> (tensor<?xf32>) {
|
||||
// TODO: There should be a memory copy here. This is a bug in CallOp
|
||||
// bufferization.
|
||||
// CHECK: call @inner_func_2(%[[arg0]])
|
||||
%3 = call @inner_func_2(%t1) : (tensor<?xf32>) -> tensor<?xf32>
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
}
|
||||
return %1: tensor<?xf32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue