[mlir][linalg][bufferize][NFC] Simplify bufferization of CallOps

There is no need to inspect the ReturnOp of the called function.

This change also refactors the bufferization of CallOps in such a way that `lookupBuffer` is called only a single time. This is important for a later change that fixes CallOp bufferization. (There is currently a TODO among the test cases.)

Note: This change modifies a test case but is marked as NFC. There is no change of functionality, but FuncOps with empty bodies are now reported with a different error message.

Differential Revision: https://reviews.llvm.org/D116446
This commit is contained in:
Matthias Springer 2022-01-06 00:22:38 +09:00
parent 66d4090d9b
commit b15b0156ca
2 changed files with 81 additions and 60 deletions

View File

@ -490,17 +490,16 @@ namespace linalg {
namespace comprehensive_bufferize { namespace comprehensive_bufferize {
namespace std_ext { namespace std_ext {
/// Return the index of the parent function's bbArg that is equivalent to the /// Return the index of the bbArg in the given FuncOp that is equivalent to the
/// given ReturnOp operand (if any). /// specified return value (if any).
static Optional<int64_t> static Optional<int64_t>
getEquivalentFuncArgIdx(ModuleBufferizationState &state, getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
OpOperand &returnOperand) { int64_t returnValIdx) {
FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp()); if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
// Return value has no equivalent bbArg. // Return value has no equivalent bbArg.
return None; return None;
return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()]; return state.equivalentFuncArgs[funcOp][returnValIdx];
} }
struct CallOpInterface struct CallOpInterface
@ -529,6 +528,7 @@ struct CallOpInterface
BufferizationState &state) const { BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op); CallOp callOp = cast<CallOp>(op);
unsigned numResults = callOp.getNumResults(); unsigned numResults = callOp.getNumResults();
unsigned numOperands = callOp->getNumOperands();
FuncOp funcOp = getCalledFunction(callOp); FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp && assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected CallOp to a FuncOp"); "expected CallOp to a FuncOp");
@ -542,54 +542,48 @@ struct CallOpInterface
// For non-tensor results: A mapping from return val indices of the old // For non-tensor results: A mapping from return val indices of the old
// CallOp to return val indices of the bufferized CallOp. // CallOp to return val indices of the bufferized CallOp.
SmallVector<Optional<unsigned>> retValMapping(numResults, None); SmallVector<Optional<unsigned>> retValMapping(numResults, None);
// Operands of the bufferized CallOp.
SmallVector<Value> newOperands(numOperands, Value());
if (funcOp.body().empty()) { // Based on previously gathered equivalence information, we know if a
// The callee is bodiless / external, so we cannot inspect it and we // tensor result folds onto an operand. These are the only tensor value
// cannot assume anything. We can just assert that it does not return a // results that are supported at the moment.
// tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
for (int i = 0; i < numResults; ++i) {
Type returnType = callOp.getResult(i).getType();
if (isaTensor(returnType))
return callOp->emitError()
<< "cannot bufferize bodiless function that returns a tensor";
resultTypes.push_back(returnType);
retValMapping[i] = i;
}
} else {
// The callee has a body. Based on previously gathered equivalence
// information, we know if a tensor result folds onto an operand. These
// are the only tensor value returns that are supported at the moment.
// //
// For tensors return values that do not fold onto an operand, additional // For tensors return values that do not fold onto an operand, additional
// work is needed (TODO) to either: // work is needed (TODO) to either:
// * hoist a result into an inplaceable operand or // * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer. // * devise a better representation to truly return a buffer.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); //
assert(returnOp && "expected func with single return op"); // Note: If a function has no body, no equivalence information is
// available. Consequently, a tensor return value cannot be proven to fold
// onto a FuncOp bbArg, so calls to such functions are not bufferizable at
// the moment.
// For each FuncOp result, keep track of which inplace argument it reuses. // 1. Compute the result types of the new CallOp. Tensor results that are
for (OpOperand &returnOperand : returnOp->getOpOperands()) { // equivalent to a FuncOp bbArg are no longer returned.
unsigned returnIdx = returnOperand.getOperandNumber(); for (auto it : llvm::enumerate(callOp.getResultTypes())) {
Type returnType = returnOperand.get().getType(); unsigned returnValIdx = it.index();
Type returnType = it.value();
if (!isaTensor(returnType)) { if (!isaTensor(returnType)) {
// Non-tensor values are returned. // Non-tensor values are returned.
retValMapping[returnIdx] = resultTypes.size(); retValMapping[returnValIdx] = resultTypes.size();
resultTypes.push_back(returnType); resultTypes.push_back(returnType);
continue; continue;
} }
if (Optional<int64_t> bbArgIdx = if (Optional<int64_t> bbArgIdx =
getEquivalentFuncArgIdx(moduleState, returnOperand)) { getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
// Return operands that are equivalent to some bbArg, are not // Return operands that are equivalent to some bbArg, are not
// returned. // returned.
replacementValues[returnIdx] = Value buffer =
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx)); state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
replacementValues[returnValIdx] = buffer;
newOperands[*bbArgIdx] = buffer;
continue; continue;
} }
llvm_unreachable("returning non-equivalent tensors not supported"); return callOp->emitError(
} "call to FuncOp that returns non-equivalent tensors not supported");
} }
// 2. Compute bufferized FunctionType. // 2. Compute bufferized FunctionType.
@ -601,23 +595,26 @@ struct CallOpInterface
moduleState.bufferizedFunctionTypes); moduleState.bufferizedFunctionTypes);
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`. // 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
SmallVector<Value> newOperands;
newOperands.reserve(callOp->getNumOperands());
for (OpOperand &opOperand : callOp->getOpOperands()) { for (OpOperand &opOperand : callOp->getOpOperands()) {
unsigned idx = opOperand.getOperandNumber();
Value tensorOperand = opOperand.get(); Value tensorOperand = opOperand.get();
// Non-tensor operands are just copied. // Non-tensor operands are just copied.
if (!tensorOperand.getType().isa<TensorType>()) { if (!tensorOperand.getType().isa<TensorType>()) {
newOperands.push_back(tensorOperand); newOperands[idx] = tensorOperand;
continue; continue;
} }
// Tensor operands are guaranteed to have been buferized. // Retrieve buffers for tensor operands. Tensor operand buffers, who's
int64_t idx = opOperand.getOperandNumber(); // corresponding FuncOp bbArgs are equivalent to a returned tensor, were
Value buffer = state.lookupBuffer(rewriter, tensorOperand); // already stored in `newOperands` during Step 1.
Value buffer = newOperands[idx]
? newOperands[idx]
: state.lookupBuffer(rewriter, tensorOperand);
// Caller / callee type mistmatch is handled with a CastOp. // Caller / callee type mistmatch is handled with a CastOp.
auto memRefType = bufferizedFuncType.getInput(idx); auto memRefType = bufferizedFuncType.getInput(idx);
// Since we don't yet have a clear layout story, buffer_cast may // Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary. // conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast // If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do // that will either canonicalize away or fail compilation until we can do
@ -627,20 +624,21 @@ struct CallOpInterface
memRefType, buffer); memRefType, buffer);
buffer = castBuffer; buffer = castBuffer;
} }
newOperands.push_back(buffer); newOperands[idx] = buffer;
} }
// 4. Create the new CallOp. // 4. Create the new CallOp.
Operation *newCallOp = rewriter.create<CallOp>( Operation *newCallOp = rewriter.create<CallOp>(
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands); callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs()); newCallOp->setAttrs(callOp->getAttrs());
// Get replacement values for non-tensor / non-equivalent results.
// 5. Replace the old op with the new op.
for (int i = 0; i < replacementValues.size(); ++i) { for (int i = 0; i < replacementValues.size(); ++i) {
if (replacementValues[i]) if (replacementValues[i])
continue; continue;
replacementValues[i] = newCallOp->getResult(*retValMapping[i]); replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
} }
// 5. Replace the old op with the new op.
state.replaceOp(rewriter, callOp, replacementValues); state.replaceOp(rewriter, callOp, replacementValues);
return success(); return success();

View File

@ -187,3 +187,26 @@ func @to_memref_op_is_writing(
return %r1, %r2 : vector<5xf32>, vector<5xf32> return %r1, %r2 : vector<5xf32>, vector<5xf32>
} }
// -----
func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
// expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
return
}
// -----
func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) {
%0 = linalg.init_tensor [5] : tensor<5xf32>
return %0 : tensor<5xf32>
}
func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) {
// expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
return
}