forked from OSchip/llvm-project
[mlir][linalg][bufferize] Fix CallOps with non-tensor operands
Such CallOps were not handled properly. When computing the new result types (and replacement values) of a CallOp, non-tensor return values were not accounted for. Differential Revision: https://reviews.llvm.org/D116445
This commit is contained in:
parent
d716cfc4fa
commit
a98c5a08b1
|
@ -490,6 +490,19 @@ namespace linalg {
|
|||
namespace comprehensive_bufferize {
|
||||
namespace std_ext {
|
||||
|
||||
/// Return the index of the parent function's bbArg that is equivalent to the
|
||||
/// given ReturnOp operand (if any).
|
||||
static Optional<int64_t>
|
||||
getEquivalentFuncArgIdx(ModuleBufferizationState &state,
|
||||
OpOperand &returnOperand) {
|
||||
FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
|
||||
if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
|
||||
// Return value has no equivalent bbArg.
|
||||
return None;
|
||||
|
||||
return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
|
||||
}
|
||||
|
||||
struct CallOpInterface
|
||||
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
|
||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||
|
@ -515,57 +528,67 @@ struct CallOpInterface
|
|||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
BufferizationState &state) const {
|
||||
CallOp callOp = cast<CallOp>(op);
|
||||
unsigned numResults = callOp.getNumResults();
|
||||
FuncOp funcOp = getCalledFunction(callOp);
|
||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||
"expected Callop to a FuncOp");
|
||||
"expected CallOp to a FuncOp");
|
||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||
|
||||
// 1. Filter return types:
|
||||
// - if the callee is bodiless / external, we cannot inspect it and we
|
||||
// cannot assume anything. We can just assert that it does not return a
|
||||
// tensor as this would have to bufferize to "return a memref", whose
|
||||
// semantics is ill-defined.
|
||||
// - if the callee has a body, we perform inter-procedural equivalence
|
||||
// analysis. When successful, a result folds onto an operand. When
|
||||
// unsuccessful, additional work is needed (TODO) to either:
|
||||
// * hoist a result into an inplaceable operand or
|
||||
// * devise a better representation to truly return a buffer.
|
||||
// Result types of the bufferized CallOp.
|
||||
SmallVector<Type> resultTypes;
|
||||
// Replacement values for the existing CallOp. These are usually the results
|
||||
// of the bufferized CallOp, unless a tensor result folds onto an operand.
|
||||
SmallVector<Value> replacementValues(numResults, Value());
|
||||
// For non-tensor results: A mapping from return val indices of the old
|
||||
// CallOp to return val indices of the bufferized CallOp.
|
||||
SmallVector<Optional<unsigned>> retValMapping(numResults, None);
|
||||
|
||||
if (funcOp.body().empty()) {
|
||||
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
|
||||
return callOp->emitError()
|
||||
<< "cannot bufferize bodiless function that returns a tensor";
|
||||
// The callee is bodiless / external, so we cannot inspect it and we
|
||||
// cannot assume anything. We can just assert that it does not return a
|
||||
// tensor as this would have to bufferize to "return a memref", whose
|
||||
// semantics is ill-defined.
|
||||
for (int i = 0; i < numResults; ++i) {
|
||||
Type returnType = callOp.getResult(i).getType();
|
||||
if (isaTensor(returnType))
|
||||
return callOp->emitError()
|
||||
<< "cannot bufferize bodiless function that returns a tensor";
|
||||
resultTypes.push_back(returnType);
|
||||
retValMapping[i] = i;
|
||||
}
|
||||
} else {
|
||||
// The callee has a body. Based on previously gathered equivalence
|
||||
// information, we know if a tensor result folds onto an operand. These
|
||||
// are the only tensor value returns that are supported at the moment.
|
||||
//
|
||||
// For tensors return values that do not fold onto an operand, additional
|
||||
// work is needed (TODO) to either:
|
||||
// * hoist a result into an inplaceable operand or
|
||||
// * devise a better representation to truly return a buffer.
|
||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||
assert(returnOp && "expected func with single return op");
|
||||
|
||||
// For each FuncOp result, keep track of which inplace argument it reuses.
|
||||
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
|
||||
unsigned returnIdx = returnOperand.getOperandNumber();
|
||||
Type returnType = returnOperand.get().getType();
|
||||
if (!isaTensor(returnType)) {
|
||||
// Non-tensor values are returned.
|
||||
retValMapping[returnIdx] = resultTypes.size();
|
||||
resultTypes.push_back(returnType);
|
||||
continue;
|
||||
}
|
||||
|
||||
// If return operand is equivalent to some bbArg, no need to return it.
|
||||
if (moduleState.equivalentFuncArgs[funcOp].count(
|
||||
returnOperand.getOperandNumber())) {
|
||||
int64_t idx =
|
||||
moduleState
|
||||
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
|
||||
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
|
||||
Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx));
|
||||
// Add a ToTensorOp to kill all uses of the CallOp return.
|
||||
// Replace all uses of the CallOp results so we can erase the CallOp.
|
||||
// This ToTensorOp must fold/DCE away or bufferization should be
|
||||
// considered failed.
|
||||
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
|
||||
callOp.getLoc(), buffer);
|
||||
oldRes.replaceAllUsesWith(toTensorOp);
|
||||
if (Optional<int64_t> bbArgIdx =
|
||||
getEquivalentFuncArgIdx(moduleState, returnOperand)) {
|
||||
// Return operands that are equivalent to some bbArg, are not
|
||||
// returned.
|
||||
replacementValues[returnIdx] =
|
||||
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
|
||||
continue;
|
||||
}
|
||||
|
||||
resultTypes.push_back(returnType);
|
||||
llvm_unreachable("returning non-equivalent tensors not supported");
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -612,8 +635,13 @@ struct CallOpInterface
|
|||
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
|
||||
newCallOp->setAttrs(callOp->getAttrs());
|
||||
|
||||
// 5. Delete the op at the end of bufferization.
|
||||
callOp->erase();
|
||||
// 5. Replace the old op with the new op.
|
||||
for (int i = 0; i < replacementValues.size(); ++i) {
|
||||
if (replacementValues[i])
|
||||
continue;
|
||||
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
|
||||
}
|
||||
state.replaceOp(rewriter, callOp, replacementValues);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -1000,6 +1000,34 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @inner_func(
|
||||
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
|
||||
func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
|
||||
// CHECK-NOT: copy
|
||||
%f = arith.constant 1.0 : f32
|
||||
%c0 = arith.constant 0 : index
|
||||
%c1 = arith.constant 1 : index
|
||||
// CHECK: memref.store %{{.*}}, %[[arg0]]
|
||||
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
|
||||
// CHECK: %[[load:.*]] = memref.load %[[arg0]]
|
||||
%1 = tensor.extract %0[%c1] : tensor<?xf32>
|
||||
// CHECK: return %[[load]] : f32
|
||||
return %0, %1 : tensor<?xf32>, f32
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @call_func_with_non_tensor_return(
|
||||
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
|
||||
func @call_func_with_non_tensor_return(
|
||||
%t0: tensor<?xf32> {linalg.inplaceable = true}) -> (f32, tensor<?xf32>) {
|
||||
// CHECK-NOT: copy
|
||||
// CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
|
||||
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
|
||||
// CHECK: return %[[call]] : f32
|
||||
return %1, %0 : f32, tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @func_without_tensor_args
|
||||
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
|
||||
// CHECK: %[[alloc:.*]] = memref.alloc()
|
||||
|
|
Loading…
Reference in New Issue