forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Simplify bufferization of CallOps
There is no need to inspect the ReturnOp of the called function. This change also refactors the bufferization of CallOps in such a way that `lookupBuffer` is called only a single time. This is important for a later change that fixes CallOp bufferization. (There is currently a TODO among the test cases.) Note: This change modifies a test case but is marked as NFC. There is no change of functionality, but FuncOps with empty bodies are now reported with a different error message. Differential Revision: https://reviews.llvm.org/D116446
This commit is contained in:
parent
66d4090d9b
commit
b15b0156ca
|
@ -490,17 +490,16 @@ namespace linalg {
|
|||
namespace comprehensive_bufferize {
|
||||
namespace std_ext {
|
||||
|
||||
/// Return the index of the parent function's bbArg that is equivalent to the
|
||||
/// given ReturnOp operand (if any).
|
||||
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
|
||||
/// specified return value (if any).
|
||||
static Optional<int64_t>
|
||||
getEquivalentFuncArgIdx(ModuleBufferizationState &state,
|
||||
OpOperand &returnOperand) {
|
||||
FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
|
||||
if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
|
||||
getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
|
||||
int64_t returnValIdx) {
|
||||
if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
|
||||
// Return value has no equivalent bbArg.
|
||||
return None;
|
||||
|
||||
return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
|
||||
return state.equivalentFuncArgs[funcOp][returnValIdx];
|
||||
}
|
||||
|
||||
struct CallOpInterface
|
||||
|
@ -529,6 +528,7 @@ struct CallOpInterface
|
|||
BufferizationState &state) const {
|
||||
CallOp callOp = cast<CallOp>(op);
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
unsigned numOperands = callOp->getNumOperands();
|
||||
FuncOp funcOp = getCalledFunction(callOp);
|
||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||
"expected CallOp to a FuncOp");
|
||||
|
@ -542,54 +542,48 @@ struct CallOpInterface
|
|||
// For non-tensor results: A mapping from return val indices of the old
|
||||
// CallOp to return val indices of the bufferized CallOp.
|
||||
SmallVector<Optional<unsigned>> retValMapping(numResults, None);
|
||||
// Operands of the bufferized CallOp.
|
||||
SmallVector<Value> newOperands(numOperands, Value());
|
||||
|
||||
if (funcOp.body().empty()) {
|
||||
// The callee is bodiless / external, so we cannot inspect it and we
|
||||
// cannot assume anything. We can just assert that it does not return a
|
||||
// tensor as this would have to bufferize to "return a memref", whose
|
||||
// semantics is ill-defined.
|
||||
for (int i = 0; i < numResults; ++i) {
|
||||
Type returnType = callOp.getResult(i).getType();
|
||||
if (isaTensor(returnType))
|
||||
return callOp->emitError()
|
||||
<< "cannot bufferize bodiless function that returns a tensor";
|
||||
// Based on previously gathered equivalence information, we know if a
|
||||
// tensor result folds onto an operand. These are the only tensor value
|
||||
// results that are supported at the moment.
|
||||
//
|
||||
// For tensors return values that do not fold onto an operand, additional
|
||||
// work is needed (TODO) to either:
|
||||
// * hoist a result into an inplaceable operand or
|
||||
// * devise a better representation to truly return a buffer.
|
||||
//
|
||||
// Note: If a function has no body, no equivalence information is
|
||||
// available. Consequently, a tensor return value cannot be proven to fold
|
||||
// onto a FuncOp bbArg, so calls to such functions are not bufferizable at
|
||||
// the moment.
|
||||
|
||||
// 1. Compute the result types of the new CallOp. Tensor results that are
|
||||
// equivalent to a FuncOp bbArg are no longer returned.
|
||||
for (auto it : llvm::enumerate(callOp.getResultTypes())) {
|
||||
unsigned returnValIdx = it.index();
|
||||
Type returnType = it.value();
|
||||
if (!isaTensor(returnType)) {
|
||||
// Non-tensor values are returned.
|
||||
retValMapping[returnValIdx] = resultTypes.size();
|
||||
resultTypes.push_back(returnType);
|
||||
retValMapping[i] = i;
|
||||
continue;
|
||||
}
|
||||
} else {
|
||||
// The callee has a body. Based on previously gathered equivalence
|
||||
// information, we know if a tensor result folds onto an operand. These
|
||||
// are the only tensor value returns that are supported at the moment.
|
||||
//
|
||||
// For tensors return values that do not fold onto an operand, additional
|
||||
// work is needed (TODO) to either:
|
||||
// * hoist a result into an inplaceable operand or
|
||||
// * devise a better representation to truly return a buffer.
|
||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||
assert(returnOp && "expected func with single return op");
|
||||
|
||||
// For each FuncOp result, keep track of which inplace argument it reuses.
|
||||
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
|
||||
unsigned returnIdx = returnOperand.getOperandNumber();
|
||||
Type returnType = returnOperand.get().getType();
|
||||
if (!isaTensor(returnType)) {
|
||||
// Non-tensor values are returned.
|
||||
retValMapping[returnIdx] = resultTypes.size();
|
||||
resultTypes.push_back(returnType);
|
||||
continue;
|
||||
}
|
||||
|
||||
if (Optional<int64_t> bbArgIdx =
|
||||
getEquivalentFuncArgIdx(moduleState, returnOperand)) {
|
||||
// Return operands that are equivalent to some bbArg, are not
|
||||
// returned.
|
||||
replacementValues[returnIdx] =
|
||||
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
|
||||
continue;
|
||||
}
|
||||
|
||||
llvm_unreachable("returning non-equivalent tensors not supported");
|
||||
if (Optional<int64_t> bbArgIdx =
|
||||
getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
|
||||
// Return operands that are equivalent to some bbArg, are not
|
||||
// returned.
|
||||
Value buffer =
|
||||
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
|
||||
replacementValues[returnValIdx] = buffer;
|
||||
newOperands[*bbArgIdx] = buffer;
|
||||
continue;
|
||||
}
|
||||
|
||||
return callOp->emitError(
|
||||
"call to FuncOp that returns non-equivalent tensors not supported");
|
||||
}
|
||||
|
||||
// 2. Compute bufferized FunctionType.
|
||||
|
@ -601,23 +595,26 @@ struct CallOpInterface
|
|||
moduleState.bufferizedFunctionTypes);
|
||||
|
||||
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
|
||||
SmallVector<Value> newOperands;
|
||||
newOperands.reserve(callOp->getNumOperands());
|
||||
for (OpOperand &opOperand : callOp->getOpOperands()) {
|
||||
unsigned idx = opOperand.getOperandNumber();
|
||||
Value tensorOperand = opOperand.get();
|
||||
|
||||
// Non-tensor operands are just copied.
|
||||
if (!tensorOperand.getType().isa<TensorType>()) {
|
||||
newOperands.push_back(tensorOperand);
|
||||
newOperands[idx] = tensorOperand;
|
||||
continue;
|
||||
}
|
||||
|
||||
// Tensor operands are guaranteed to have been buferized.
|
||||
int64_t idx = opOperand.getOperandNumber();
|
||||
Value buffer = state.lookupBuffer(rewriter, tensorOperand);
|
||||
// Retrieve buffers for tensor operands. Tensor operand buffers, who's
|
||||
// corresponding FuncOp bbArgs are equivalent to a returned tensor, were
|
||||
// already stored in `newOperands` during Step 1.
|
||||
Value buffer = newOperands[idx]
|
||||
? newOperands[idx]
|
||||
: state.lookupBuffer(rewriter, tensorOperand);
|
||||
|
||||
// Caller / callee type mistmatch is handled with a CastOp.
|
||||
auto memRefType = bufferizedFuncType.getInput(idx);
|
||||
// Since we don't yet have a clear layout story, buffer_cast may
|
||||
// Since we don't yet have a clear layout story, to_memref may
|
||||
// conservatively turn tensors into more dynamic memref than necessary.
|
||||
// If the memref type of the callee fails, introduce an extra memref.cast
|
||||
// that will either canonicalize away or fail compilation until we can do
|
||||
|
@ -627,20 +624,21 @@ struct CallOpInterface
|
|||
memRefType, buffer);
|
||||
buffer = castBuffer;
|
||||
}
|
||||
newOperands.push_back(buffer);
|
||||
newOperands[idx] = buffer;
|
||||
}
|
||||
|
||||
// 4. Create the new CallOp.
|
||||
Operation *newCallOp = rewriter.create<CallOp>(
|
||||
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
|
||||
newCallOp->setAttrs(callOp->getAttrs());
|
||||
|
||||
// 5. Replace the old op with the new op.
|
||||
// Get replacement values for non-tensor / non-equivalent results.
|
||||
for (int i = 0; i < replacementValues.size(); ++i) {
|
||||
if (replacementValues[i])
|
||||
continue;
|
||||
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
|
||||
}
|
||||
|
||||
// 5. Replace the old op with the new op.
|
||||
state.replaceOp(rewriter, callOp, replacementValues);
|
||||
|
||||
return success();
|
||||
|
|
|
@ -187,3 +187,26 @@ func @to_memref_op_is_writing(
|
|||
|
||||
return %r1, %r2 : vector<5xf32>, vector<5xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
|
||||
|
||||
func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
|
||||
// expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
|
||||
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) {
|
||||
%0 = linalg.init_tensor [5] : tensor<5xf32>
|
||||
return %0 : tensor<5xf32>
|
||||
}
|
||||
|
||||
func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) {
|
||||
// expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
|
||||
call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue