[mlir][linalg][bufferize][NFC] Simplify bufferization of CallOps

There is no need to inspect the ReturnOp of the called function.

This change also refactors the bufferization of CallOps in such a way that `lookupBuffer` is called only a single time. This is important for a later change that fixes CallOp bufferization. (There is currently a TODO among the test cases.)

Note: This change modifies a test case but is marked as NFC. There is no change of functionality, but FuncOps with empty bodies are now reported with a different error message.

Differential Revision: https://reviews.llvm.org/D116446
This commit is contained in:
Matthias Springer 2022-01-06 00:22:38 +09:00
parent 66d4090d9b
commit b15b0156ca
2 changed files with 81 additions and 60 deletions

View File

@ -490,17 +490,16 @@ namespace linalg {
namespace comprehensive_bufferize {
namespace std_ext {
/// Return the index of the parent function's bbArg that is equivalent to the
/// given ReturnOp operand (if any).
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
/// specified return value (if any).
static Optional<int64_t>
getEquivalentFuncArgIdx(ModuleBufferizationState &state,
OpOperand &returnOperand) {
FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
int64_t returnValIdx) {
if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
// Return value has no equivalent bbArg.
return None;
return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
return state.equivalentFuncArgs[funcOp][returnValIdx];
}
struct CallOpInterface
@ -529,6 +528,7 @@ struct CallOpInterface
BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op);
unsigned numResults = callOp.getNumResults();
unsigned numOperands = callOp->getNumOperands();
FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected CallOp to a FuncOp");
@ -542,54 +542,48 @@ struct CallOpInterface
// For non-tensor results: A mapping from return val indices of the old
// CallOp to return val indices of the bufferized CallOp.
SmallVector<Optional<unsigned>> retValMapping(numResults, None);
// Operands of the bufferized CallOp.
SmallVector<Value> newOperands(numOperands, Value());
if (funcOp.body().empty()) {
// The callee is bodiless / external, so we cannot inspect it and we
// cannot assume anything. We can just assert that it does not return a
// tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
for (int i = 0; i < numResults; ++i) {
Type returnType = callOp.getResult(i).getType();
if (isaTensor(returnType))
return callOp->emitError()
<< "cannot bufferize bodiless function that returns a tensor";
resultTypes.push_back(returnType);
retValMapping[i] = i;
}
} else {
// The callee has a body. Based on previously gathered equivalence
// information, we know if a tensor result folds onto an operand. These
// are the only tensor value returns that are supported at the moment.
// Based on previously gathered equivalence information, we know if a
// tensor result folds onto an operand. These are the only tensor value
// results that are supported at the moment.
//
// For tensors return values that do not fold onto an operand, additional
// work is needed (TODO) to either:
// * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op");
//
// Note: If a function has no body, no equivalence information is
// available. Consequently, a tensor return value cannot be proven to fold
// onto a FuncOp bbArg, so calls to such functions are not bufferizable at
// the moment.
// For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
unsigned returnIdx = returnOperand.getOperandNumber();
Type returnType = returnOperand.get().getType();
// 1. Compute the result types of the new CallOp. Tensor results that are
// equivalent to a FuncOp bbArg are no longer returned.
for (auto it : llvm::enumerate(callOp.getResultTypes())) {
unsigned returnValIdx = it.index();
Type returnType = it.value();
if (!isaTensor(returnType)) {
// Non-tensor values are returned.
retValMapping[returnIdx] = resultTypes.size();
retValMapping[returnValIdx] = resultTypes.size();
resultTypes.push_back(returnType);
continue;
}
if (Optional<int64_t> bbArgIdx =
getEquivalentFuncArgIdx(moduleState, returnOperand)) {
getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
// Return operands that are equivalent to some bbArg, are not
// returned.
replacementValues[returnIdx] =
Value buffer =
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
replacementValues[returnValIdx] = buffer;
newOperands[*bbArgIdx] = buffer;
continue;
}
llvm_unreachable("returning non-equivalent tensors not supported");
}
return callOp->emitError(
"call to FuncOp that returns non-equivalent tensors not supported");
}
// 2. Compute bufferized FunctionType.
@ -601,23 +595,26 @@ struct CallOpInterface
moduleState.bufferizedFunctionTypes);
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
SmallVector<Value> newOperands;
newOperands.reserve(callOp->getNumOperands());
for (OpOperand &opOperand : callOp->getOpOperands()) {
unsigned idx = opOperand.getOperandNumber();
Value tensorOperand = opOperand.get();
// Non-tensor operands are just copied.
if (!tensorOperand.getType().isa<TensorType>()) {
newOperands.push_back(tensorOperand);
newOperands[idx] = tensorOperand;
continue;
}
// Tensor operands are guaranteed to have been buferized.
int64_t idx = opOperand.getOperandNumber();
Value buffer = state.lookupBuffer(rewriter, tensorOperand);
// Retrieve buffers for tensor operands. Tensor operand buffers, who's
// corresponding FuncOp bbArgs are equivalent to a returned tensor, were
// already stored in `newOperands` during Step 1.
Value buffer = newOperands[idx]
? newOperands[idx]
: state.lookupBuffer(rewriter, tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx);
// Since we don't yet have a clear layout story, buffer_cast may
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
@ -627,20 +624,21 @@ struct CallOpInterface
memRefType, buffer);
buffer = castBuffer;
}
newOperands.push_back(buffer);
newOperands[idx] = buffer;
}
// 4. Create the new CallOp.
Operation *newCallOp = rewriter.create<CallOp>(
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs());
// 5. Replace the old op with the new op.
// Get replacement values for non-tensor / non-equivalent results.
for (int i = 0; i < replacementValues.size(); ++i) {
if (replacementValues[i])
continue;
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
}
// 5. Replace the old op with the new op.
state.replaceOp(rewriter, callOp, replacementValues);
return success();

View File

@ -187,3 +187,26 @@ func @to_memref_op_is_writing(
return %r1, %r2 : vector<5xf32>, vector<5xf32>
}
// -----
func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
// expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
return
}
// -----
func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) {
%0 = linalg.init_tensor [5] : tensor<5xf32>
return %0 : tensor<5xf32>
}
func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) {
// expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return
}