forked from OSchip/llvm-project
[mlir][linalg][bufferize] Support scf.execute_region bufferization
This op is needed for unit testing in a subsequent revision. (This is the first op that has a block that yields equivalent values via the op's results.) Note: Bufferization of scf.execute_region ops with multiple blocks is not yet supported. Differential Revision: https://reviews.llvm.org/D117424
This commit is contained in:
parent
b9d85a5231
commit
b83c67d978
|
@ -44,6 +44,7 @@ struct ExecuteRegionOpInterface
|
|||
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
||||
size_t resultNum = std::distance(op->getOpResults().begin(),
|
||||
llvm::find(op->getOpResults(), opResult));
|
||||
// TODO: Support multiple blocks.
|
||||
assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
|
||||
"expected exactly 1 block");
|
||||
auto yieldOp = dyn_cast<scf::YieldOp>(
|
||||
|
@ -66,13 +67,59 @@ struct ExecuteRegionOpInterface
|
|||
|
||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||
const BufferizationState &state) const {
|
||||
// TODO: Add bufferization support when needed. scf.execute_region should be
|
||||
// bufferized similar to scf.if.
|
||||
bool hasTensorReturnType = any_of(
|
||||
op->getResultTypes(), [](Type t) { return t.isa<TensorType>(); });
|
||||
if (hasTensorReturnType)
|
||||
return op->emitError(
|
||||
"scf.execute_region with tensor result not supported");
|
||||
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
||||
|
||||
// Compute new result types.
|
||||
SmallVector<Type> newResultTypes;
|
||||
for (Type type : executeRegionOp->getResultTypes()) {
|
||||
if (auto rankedTensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
newResultTypes.push_back(getDynamicMemRefType(rankedTensorType));
|
||||
} else if (auto tensorType = type.dyn_cast<TensorType>()) {
|
||||
newResultTypes.push_back(
|
||||
getUnrankedMemRefType(tensorType.getElementType()));
|
||||
} else {
|
||||
newResultTypes.push_back(type);
|
||||
}
|
||||
}
|
||||
|
||||
// Create new op and move over region.
|
||||
auto newOp =
|
||||
rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
|
||||
newOp.getRegion().takeBody(executeRegionOp.getRegion());
|
||||
|
||||
// Update terminator.
|
||||
assert(newOp.getRegion().getBlocks().size() == 1 &&
|
||||
"only 1 block supported");
|
||||
Block *newBlock = &newOp.getRegion().front();
|
||||
auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
|
||||
rewriter.setInsertionPoint(yieldOp);
|
||||
SmallVector<Value> newYieldValues;
|
||||
for (auto it : llvm::enumerate(yieldOp.getResults())) {
|
||||
Value val = it.value();
|
||||
if (val.getType().isa<TensorType>()) {
|
||||
newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
|
||||
yieldOp.getLoc(), newResultTypes[it.index()], val));
|
||||
} else {
|
||||
newYieldValues.push_back(val);
|
||||
}
|
||||
}
|
||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
|
||||
|
||||
// Update all uses of the old op.
|
||||
rewriter.setInsertionPointAfter(newOp);
|
||||
SmallVector<Value> newResults;
|
||||
for (auto it : llvm::enumerate(executeRegionOp->getResultTypes())) {
|
||||
if (it.value().isa<TensorType>()) {
|
||||
newResults.push_back(rewriter.create<bufferization::ToTensorOp>(
|
||||
executeRegionOp.getLoc(), newOp->getResult(it.index())));
|
||||
} else {
|
||||
newResults.push_back(newOp->getResult(it.index()));
|
||||
}
|
||||
}
|
||||
|
||||
// Replace old op.
|
||||
rewriter.replaceOp(executeRegionOp, newResults);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -159,8 +159,8 @@ func @mini_test_case1() -> tensor<10x20xf32> {
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{memref return type is unsupported}}
|
||||
func @main() -> tensor<4xi32> {
|
||||
// expected-error @+1 {{scf.execute_region with tensor result not supported}}
|
||||
%r = scf.execute_region -> tensor<4xi32> {
|
||||
%A = arith.constant dense<[1, 2, 3, 4]> : tensor<4xi32>
|
||||
scf.yield %A: tensor<4xi32>
|
||||
|
|
|
@ -446,6 +446,59 @@ func @main() {
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @execute_region_test(
|
||||
// CHECK-SAME: %[[m1:.*]]: memref<?xf32
|
||||
func @execute_region_test(%t1 : tensor<?xf32> {linalg.inplaceable = "true"})
|
||||
-> (f32, tensor<?xf32>, f32)
|
||||
{
|
||||
%f1 = arith.constant 0.0 : f32
|
||||
%f2 = arith.constant 1.0 : f32
|
||||
%idx = arith.constant 7 : index
|
||||
|
||||
// scf.execute_region is canonicalized away after bufferization. So just the
|
||||
// memref.store is left over.
|
||||
|
||||
// CHECK: memref.store %{{.*}}, %[[m1]][%{{.*}}]
|
||||
%0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
|
||||
%t2 = tensor.insert %f2 into %t1[%idx] : tensor<?xf32>
|
||||
scf.yield %f1, %t2, %f2 : f32, tensor<?xf32>, f32
|
||||
}
|
||||
|
||||
// CHECK: return %{{.*}}, %{{.*}} : f32, f32
|
||||
return %0, %1, %2 : f32, tensor<?xf32>, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @execute_region_with_conflict(
|
||||
// CHECK-SAME: %[[m1:.*]]: memref<?xf32
|
||||
func @execute_region_with_conflict(%t1 : tensor<?xf32> {linalg.inplaceable = "true"})
|
||||
-> (f32, tensor<?xf32>, f32)
|
||||
{
|
||||
%f1 = arith.constant 0.0 : f32
|
||||
%idx = arith.constant 7 : index
|
||||
|
||||
// scf.execute_region is canonicalized away after bufferization. So just the
|
||||
// memref.store is left over.
|
||||
|
||||
// CHECK: %[[alloc:.*]] = memref.alloc
|
||||
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
|
||||
// CHECK: memref.copy %[[m1]], %[[alloc]]
|
||||
// CHECK: memref.store %{{.*}}, %[[alloc]][%{{.*}}]
|
||||
%0, %1, %2 = scf.execute_region -> (f32, tensor<?xf32>, f32) {
|
||||
%t2 = tensor.insert %f1 into %t1[%idx] : tensor<?xf32>
|
||||
scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
|
||||
}
|
||||
|
||||
// CHECK: %[[load:.*]] = memref.load %[[m1]]
|
||||
%3 = tensor.extract %t1[%idx] : tensor<?xf32>
|
||||
|
||||
// CHECK: return %{{.*}}, %[[casted]], %[[load]] : f32, memref<?xf32, #{{.*}}>, f32
|
||||
return %0, %1, %3 : f32, tensor<?xf32>, f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: #[[$DYN_1D_MAP:.*]] = affine_map<(d0)[s0, s1] -> (d0 * s1 + s0)>
|
||||
|
||||
// CHECK: func private @some_external_func(memref<?xf32, #[[$DYN_1D_MAP]]>)
|
||||
|
|
Loading…
Reference in New Issue