forked from OSchip/llvm-project
[mlir][linalg][bufferize] Fix CallOps with non-tensor operands
Such CallOps were not handled properly. When computing the new result types (and replacement values) of a CallOp, non-tensor return values were not accounted for. Differential Revision: https://reviews.llvm.org/D116445
This commit is contained in:
parent
d716cfc4fa
commit
a98c5a08b1
|
@ -490,6 +490,19 @@ namespace linalg {
|
||||||
namespace comprehensive_bufferize {
|
namespace comprehensive_bufferize {
|
||||||
namespace std_ext {
|
namespace std_ext {
|
||||||
|
|
||||||
|
/// Return the index of the parent function's bbArg that is equivalent to the
|
||||||
|
/// given ReturnOp operand (if any).
|
||||||
|
static Optional<int64_t>
|
||||||
|
getEquivalentFuncArgIdx(ModuleBufferizationState &state,
|
||||||
|
OpOperand &returnOperand) {
|
||||||
|
FuncOp funcOp = cast<FuncOp>(returnOperand.getOwner()->getParentOp());
|
||||||
|
if (!state.equivalentFuncArgs[funcOp].count(returnOperand.getOperandNumber()))
|
||||||
|
// Return value has no equivalent bbArg.
|
||||||
|
return None;
|
||||||
|
|
||||||
|
return state.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
|
||||||
|
}
|
||||||
|
|
||||||
struct CallOpInterface
|
struct CallOpInterface
|
||||||
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
|
: public BufferizableOpInterface::ExternalModel<CallOpInterface, CallOp> {
|
||||||
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
bool bufferizesToMemoryRead(Operation *op, OpOperand &opOperand,
|
||||||
|
@ -515,57 +528,67 @@ struct CallOpInterface
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
BufferizationState &state) const {
|
BufferizationState &state) const {
|
||||||
CallOp callOp = cast<CallOp>(op);
|
CallOp callOp = cast<CallOp>(op);
|
||||||
|
unsigned numResults = callOp.getNumResults();
|
||||||
FuncOp funcOp = getCalledFunction(callOp);
|
FuncOp funcOp = getCalledFunction(callOp);
|
||||||
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
assert(isa<CallOp>(callOp.getOperation()) && funcOp &&
|
||||||
"expected Callop to a FuncOp");
|
"expected CallOp to a FuncOp");
|
||||||
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
ModuleBufferizationState &moduleState = getModuleBufferizationState(state);
|
||||||
|
|
||||||
// 1. Filter return types:
|
// Result types of the bufferized CallOp.
|
||||||
// - if the callee is bodiless / external, we cannot inspect it and we
|
|
||||||
// cannot assume anything. We can just assert that it does not return a
|
|
||||||
// tensor as this would have to bufferize to "return a memref", whose
|
|
||||||
// semantics is ill-defined.
|
|
||||||
// - if the callee has a body, we perform inter-procedural equivalence
|
|
||||||
// analysis. When successful, a result folds onto an operand. When
|
|
||||||
// unsuccessful, additional work is needed (TODO) to either:
|
|
||||||
// * hoist a result into an inplaceable operand or
|
|
||||||
// * devise a better representation to truly return a buffer.
|
|
||||||
SmallVector<Type> resultTypes;
|
SmallVector<Type> resultTypes;
|
||||||
|
// Replacement values for the existing CallOp. These are usually the results
|
||||||
|
// of the bufferized CallOp, unless a tensor result folds onto an operand.
|
||||||
|
SmallVector<Value> replacementValues(numResults, Value());
|
||||||
|
// For non-tensor results: A mapping from return val indices of the old
|
||||||
|
// CallOp to return val indices of the bufferized CallOp.
|
||||||
|
SmallVector<Optional<unsigned>> retValMapping(numResults, None);
|
||||||
|
|
||||||
if (funcOp.body().empty()) {
|
if (funcOp.body().empty()) {
|
||||||
if (llvm::any_of(funcOp.getType().getResults(), isaTensor))
|
// The callee is bodiless / external, so we cannot inspect it and we
|
||||||
return callOp->emitError()
|
// cannot assume anything. We can just assert that it does not return a
|
||||||
<< "cannot bufferize bodiless function that returns a tensor";
|
// tensor as this would have to bufferize to "return a memref", whose
|
||||||
|
// semantics is ill-defined.
|
||||||
|
for (int i = 0; i < numResults; ++i) {
|
||||||
|
Type returnType = callOp.getResult(i).getType();
|
||||||
|
if (isaTensor(returnType))
|
||||||
|
return callOp->emitError()
|
||||||
|
<< "cannot bufferize bodiless function that returns a tensor";
|
||||||
|
resultTypes.push_back(returnType);
|
||||||
|
retValMapping[i] = i;
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
|
// The callee has a body. Based on previously gathered equivalence
|
||||||
|
// information, we know if a tensor result folds onto an operand. These
|
||||||
|
// are the only tensor value returns that are supported at the moment.
|
||||||
|
//
|
||||||
|
// For tensors return values that do not fold onto an operand, additional
|
||||||
|
// work is needed (TODO) to either:
|
||||||
|
// * hoist a result into an inplaceable operand or
|
||||||
|
// * devise a better representation to truly return a buffer.
|
||||||
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
ReturnOp returnOp = getAssumedUniqueReturnOp(funcOp);
|
||||||
assert(returnOp && "expected func with single return op");
|
assert(returnOp && "expected func with single return op");
|
||||||
|
|
||||||
// For each FuncOp result, keep track of which inplace argument it reuses.
|
// For each FuncOp result, keep track of which inplace argument it reuses.
|
||||||
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
|
for (OpOperand &returnOperand : returnOp->getOpOperands()) {
|
||||||
|
unsigned returnIdx = returnOperand.getOperandNumber();
|
||||||
Type returnType = returnOperand.get().getType();
|
Type returnType = returnOperand.get().getType();
|
||||||
if (!isaTensor(returnType)) {
|
if (!isaTensor(returnType)) {
|
||||||
|
// Non-tensor values are returned.
|
||||||
|
retValMapping[returnIdx] = resultTypes.size();
|
||||||
resultTypes.push_back(returnType);
|
resultTypes.push_back(returnType);
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
// If return operand is equivalent to some bbArg, no need to return it.
|
if (Optional<int64_t> bbArgIdx =
|
||||||
if (moduleState.equivalentFuncArgs[funcOp].count(
|
getEquivalentFuncArgIdx(moduleState, returnOperand)) {
|
||||||
returnOperand.getOperandNumber())) {
|
// Return operands that are equivalent to some bbArg, are not
|
||||||
int64_t idx =
|
// returned.
|
||||||
moduleState
|
replacementValues[returnIdx] =
|
||||||
.equivalentFuncArgs[funcOp][returnOperand.getOperandNumber()];
|
state.lookupBuffer(rewriter, callOp->getOperand(*bbArgIdx));
|
||||||
Value oldRes = callOp->getResult(returnOperand.getOperandNumber());
|
|
||||||
Value buffer = state.lookupBuffer(rewriter, callOp->getOperand(idx));
|
|
||||||
// Add a ToTensorOp to kill all uses of the CallOp return.
|
|
||||||
// Replace all uses of the CallOp results so we can erase the CallOp.
|
|
||||||
// This ToTensorOp must fold/DCE away or bufferization should be
|
|
||||||
// considered failed.
|
|
||||||
Value toTensorOp = rewriter.create<bufferization::ToTensorOp>(
|
|
||||||
callOp.getLoc(), buffer);
|
|
||||||
oldRes.replaceAllUsesWith(toTensorOp);
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
resultTypes.push_back(returnType);
|
llvm_unreachable("returning non-equivalent tensors not supported");
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -612,8 +635,13 @@ struct CallOpInterface
|
||||||
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
|
callOp.getLoc(), funcOp.sym_name(), resultTypes, newOperands);
|
||||||
newCallOp->setAttrs(callOp->getAttrs());
|
newCallOp->setAttrs(callOp->getAttrs());
|
||||||
|
|
||||||
// 5. Delete the op at the end of bufferization.
|
// 5. Replace the old op with the new op.
|
||||||
callOp->erase();
|
for (int i = 0; i < replacementValues.size(); ++i) {
|
||||||
|
if (replacementValues[i])
|
||||||
|
continue;
|
||||||
|
replacementValues[i] = newCallOp->getResult(*retValMapping[i]);
|
||||||
|
}
|
||||||
|
state.replaceOp(rewriter, callOp, replacementValues);
|
||||||
|
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
|
@ -1000,6 +1000,34 @@ func @equivalent_func_arg_2(%t0: tensor<?xf32> {linalg.inplaceable = true},
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @inner_func(
|
||||||
|
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
|
||||||
|
func @inner_func(%t: tensor<?xf32>) -> (tensor<?xf32>, f32) {
|
||||||
|
// CHECK-NOT: copy
|
||||||
|
%f = arith.constant 1.0 : f32
|
||||||
|
%c0 = arith.constant 0 : index
|
||||||
|
%c1 = arith.constant 1 : index
|
||||||
|
// CHECK: memref.store %{{.*}}, %[[arg0]]
|
||||||
|
%0 = tensor.insert %f into %t[%c0] : tensor<?xf32>
|
||||||
|
// CHECK: %[[load:.*]] = memref.load %[[arg0]]
|
||||||
|
%1 = tensor.extract %0[%c1] : tensor<?xf32>
|
||||||
|
// CHECK: return %[[load]] : f32
|
||||||
|
return %0, %1 : tensor<?xf32>, f32
|
||||||
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: func @call_func_with_non_tensor_return(
|
||||||
|
// CHECK-SAME: %[[arg0:.*]]: memref<?xf32
|
||||||
|
func @call_func_with_non_tensor_return(
|
||||||
|
%t0: tensor<?xf32> {linalg.inplaceable = true}) -> (f32, tensor<?xf32>) {
|
||||||
|
// CHECK-NOT: copy
|
||||||
|
// CHECK: %[[call:.*]] = call @inner_func(%[[arg0]])
|
||||||
|
%0, %1 = call @inner_func(%t0) : (tensor<?xf32>) -> (tensor<?xf32>, f32)
|
||||||
|
// CHECK: return %[[call]] : f32
|
||||||
|
return %1, %0 : f32, tensor<?xf32>
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
// CHECK-LABEL: func @func_without_tensor_args
|
// CHECK-LABEL: func @func_without_tensor_args
|
||||||
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
|
func @func_without_tensor_args(%v : vector<10xf32>) -> () {
|
||||||
// CHECK: %[[alloc:.*]] = memref.alloc()
|
// CHECK: %[[alloc:.*]] = memref.alloc()
|
||||||
|
|
Loading…
Reference in New Issue