[mlir][linalg][bufferize] Fix CallOps with non-tensor operands

Such CallOps were not handled properly. When computing the new result types (and replacement values) of a CallOp, non-tensor return values were not accounted for.

Differential Revision: https://reviews.llvm.org/D116445
This commit is contained in:
Matthias Springer 2022-01-06 00:13:55 +09:00
parent d716cfc4fa
commit a98c5a08b1
2 changed files with 88 additions and 32 deletions

View File

@ -490,6 +490,19 @@ namespace linalg {
namespace comprehensive_bufferize { namespace comprehensive_bufferize {
namespace std_ext { namespace std_ext {
/// Return the index of the parent function's bbArg that is equivalent to the
/// given ReturnOp operand (if any).
static Optional<int64_t>
getEquivalentFuncArgIdx(ModuleBufferizationState &state,
OpOperand &returnOperand) {
FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
// Return value has no equivalent bbArg.
return None;
return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
}
struct CallOpInterface struct CallOpInterface
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> { : public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand, bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
@ -515,57 +528,67 @@ struct CallOpInterface
LogicalResult bufferize(Operation *op, RewriterBase &rewriter, LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
BufferizationState &state) const { BufferizationState &state) const {
CallOp callOp = cast<CallOp>(op); CallOp callOp = cast<CallOp>(op);
unsigned numResults = callOp.getNumResults();
FuncOp funcOp = getCalledFunction(callOp); FuncOp funcOp = getCalledFunction(callOp);
assert(isa<CallOp>(callOp.getOperation()) && funcOp && assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
"expected Callop to a FuncOp"); "expected CallOp to a FuncOp");
ModuleBufferizationState &moduleState = getModuleBufferizationState(state); ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
// 1. Filter return types: // Result types of the bufferized CallOp.
// - if the callee is bodiless / external, we cannot inspect it and we
// cannot assume anything. We can just assert that it does not return a
// tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
// - if the callee has a body, we perform inter-procedural equivalence
// analysis. When successful, a result folds onto an operand. When
// unsuccessful, additional work is needed (TODO) to either:
// * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer.
SmallVector<Type> resultTypes; SmallVector<Type> resultTypes;
// Replacement values for the existing CallOp. These are usually the results
// of the bufferized CallOp, unless a tensor result folds onto an operand.
SmallVector<Value> replacementValues(numResults, Value());
// For non-tensor results: A mapping from return val indices of the old
// CallOp to return val indices of the bufferized CallOp.
SmallVector<Optional<unsigned>> retValMapping(numResults, None);
if (funcOp.body().empty()) { if (funcOp.body().empty()) {
if (llvm::any_of(funcOp.getType().getResults(), isaTensor)) // The callee is bodiless / external, so we cannot inspect it and we
return callOp->emitError() // cannot assume anything. We can just assert that it does not return a
<< "cannot bufferize bodiless function that returns a tensor"; // tensor as this would have to bufferize to "return a memref", whose
// semantics is ill-defined.
for (int i = 0; i < numResults; ++i) {
Type returnType = callOp.getResult(i).getType();
if (isaTensor(returnType))
return callOp->emitError()
<< "cannot bufferize bodiless function that returns a tensor";
resultTypes.push_back(returnType);
retValMapping[i] = i;
}
} else { } else {
// The callee has a body. Based on previously gathered equivalence
// information, we know if a tensor result folds onto an operand. These
// are the only tensor value returns that are supported at the moment.
//
// For tensors return values that do not fold onto an operand, additional
// work is needed (TODO) to either:
// * hoist a result into an inplaceable operand or
// * devise a better representation to truly return a buffer.
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp); ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
assert(returnOp && "expected func with single return op"); assert(returnOp && "expected func with single return op");
// For each FuncOp result, keep track of which inplace argument it reuses. // For each FuncOp result, keep track of which inplace argument it reuses.
for (OpOperand &returnOperand : returnOp->getOpOperands()) { for (OpOperand &returnOperand : returnOp->getOpOperands()) {
unsigned returnIdx = returnOperand.getOperandNumber();
Type returnType = returnOperand.get().getType(); Type returnType = returnOperand.get().getType();
if (!isaTensor(returnType)) { if (!isaTensor(returnType)) {
// Non-tensor values are returned.
retValMapping[returnIdx] = resultTypes.size();
resultTypes.push_back(returnType); resultTypes.push_back(returnType);
continue; continue;
} }
// If return operand is equivalent to some bbArg, no need to return it. if (Optional<int64_t> bbArgIdx =
if (moduleState.equivalentFuncArgs[funcOp].count( getEquivalentFuncArgIdx(moduleState, returnOperand)) {
returnOperand.getOperandNumber())) { // Return operands that are equivalent to some bbArg, are not
int64_t idx = // returned.
moduleState replacementValues[returnIdx] =
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()]; state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx));
// Add a ToTensorOp to kill all uses of the CallOp return.
// Replace all uses of the CallOp results so we can erase the CallOp.
// This ToTensorOp must fold/DCE away or bufferization should be
// considered failed.
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
callOp.getLoc(), buffer);
oldRes.replaceAllUsesWith(toTensorOp);
continue; continue;
} }
resultTypes.push_back(returnType); llvm_unreachable("returning non-equivalent tensors not supported");
} }
} }
@ -612,8 +635,13 @@ struct CallOpInterface
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands); callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
newCallOp->setAttrs(callOp->getAttrs()); newCallOp->setAttrs(callOp->getAttrs());
// 5. Delete the op at the end of bufferization. // 5. Replace the old op with the new op.
callOp->erase(); for (int i = 0; i < replacementValues.size(); ++i) {
if (replacementValues[i])
continue;
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
}
state.replaceOp(rewriter, callOp, replacementValues);
return success(); return success();
} }

View File

@ -1000,6 +1000,34 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
// ----- // -----
// CHECK-LABEL: func @inner_func(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
// CHECK-NOT: copy
%f = arith.constant 1.0 : f32
%c0 = arith.constant 0 : index
%c1 = arith.constant 1 : index
// CHECK: memref.store %{{.*}}, %[[arg0]]
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
// CHECK: %[[load:.*]] = memref.load %[[arg0]]
%1 = tensor.extract %0[%c1] : tensor<?xf32>
// CHECK: return %[[load]] : f32
return %0, %1 : tensor<?xf32>, f32
}
// CHECK-LABEL: func @call_func_with_non_tensor_return(
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
func @call_func_with_non_tensor_return(
%t0: tensor<?xf32> {linalg.inplaceable = true}) -> (f32, tensor<?xf32>) {
// CHECK-NOT: copy
// CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
// CHECK: return %[[call]] : f32
return %1, %0 : f32, tensor<?xf32>
}
// -----
// CHECK-LABEL: func @func_without_tensor_args // CHECK-LABEL: func @func_without_tensor_args
func @func_without_tensor_args(%v : vector<10xf32>) -> () { func @func_without_tensor_args(%v : vector<10xf32>) -> () {
// CHECK: %[[alloc:.*]] = memref.alloc() // CHECK: %[[alloc:.*]] = memref.alloc()