forked from OSchip/llvm-project
[mlir][linalg][bufferize][NFC] Simplify bufferization of CallOps
There is no need to inspect the ReturnOp of the called function. This change also refactors the bufferization of CallOps in such a way that `lookupBuffer` is called only a single time. This is important for a later change that fixes CallOp bufferization. (There is currently a TODO among the test cases.) Note: This change modifies a test case but is marked as NFC. There is no change of functionality, but FuncOps with empty bodies are now reported with a different error message. Differential Revision: https://reviews.llvm.org/D116446
This commit is contained in:
parent
66d4090d9b
commit
b15b0156ca
|
@ -490,17 +490,16 @@ namespace linalg {
|
||||||
namespace comprehensive_bufferize {
|
namespace comprehensive_bufferize {
|
||||||
namespace std_ext {
|
namespace std_ext {
|
||||||
|
|
||||||
/// Return the index of the parent function's bbArg that is equivalent to the
|
/// Return the index of the bbArg in the given FuncOp that is equivalent to the
|
||||||
/// given ReturnOp operand (if any).
|
/// specified return value (if any).
|
||||||
static Optional<int64_t>
|
static Optional<int64_t>
|
||||||
getEquivalentFuncArgIdx(ModuleBufferizationState &state,
|
getEquivalentFuncArgIdx(FuncOp funcOp, ModuleBufferizationState &state,
|
||||||
OpOperand &returnOperand) {
|
int64_t returnValIdx) {
|
||||||
FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
|
if (!state.equivalentFuncArgs[funcOp].count(returnValIdx))
|
||||||
if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
|
|
||||||
// Return value has no equivalent bbArg.
|
// Return value has no equivalent bbArg.
|
||||||
return None;
|
return None;
|
||||||
|
|
||||||
return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
|
return state.equivalentFuncArgs[funcOp][returnValIdx];
|
||||||
}
|
}
|
||||||
|
|
||||||
struct CallOpInterface
|
struct CallOpInterface
|
||||||
|
@ -529,6 +528,7 @@ struct CallOpInterface
|
||||||
BufferizationState &state) const {
|
BufferizationState &state) const {
|
||||||
CallOp callOp = cast<CallOp>(op);
|
CallOp callOp = cast<CallOp>(op);
|
||||||
unsigned numResults = callOp.getNumResults();
|
unsigned numResults = callOp.getNumResults();
|
||||||
|
unsigned numOperands = callOp->getNumOperands();
|
||||||
FuncOp funcOp = getCalledFunction(callOp);
|
FuncOp funcOp = getCalledFunction(callOp);
|
||||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||||
"expected CallOp to a FuncOp");
|
"expected CallOp to a FuncOp");
|
||||||
|
@ -542,54 +542,48 @@ struct CallOpInterface
|
||||||
// For non-tensor results: A mapping from return val indices of the old
|
// For non-tensor results: A mapping from return val indices of the old
|
||||||
// CallOp to return val indices of the bufferized CallOp.
|
// CallOp to return val indices of the bufferized CallOp.
|
||||||
SmallVector<Optional<unsigned>> retValMapping(numResults, None);
|
SmallVector<Optional<unsigned>> retValMapping(numResults, None);
|
||||||
|
// Operands of the bufferized CallOp.
|
||||||
|
SmallVector<Value> newOperands(numOperands, Value());
|
||||||
|
|
||||||
if (funcOp.body().empty()) {
|
// Based on previously gathered equivalence information, we know if a
|
||||||
// The callee is bodiless / external, so we cannot inspect it and we
|
// tensor result folds onto an operand. These are the only tensor value
|
||||||
// cannot assume anything. We can just assert that it does not return a
|
// results that are supported at the moment.
|
||||||
// tensor as this would have to bufferize to "return a memref", whose
|
|
||||||
// semantics is ill-defined.
|
|
||||||
for (int i = 0; i < numResults; ++i) {
|
|
||||||
Type returnType = callOp.getResult(i).getType();
|
|
||||||
if (isaTensor(returnType))
|
|
||||||
return callOp->emitError()
|
|
||||||
<< "cannot bufferize bodiless function that returns a tensor";
|
|
||||||
resultTypes.push_back(returnType);
|
|
||||||
retValMapping[i] = i;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// The callee has a body. Based on previously gathered equivalence
|
|
||||||
// information, we know if a tensor result folds onto an operand. These
|
|
||||||
// are the only tensor value returns that are supported at the moment.
|
|
||||||
//
|
//
|
||||||
// For tensors return values that do not fold onto an operand, additional
|
// For tensors return values that do not fold onto an operand, additional
|
||||||
// work is needed (TODO) to either:
|
// work is needed (TODO) to either:
|
||||||
// * hoist a result into an inplaceable operand or
|
// * hoist a result into an inplaceable operand or
|
||||||
// * devise a better representation to truly return a buffer.
|
// * devise a better representation to truly return a buffer.
|
||||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
//
|
||||||
assert(returnOp && "expected func with single return op");
|
// Note: If a function has no body, no equivalence information is
|
||||||
|
// available. Consequently, a tensor return value cannot be proven to fold
|
||||||
|
// onto a FuncOp bbArg, so calls to such functions are not bufferizable at
|
||||||
|
// the moment.
|
||||||
|
|
||||||
// For each FuncOp result, keep track of which inplace argument it reuses.
|
// 1. Compute the result types of the new CallOp. Tensor results that are
|
||||||
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
|
// equivalent to a FuncOp bbArg are no longer returned.
|
||||||
unsigned returnIdx = returnOperand.getOperandNumber();
|
for (auto it : llvm::enumerate(callOp.getResultTypes())) {
|
||||||
Type returnType = returnOperand.get().getType();
|
unsigned returnValIdx = it.index();
|
||||||
|
Type returnType = it.value();
|
||||||
if (!isaTensor(returnType)) {
|
if (!isaTensor(returnType)) {
|
||||||
// Non-tensor values are returned.
|
// Non-tensor values are returned.
|
||||||
retValMapping[returnIdx] = resultTypes.size();
|
retValMapping[returnValIdx] = resultTypes.size();
|
||||||
resultTypes.push_back(returnType);
|
resultTypes.push_back(returnType);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
if (Optional<int64_t> bbArgIdx =
|
if (Optional<int64_t> bbArgIdx =
|
||||||
getEquivalentFuncArgIdx(moduleState, returnOperand)) {
|
getEquivalentFuncArgIdx(funcOp, moduleState, returnValIdx)) {
|
||||||
// Return operands that are equivalent to some bbArg, are not
|
// Return operands that are equivalent to some bbArg, are not
|
||||||
// returned.
|
// returned.
|
||||||
replacementValues[returnIdx] =
|
Value buffer =
|
||||||
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
|
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
|
||||||
|
replacementValues[returnValIdx] = buffer;
|
||||||
|
newOperands[*bbArgIdx] = buffer;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
llvm_unreachable("returning non-equivalent tensors not supported");
|
return callOp->emitError(
|
||||||
}
|
"call to FuncOp that returns non-equivalent tensors not supported");
|
||||||
}
|
}
|
||||||
|
|
||||||
// 2. Compute bufferized FunctionType.
|
// 2. Compute bufferized FunctionType.
|
||||||
|
@ -601,23 +595,26 @@ struct CallOpInterface
|
||||||
moduleState.bufferizedFunctionTypes);
|
moduleState.bufferizedFunctionTypes);
|
||||||
|
|
||||||
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
|
// 3. Rewrite tensor operands as memrefs based on `bufferizedFuncType`.
|
||||||
SmallVector<Value> newOperands;
|
|
||||||
newOperands.reserve(callOp->getNumOperands());
|
|
||||||
for (OpOperand &opOperand : callOp->getOpOperands()) {
|
for (OpOperand &opOperand : callOp->getOpOperands()) {
|
||||||
|
unsigned idx = opOperand.getOperandNumber();
|
||||||
Value tensorOperand = opOperand.get();
|
Value tensorOperand = opOperand.get();
|
||||||
|
|
||||||
// Non-tensor operands are just copied.
|
// Non-tensor operands are just copied.
|
||||||
if (!tensorOperand.getType().isa<TensorType>()) {
|
if (!tensorOperand.getType().isa<TensorType>()) {
|
||||||
newOperands.push_back(tensorOperand);
|
newOperands[idx] = tensorOperand;
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Tensor operands are guaranteed to have been buferized.
|
// Retrieve buffers for tensor operands. Tensor operand buffers, who's
|
||||||
int64_t idx = opOperand.getOperandNumber();
|
// corresponding FuncOp bbArgs are equivalent to a returned tensor, were
|
||||||
Value buffer = state.lookupBuffer(rewriter, tensorOperand);
|
// already stored in `newOperands` during Step 1.
|
||||||
|
Value buffer = newOperands[idx]
|
||||||
|
? newOperands[idx]
|
||||||
|
: state.lookupBuffer(rewriter, tensorOperand);
|
||||||
|
|
||||||
// Caller / callee type mistmatch is handled with a CastOp.
|
// Caller / callee type mistmatch is handled with a CastOp.
|
||||||
auto memRefType = bufferizedFuncType.getInput(idx);
|
auto memRefType = bufferizedFuncType.getInput(idx);
|
||||||
// Since we don't yet have a clear layout story, buffer_cast may
|
// Since we don't yet have a clear layout story, to_memref may
|
||||||
// conservatively turn tensors into more dynamic memref than necessary.
|
// conservatively turn tensors into more dynamic memref than necessary.
|
||||||
// If the memref type of the callee fails, introduce an extra memref.cast
|
// If the memref type of the callee fails, introduce an extra memref.cast
|
||||||
// that will either canonicalize away or fail compilation until we can do
|
// that will either canonicalize away or fail compilation until we can do
|
||||||
|
@ -627,20 +624,21 @@ struct CallOpInterface
|
||||||
memRefType, buffer);
|
memRefType, buffer);
|
||||||
buffer = castBuffer;
|
buffer = castBuffer;
|
||||||
}
|
}
|
||||||
newOperands.push_back(buffer);
|
newOperands[idx] = buffer;
|
||||||
}
|
}
|
||||||
|
|
||||||
// 4. Create the new CallOp.
|
// 4. Create the new CallOp.
|
||||||
Operation *newCallOp = rewriter.create<CallOp>(
|
Operation *newCallOp = rewriter.create<CallOp>(
|
||||||
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
|
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
|
||||||
newCallOp->setAttrs(callOp->getAttrs());
|
newCallOp->setAttrs(callOp->getAttrs());
|
||||||
|
// Get replacement values for non-tensor / non-equivalent results.
|
||||||
// 5. Replace the old op with the new op.
|
|
||||||
for (int i = 0; i < replacementValues.size(); ++i) {
|
for (int i = 0; i < replacementValues.size(); ++i) {
|
||||||
if (replacementValues[i])
|
if (replacementValues[i])
|
||||||
continue;
|
continue;
|
||||||
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
|
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// 5. Replace the old op with the new op.
|
||||||
state.replaceOp(rewriter, callOp, replacementValues);
|
state.replaceOp(rewriter, callOp, replacementValues);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
|
|
|
@ -187,3 +187,26 @@ func @to_memref_op_is_writing(
|
||||||
|
|
||||||
return %r1, %r2 : vector<5xf32>, vector<5xf32>
|
return %r1, %r2 : vector<5xf32>, vector<5xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func private @foo(%t : tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
|
||||||
|
|
||||||
|
func @call_to_unknown_tensor_returning_func(%t : tensor<?xf32>) {
|
||||||
|
// expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
|
||||||
|
call @foo(%t) : (tensor<?xf32>) -> (f32, tensor<?xf32>, f32)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @foo(%t : tensor<5xf32>) -> (tensor<5xf32>) {
|
||||||
|
%0 = linalg.init_tensor [5] : tensor<5xf32>
|
||||||
|
return %0 : tensor<5xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
func @call_to_func_returning_non_equiv_tensor(%t : tensor<5xf32>) {
|
||||||
|
// expected-error @+1 {{call to FuncOp that returns non-equivalent tensors not supported}}
|
||||||
|
call @foo(%t) : (tensor<5xf32>) -> (tensor<5xf32>)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue