forked from OSchip/llvm-project
[mlir][SCF][bufferize] Bufferize scf.if/execute_region terminators separately
This allows for better type inference during bufferization and is in preparation of supporting memory spaces. Differential Revision: https://reviews.llvm.org/D128581
This commit is contained in:
parent
7ebf70d85d
commit
8e691e1f24
|
@ -75,41 +75,17 @@ struct ExecuteRegionOpInterface
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
const BufferizationOptions &options) const {
|
const BufferizationOptions &options) const {
|
||||||
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
auto executeRegionOp = cast<scf::ExecuteRegionOp>(op);
|
||||||
|
assert(executeRegionOp.getRegion().getBlocks().size() == 1 &&
|
||||||
// Compute new result types.
|
"only 1 block supported");
|
||||||
SmallVector<Type> newResultTypes;
|
auto yieldOp =
|
||||||
for (Type type : executeRegionOp->getResultTypes()) {
|
cast<scf::YieldOp>(executeRegionOp.getRegion().front().getTerminator());
|
||||||
if (auto tensorType = type.dyn_cast<TensorType>()) {
|
TypeRange newResultTypes(yieldOp.getResults());
|
||||||
// TODO: Infer the result type instead of computing it.
|
|
||||||
newResultTypes.push_back(getMemRefType(tensorType, options));
|
|
||||||
} else {
|
|
||||||
newResultTypes.push_back(type);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Create new op and move over region.
|
// Create new op and move over region.
|
||||||
auto newOp =
|
auto newOp =
|
||||||
rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
|
rewriter.create<scf::ExecuteRegionOp>(op->getLoc(), newResultTypes);
|
||||||
newOp.getRegion().takeBody(executeRegionOp.getRegion());
|
newOp.getRegion().takeBody(executeRegionOp.getRegion());
|
||||||
|
|
||||||
// Update terminator.
|
|
||||||
assert(newOp.getRegion().getBlocks().size() == 1 &&
|
|
||||||
"only 1 block supported");
|
|
||||||
Block *newBlock = &newOp.getRegion().front();
|
|
||||||
auto yieldOp = cast<scf::YieldOp>(newBlock->getTerminator());
|
|
||||||
rewriter.setInsertionPoint(yieldOp);
|
|
||||||
SmallVector<Value> newYieldValues;
|
|
||||||
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
|
|
||||||
Value val = it.value();
|
|
||||||
if (val.getType().isa<TensorType>()) {
|
|
||||||
newYieldValues.push_back(rewriter.create<bufferization::ToMemrefOp>(
|
|
||||||
yieldOp.getLoc(), newResultTypes[it.index()], val));
|
|
||||||
} else {
|
|
||||||
newYieldValues.push_back(val);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
rewriter.replaceOpWithNewOp<scf::YieldOp>(yieldOp, newYieldValues);
|
|
||||||
|
|
||||||
// Update all uses of the old op.
|
// Update all uses of the old op.
|
||||||
rewriter.setInsertionPointAfter(newOp);
|
rewriter.setInsertionPointAfter(newOp);
|
||||||
SmallVector<Value> newResults;
|
SmallVector<Value> newResults;
|
||||||
|
@ -184,64 +160,62 @@ struct IfOpInterface
|
||||||
|
|
||||||
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
LogicalResult bufferize(Operation *op, RewriterBase &rewriter,
|
||||||
const BufferizationOptions &options) const {
|
const BufferizationOptions &options) const {
|
||||||
|
OpBuilder::InsertionGuard g(rewriter);
|
||||||
auto ifOp = cast<scf::IfOp>(op);
|
auto ifOp = cast<scf::IfOp>(op);
|
||||||
|
auto thenYieldOp = cast<scf::YieldOp>(ifOp.thenBlock()->getTerminator());
|
||||||
|
auto elseYieldOp = cast<scf::YieldOp>(ifOp.elseBlock()->getTerminator());
|
||||||
|
|
||||||
// Compute new types of the bufferized scf.if op.
|
// Reconcile type mismatches between then/else branches by inserting memref
|
||||||
SmallVector<Type> newTypes;
|
// casts.
|
||||||
for (Type returnType : ifOp->getResultTypes()) {
|
SmallVector<Value> thenResults, elseResults;
|
||||||
if (auto tensorType = returnType.dyn_cast<TensorType>()) {
|
bool insertedCast = false;
|
||||||
// TODO: Infer the result type instead of computing it.
|
for (unsigned i = 0; i < thenYieldOp.getResults().size(); ++i) {
|
||||||
newTypes.push_back(getMemRefType(tensorType, options));
|
Value thenValue = thenYieldOp.getResults()[i];
|
||||||
} else {
|
Value elseValue = elseYieldOp.getResults()[i];
|
||||||
newTypes.push_back(returnType);
|
if (thenValue.getType() == elseValue.getType()) {
|
||||||
|
thenResults.push_back(thenValue);
|
||||||
|
elseResults.push_back(elseValue);
|
||||||
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Type mismatch between then/else yield value. Cast both to a memref type
|
||||||
|
// with a fully dynamic layout map.
|
||||||
|
auto thenMemrefType = thenValue.getType().cast<BaseMemRefType>();
|
||||||
|
auto elseMemrefType = elseValue.getType().cast<BaseMemRefType>();
|
||||||
|
if (thenMemrefType.getMemorySpaceAsInt() !=
|
||||||
|
elseMemrefType.getMemorySpaceAsInt())
|
||||||
|
return op->emitError("inconsistent memory space on then/else branches");
|
||||||
|
rewriter.setInsertionPoint(thenYieldOp);
|
||||||
|
BaseMemRefType memrefType = getMemRefTypeWithFullyDynamicLayout(
|
||||||
|
ifOp.getResultTypes()[i].cast<TensorType>(),
|
||||||
|
thenMemrefType.getMemorySpaceAsInt());
|
||||||
|
thenResults.push_back(rewriter.create<memref::CastOp>(
|
||||||
|
thenYieldOp.getLoc(), memrefType, thenValue));
|
||||||
|
rewriter.setInsertionPoint(elseYieldOp);
|
||||||
|
elseResults.push_back(rewriter.create<memref::CastOp>(
|
||||||
|
elseYieldOp.getLoc(), memrefType, elseValue));
|
||||||
|
insertedCast = true;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (insertedCast) {
|
||||||
|
rewriter.setInsertionPoint(thenYieldOp);
|
||||||
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(thenYieldOp, thenResults);
|
||||||
|
rewriter.setInsertionPoint(elseYieldOp);
|
||||||
|
rewriter.replaceOpWithNewOp<scf::YieldOp>(elseYieldOp, elseResults);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create new op.
|
// Create new op.
|
||||||
|
rewriter.setInsertionPoint(ifOp);
|
||||||
|
ValueRange resultsValueRange(thenResults);
|
||||||
|
TypeRange newTypes(resultsValueRange);
|
||||||
auto newIfOp =
|
auto newIfOp =
|
||||||
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
|
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.getCondition(),
|
||||||
/*withElseRegion=*/true);
|
/*withElseRegion=*/true);
|
||||||
|
|
||||||
// Remove terminators.
|
|
||||||
if (!newIfOp.thenBlock()->empty()) {
|
|
||||||
rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
|
|
||||||
rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
|
|
||||||
}
|
|
||||||
|
|
||||||
// Move over then/else blocks.
|
// Move over then/else blocks.
|
||||||
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
|
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
|
||||||
rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
|
rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
|
||||||
|
|
||||||
// Update scf.yield of new then-block.
|
|
||||||
auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
|
|
||||||
rewriter.setInsertionPoint(thenYieldOp);
|
|
||||||
SmallVector<Value> thenYieldValues;
|
|
||||||
for (OpOperand &operand : thenYieldOp->getOpOperands()) {
|
|
||||||
if (operand.get().getType().isa<TensorType>()) {
|
|
||||||
ensureToMemrefOpIsValid(operand.get(),
|
|
||||||
newTypes[operand.getOperandNumber()]);
|
|
||||||
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
|
|
||||||
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
|
|
||||||
operand.get());
|
|
||||||
operand.set(toMemrefOp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update scf.yield of new else-block.
|
|
||||||
auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
|
|
||||||
rewriter.setInsertionPoint(elseYieldOp);
|
|
||||||
SmallVector<Value> elseYieldValues;
|
|
||||||
for (OpOperand &operand : elseYieldOp->getOpOperands()) {
|
|
||||||
if (operand.get().getType().isa<TensorType>()) {
|
|
||||||
ensureToMemrefOpIsValid(operand.get(),
|
|
||||||
newTypes[operand.getOperandNumber()]);
|
|
||||||
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
|
|
||||||
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
|
|
||||||
operand.get());
|
|
||||||
operand.set(toMemrefOp);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Replace op results.
|
// Replace op results.
|
||||||
replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
|
replaceOpWithBufferizedValues(rewriter, op, newIfOp->getResults());
|
||||||
|
|
||||||
|
@ -869,6 +843,24 @@ struct YieldOpInterface
|
||||||
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
|
if (!isa<scf::ExecuteRegionOp, scf::IfOp, scf::ForOp, scf::WhileOp>(
|
||||||
yieldOp->getParentOp()))
|
yieldOp->getParentOp()))
|
||||||
return yieldOp->emitError("unsupported scf::YieldOp parent");
|
return yieldOp->emitError("unsupported scf::YieldOp parent");
|
||||||
|
|
||||||
|
// TODO: Bufferize scf.yield inside scf.while/scf.for here.
|
||||||
|
// (Currently bufferized together with scf.while/scf.for.)
|
||||||
|
if (isa<scf::ForOp, scf::WhileOp>(yieldOp->getParentOp()))
|
||||||
|
return success();
|
||||||
|
|
||||||
|
SmallVector<Value> newResults;
|
||||||
|
for (const auto &it : llvm::enumerate(yieldOp.getResults())) {
|
||||||
|
Value value = it.value();
|
||||||
|
if (value.getType().isa<TensorType>()) {
|
||||||
|
Value buffer = getBuffer(rewriter, value, options);
|
||||||
|
newResults.push_back(buffer);
|
||||||
|
} else {
|
||||||
|
newResults.push_back(value);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
replaceOpWithNewBufferizedOp<scf::YieldOp>(rewriter, op, newResults);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -8,6 +8,7 @@
|
||||||
// CHECK-LABEL: func @buffer_not_deallocated(
|
// CHECK-LABEL: func @buffer_not_deallocated(
|
||||||
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
|
// CHECK-SAME: %[[t:.*]]: tensor<?xf32>
|
||||||
func.func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32> {
|
func.func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32> {
|
||||||
|
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
|
||||||
// CHECK: %[[r:.*]] = scf.if %{{.*}} {
|
// CHECK: %[[r:.*]] = scf.if %{{.*}} {
|
||||||
%r = scf.if %c -> tensor<?xf32> {
|
%r = scf.if %c -> tensor<?xf32> {
|
||||||
// CHECK: %[[some_op:.*]] = "test.some_op"
|
// CHECK: %[[some_op:.*]] = "test.some_op"
|
||||||
|
@ -20,7 +21,6 @@ func.func @buffer_not_deallocated(%t : tensor<?xf32>, %c : i1) -> tensor<?xf32>
|
||||||
scf.yield %0 : tensor<?xf32>
|
scf.yield %0 : tensor<?xf32>
|
||||||
} else {
|
} else {
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
// CHECK: %[[m:.*]] = bufferization.to_memref %[[t]]
|
|
||||||
// CHECK: %[[cloned:.*]] = bufferization.clone %[[m]]
|
// CHECK: %[[cloned:.*]] = bufferization.clone %[[m]]
|
||||||
// CHECK: scf.yield %[[cloned]]
|
// CHECK: scf.yield %[[cloned]]
|
||||||
scf.yield %t : tensor<?xf32>
|
scf.yield %t : tensor<?xf32>
|
||||||
|
@ -40,8 +40,8 @@ func.func @write_to_alloc_tensor_or_readonly_tensor(%arg0: tensor<i32>,
|
||||||
%cond: i1, %val: i32)
|
%cond: i1, %val: i32)
|
||||||
-> tensor<i32>
|
-> tensor<i32>
|
||||||
{
|
{
|
||||||
|
// CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]]
|
||||||
// CHECK: %[[r:.*]] = scf.if {{.*}} {
|
// CHECK: %[[r:.*]] = scf.if {{.*}} {
|
||||||
// CHECK: %[[arg0_m:.*]] = bufferization.to_memref %[[arg0]]
|
|
||||||
// CHECK: %[[clone:.*]] = bufferization.clone %[[arg0_m]]
|
// CHECK: %[[clone:.*]] = bufferization.clone %[[arg0_m]]
|
||||||
// CHECK: scf.yield %[[clone]]
|
// CHECK: scf.yield %[[clone]]
|
||||||
// CHECK: } else {
|
// CHECK: } else {
|
||||||
|
|
|
@ -206,9 +206,9 @@ func.func @simple_tensor_test(%t1 : tensor<?xf32>, %f : f32) -> tensor<?xf32> {
|
||||||
// CHECK-SCF-SAME: %[[t1:.*]]: tensor<?xf32> {bufferization.writable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index
|
// CHECK-SCF-SAME: %[[t1:.*]]: tensor<?xf32> {bufferization.writable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index
|
||||||
func.func @simple_scf_if(%t1: tensor<?xf32> {bufferization.writable = true}, %c: i1, %pos: index, %f: f32)
|
func.func @simple_scf_if(%t1: tensor<?xf32> {bufferization.writable = true}, %c: i1, %pos: index, %f: f32)
|
||||||
-> (tensor<?xf32>, index) {
|
-> (tensor<?xf32>, index) {
|
||||||
|
// CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
|
||||||
// CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, #{{.*}}>) {
|
// CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, #{{.*}}>) {
|
||||||
%r1, %r2 = scf.if %c -> (tensor<?xf32>, index) {
|
%r1, %r2 = scf.if %c -> (tensor<?xf32>, index) {
|
||||||
// CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
|
|
||||||
// CHECK-SCF: scf.yield %[[t1_memref]]
|
// CHECK-SCF: scf.yield %[[t1_memref]]
|
||||||
scf.yield %t1, %pos : tensor<?xf32>, index
|
scf.yield %t1, %pos : tensor<?xf32>, index
|
||||||
// CHECK-SCF: } else {
|
// CHECK-SCF: } else {
|
||||||
|
|
|
@ -124,11 +124,10 @@ func.func @execute_region_with_conflict(
|
||||||
scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
|
scf.yield %f1, %t2, %f1 : f32, tensor<?xf32>, f32
|
||||||
}
|
}
|
||||||
|
|
||||||
// CHECK: %[[casted:.*]] = memref.cast %[[alloc]]
|
|
||||||
// CHECK: %[[load:.*]] = memref.load %[[m1]]
|
// CHECK: %[[load:.*]] = memref.load %[[m1]]
|
||||||
%3 = tensor.extract %t1[%idx] : tensor<?xf32>
|
%3 = tensor.extract %t1[%idx] : tensor<?xf32>
|
||||||
|
|
||||||
// CHECK: return %{{.*}}, %[[casted]], %[[load]] : f32, memref<?xf32, #{{.*}}>, f32
|
// CHECK: return %{{.*}}, %[[alloc]], %[[load]] : f32, memref<?xf32>, f32
|
||||||
return %0, %1, %3 : f32, tensor<?xf32>, f32
|
return %0, %1, %3 : f32, tensor<?xf32>, f32
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue