[mlir][linalg][bufferize] Reimplementation of scf.if bufferization

Instead of modifying the existing scf.if op, create a new op with memref OpOperands/OpResults and delete the old op.

New allocations / other memrefs can now be yielded from the op. This functionality is deactivated by default and guarded against by AssertDestinationPassingStyle.

Differential Revision: https://reviews.llvm.org/D115491
This commit is contained in:
Matthias Springer 2021-12-15 18:32:13 +09:00
parent a4830d14ed
commit a5927737da
8 changed files with 167 additions and 46 deletions

View File

@ -461,8 +461,6 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
// Certain buffers are not writeable:
// 1. A function bbArg that is not inplaceable or
// 2. A constant op.
assert(!aliasesNonWritableBuffer(opResult, aliasInfo, state) &&
"expected that opResult does not alias non-writable buffer");
bool nonWritable =
aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
if (!nonWritable)

View File

@ -131,27 +131,74 @@ struct IfOpInterface
BufferizationState &state) const {
auto ifOp = cast<scf::IfOp>(op);
// Bufferize then/else blocks.
if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state)))
return failure();
if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state)))
return failure();
// Use IRRewriter instead of OpBuilder because it has additional helper
// functions.
IRRewriter rewriter(op->getContext());
rewriter.setInsertionPoint(ifOp);
for (OpResult opResult : ifOp->getResults()) {
if (!opResult.getType().isa<TensorType>())
continue;
// TODO: Atm we bail on unranked TensorType because we don't know how to
// alloc an UnrankedMemRefType + its underlying ranked MemRefType.
assert(opResult.getType().isa<RankedTensorType>() &&
"unsupported unranked tensor");
Value resultBuffer = state.getResultBuffer(opResult);
if (!resultBuffer)
return failure();
state.mapBuffer(opResult, resultBuffer);
// Compute new types of the bufferized scf.if op.
SmallVector<Type> newTypes;
for (Type returnType : ifOp->getResultTypes()) {
if (returnType.isa<TensorType>()) {
assert(returnType.isa<RankedTensorType>() &&
"unsupported unranked tensor");
newTypes.push_back(
getDynamicMemRefType(returnType.cast<RankedTensorType>()));
} else {
newTypes.push_back(returnType);
}
}
// Create new op.
auto newIfOp =
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.condition(),
/*withElseRegion=*/true);
// Remove terminators.
if (!newIfOp.thenBlock()->empty()) {
rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
}
// Move over then/else blocks.
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
// Update scf.yield of new then-block.
auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
rewriter.setInsertionPoint(thenYieldOp);
SmallVector<Value> thenYieldValues;
for (OpOperand &operand : thenYieldOp->getOpOperands()) {
if (operand.get().getType().isa<TensorType>()) {
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
operand.get());
operand.set(toMemrefOp);
}
}
// Update scf.yield of new else-block.
auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
rewriter.setInsertionPoint(elseYieldOp);
SmallVector<Value> elseYieldValues;
for (OpOperand &operand : elseYieldOp->getOpOperands()) {
if (operand.get().getType().isa<TensorType>()) {
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
operand.get());
operand.set(toMemrefOp);
}
}
// Replace op results.
state.replaceOp(op, newIfOp->getResults());
// Bufferize then/else blocks.
if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state)))
return failure();
if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state)))
return failure();
return success();
}
@ -293,33 +340,65 @@ struct ForOpInterface
}
};
// TODO: Evolve toward matching ReturnLike ops. Check for aliasing values that
// do not bufferize inplace. (Requires a few more changes for ConstantOp,
// InitTensorOp, CallOp.)
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state,
BufferizationAliasInfo &aliasInfo,
SmallVector<Operation *> &newOps) {
LogicalResult status = success();
op->walk([&](scf::YieldOp yieldOp) {
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
if (!forOp)
return WalkResult::advance();
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
status =
yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to an equivalent buffer to the matching"
<< " enclosing scf::for operand";
return WalkResult::interrupt();
}
}
}
OpOperand &forOperand = forOp.getOpOperandForResult(
forOp->getResult(operand.getOperandNumber()));
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
// TODO: this could get resolved with copies but it can also turn into
// swaps so we need to be careful about order of copies.
status =
yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to an equivalent buffer to the matching"
<< " enclosing scf::for operand";
return WalkResult::interrupt();
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
// IfOps are in destination passing style if all yielded tensors are
// a value or equivalent to a value that is defined outside of the IfOp.
for (OpOperand &operand : yieldOp->getOpOperands()) {
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
if (!tensorType)
continue;
bool foundOutsideEquivalent = false;
aliasInfo.applyOnEquivalenceClass(operand.get(), [&](Value value) {
Operation *valueOp = value.getDefiningOp();
if (value.isa<BlockArgument>())
valueOp = value.cast<BlockArgument>().getOwner()->getParentOp();
bool inThenBlock = ifOp.thenBlock()->findAncestorOpInBlock(*valueOp);
bool inElseBlock = ifOp.elseBlock()->findAncestorOpInBlock(*valueOp);
if (!inThenBlock && !inElseBlock)
foundOutsideEquivalent = true;
});
if (!foundOutsideEquivalent) {
status = yieldOp->emitError()
<< "Yield operand #" << operand.getOperandNumber()
<< " does not bufferize to a buffer that is equivalent to a"
<< " buffer defined outside of the scf::if op";
return WalkResult::interrupt();
}
}
}

View File

@ -97,7 +97,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
// TODO: Find a way to enable this step automatically when bufferizing tensor
// dialect ops.
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
if (!allowReturnMemref)
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
ModuleOp moduleOp = getOperation();
applyEnablingTransformations(moduleOp);

View File

@ -1,9 +1,9 @@
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=test-analysis-only -split-input-file | FileCheck %s
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref" -split-input-file | FileCheck %s
// Run fuzzer with different seeds.
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=23" -split-input-file -o /dev/null
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=59" -split-input-file -o /dev/null
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=91" -split-input-file -o /dev/null
//===----------------------------------------------------------------------===//
// Simple cases

View File

@ -38,12 +38,12 @@ func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
func @scf_if_not_equivalent(
%cond: i1, %t1: tensor<?xf32> {linalg.inplaceable = true},
%idx: index) -> tensor<?xf32> {
// expected-error @+1 {{result buffer is ambiguous}}
%r = scf.if %cond -> (tensor<?xf32>) {
scf.yield %t1 : tensor<?xf32>
} else {
// This buffer aliases, but is not equivalent.
%t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor<?xf32> to tensor<?xf32>
// expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is equivalent to a buffer defined outside of the scf::if op}}
scf.yield %t2 : tensor<?xf32>
}
return %r : tensor<?xf32>
@ -127,9 +127,9 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
// -----
// expected-error @+1 {{memref return type is unsupported}}
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
{
// expected-error @+1 {{result buffer is ambiguous}}
%r = scf.if %b -> (tensor<4xf32>) {
scf.yield %A : tensor<4xf32>
} else {

View File

@ -194,3 +194,28 @@ func @simple_scf_for(
// CHECK-SCF: return %[[scf_for_tensor]]
return %0 : tensor<?xf32>
}
// -----
// CHECK-SCF-LABEL: func @simple_scf_if(
// CHECK-SCF-SAME: %[[t1:.*]]: tensor<?xf32> {linalg.inplaceable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index
func @simple_scf_if(%t1: tensor<?xf32> {linalg.inplaceable = true}, %c: i1, %pos: index, %f: f32)
-> (tensor<?xf32>, index) {
// CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, #{{.*}}>) {
%r1, %r2 = scf.if %c -> (tensor<?xf32>, index) {
// CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
// CHECK-SCF: scf.yield %[[t1_memref]]
scf.yield %t1, %pos : tensor<?xf32>, index
// CHECK-SCF: } else {
} else {
// CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[t1]][{{.*}}]
// CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]]
%1 = tensor.insert %f into %t1[%pos] : tensor<?xf32>
// CHECK-SCF: scf.yield %[[insert_memref]]
scf.yield %1, %pos : tensor<?xf32>, index
}
// CHECK-SCF: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
// CHECK-SCF: return %[[r_tensor]], %[[pos]]
return %r1, %r2 : tensor<?xf32>, index
}

View File

@ -921,6 +921,22 @@ func @scf_if_inside_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
// -----
// CHECK-LABEL: func @scf_if_non_equiv_yields(
// CHECK-SAME: %[[cond:.*]]: i1, %[[A:.*]]: memref<{{.*}}>, %[[B:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
func @scf_if_non_equiv_yields(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
{
// CHECK: %[[r:.*]] = select %[[cond]], %[[A]], %[[B]]
%r = scf.if %b -> (tensor<4xf32>) {
scf.yield %A : tensor<4xf32>
} else {
scf.yield %B : tensor<4xf32>
}
// CHECK: return %[[r]]
return %r: tensor<4xf32>
}
// -----
// CHECK-LABEL: func @insert_op
// CHECK-SAME: %[[t1:.*]]: memref<?xf32, {{.*}}>, %[[s:.*]]: f32, %[[i:.*]]: index
func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},

View File

@ -101,6 +101,8 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
// TODO: Find a way to enable this step automatically when bufferizing
// tensor dialect ops.
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
if (!allowReturnMemref)
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
options.allowReturnMemref = allowReturnMemref;
options.allowUnknownOps = allowUnknownOps;