forked from OSchip/llvm-project
[mlir][linalg][bufferize] Reimplementation of scf.if bufferization
Instead of modifying the existing scf.if op, create a new op with memref OpOperands/OpResults and delete the old op. New allocations / other memrefs can now be yielded from the op. This functionality is deactivated by default and guarded against by AssertDestinationPassingStyle. Differential Revision: https://reviews.llvm.org/D115491
This commit is contained in:
parent
a4830d14ed
commit
a5927737da
|
@ -461,8 +461,6 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &opOperand, OpResult opResult,
|
|||
// Certain buffers are not writeable:
|
||||
// 1. A function bbArg that is not inplaceable or
|
||||
// 2. A constant op.
|
||||
assert(!aliasesNonWritableBuffer(opResult, aliasInfo, state) &&
|
||||
"expected that opResult does not alias non-writable buffer");
|
||||
bool nonWritable =
|
||||
aliasesNonWritableBuffer(opOperand.get(), aliasInfo, state);
|
||||
if (!nonWritable)
|
||||
|
|
|
@ -131,27 +131,74 @@ struct IfOpInterface
|
|||
BufferizationState &state) const {
|
||||
auto ifOp = cast<scf::IfOp>(op);
|
||||
|
||||
// Bufferize then/else blocks.
|
||||
if (failed(comprehensive_bufferize::bufferize(ifOp.thenBlock(), state)))
|
||||
return failure();
|
||||
if (failed(comprehensive_bufferize::bufferize(ifOp.elseBlock(), state)))
|
||||
return failure();
|
||||
// Use IRRewriter instead of OpBuilder because it has additional helper
|
||||
// functions.
|
||||
IRRewriter rewriter(op->getContext());
|
||||
rewriter.setInsertionPoint(ifOp);
|
||||
|
||||
for (OpResult opResult : ifOp->getResults()) {
|
||||
if (!opResult.getType().isa<TensorType>())
|
||||
continue;
|
||||
// TODO: Atm we bail on unranked TensorType because we don't know how to
|
||||
// alloc an UnrankedMemRefType + its underlying ranked MemRefType.
|
||||
assert(opResult.getType().isa<RankedTensorType>() &&
|
||||
"unsupported unranked tensor");
|
||||
|
||||
Value resultBuffer = state.getResultBuffer(opResult);
|
||||
if (!resultBuffer)
|
||||
return failure();
|
||||
|
||||
state.mapBuffer(opResult, resultBuffer);
|
||||
// Compute new types of the bufferized scf.if op.
|
||||
SmallVector<Type> newTypes;
|
||||
for (Type returnType : ifOp->getResultTypes()) {
|
||||
if (returnType.isa<TensorType>()) {
|
||||
assert(returnType.isa<RankedTensorType>() &&
|
||||
"unsupported unranked tensor");
|
||||
newTypes.push_back(
|
||||
getDynamicMemRefType(returnType.cast<RankedTensorType>()));
|
||||
} else {
|
||||
newTypes.push_back(returnType);
|
||||
}
|
||||
}
|
||||
|
||||
// Create new op.
|
||||
auto newIfOp =
|
||||
rewriter.create<scf::IfOp>(ifOp.getLoc(), newTypes, ifOp.condition(),
|
||||
/*withElseRegion=*/true);
|
||||
|
||||
// Remove terminators.
|
||||
if (!newIfOp.thenBlock()->empty()) {
|
||||
rewriter.eraseOp(newIfOp.thenBlock()->getTerminator());
|
||||
rewriter.eraseOp(newIfOp.elseBlock()->getTerminator());
|
||||
}
|
||||
|
||||
// Move over then/else blocks.
|
||||
rewriter.mergeBlocks(ifOp.thenBlock(), newIfOp.thenBlock());
|
||||
rewriter.mergeBlocks(ifOp.elseBlock(), newIfOp.elseBlock());
|
||||
|
||||
// Update scf.yield of new then-block.
|
||||
auto thenYieldOp = cast<scf::YieldOp>(newIfOp.thenBlock()->getTerminator());
|
||||
rewriter.setInsertionPoint(thenYieldOp);
|
||||
SmallVector<Value> thenYieldValues;
|
||||
for (OpOperand &operand : thenYieldOp->getOpOperands()) {
|
||||
if (operand.get().getType().isa<TensorType>()) {
|
||||
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
|
||||
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
|
||||
operand.get());
|
||||
operand.set(toMemrefOp);
|
||||
}
|
||||
}
|
||||
|
||||
// Update scf.yield of new else-block.
|
||||
auto elseYieldOp = cast<scf::YieldOp>(newIfOp.elseBlock()->getTerminator());
|
||||
rewriter.setInsertionPoint(elseYieldOp);
|
||||
SmallVector<Value> elseYieldValues;
|
||||
for (OpOperand &operand : elseYieldOp->getOpOperands()) {
|
||||
if (operand.get().getType().isa<TensorType>()) {
|
||||
Value toMemrefOp = rewriter.create<bufferization::ToMemrefOp>(
|
||||
operand.get().getLoc(), newTypes[operand.getOperandNumber()],
|
||||
operand.get());
|
||||
operand.set(toMemrefOp);
|
||||
}
|
||||
}
|
||||
|
||||
// Replace op results.
|
||||
state.replaceOp(op, newIfOp->getResults());
|
||||
|
||||
// Bufferize then/else blocks.
|
||||
if (failed(comprehensive_bufferize::bufferize(newIfOp.thenBlock(), state)))
|
||||
return failure();
|
||||
if (failed(comprehensive_bufferize::bufferize(newIfOp.elseBlock(), state)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -293,33 +340,65 @@ struct ForOpInterface
|
|||
}
|
||||
};
|
||||
|
||||
// TODO: Evolve toward matching ReturnLike ops. Check for aliasing values that
|
||||
// do not bufferize inplace. (Requires a few more changes for ConstantOp,
|
||||
// InitTensorOp, CallOp.)
|
||||
LogicalResult mlir::linalg::comprehensive_bufferize::scf_ext::
|
||||
AssertDestinationPassingStyle::run(Operation *op, BufferizationState &state,
|
||||
BufferizationAliasInfo &aliasInfo,
|
||||
SmallVector<Operation *> &newOps) {
|
||||
LogicalResult status = success();
|
||||
op->walk([&](scf::YieldOp yieldOp) {
|
||||
auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
|
||||
if (!forOp)
|
||||
return WalkResult::advance();
|
||||
if (auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp())) {
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
status =
|
||||
yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to an equivalent buffer to the matching"
|
||||
<< " enclosing scf::for operand";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
OpOperand &forOperand = forOp.getOpOperandForResult(
|
||||
forOp->getResult(operand.getOperandNumber()));
|
||||
auto bbArg = forOp.getRegionIterArgForOpOperand(forOperand);
|
||||
if (!aliasInfo.areEquivalentBufferizedValues(operand.get(), bbArg)) {
|
||||
// TODO: this could get resolved with copies but it can also turn into
|
||||
// swaps so we need to be careful about order of copies.
|
||||
status =
|
||||
yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to an equivalent buffer to the matching"
|
||||
<< " enclosing scf::for operand";
|
||||
return WalkResult::interrupt();
|
||||
if (auto ifOp = dyn_cast<scf::IfOp>(yieldOp->getParentOp())) {
|
||||
// IfOps are in destination passing style if all yielded tensors are
|
||||
// a value or equivalent to a value that is defined outside of the IfOp.
|
||||
for (OpOperand &operand : yieldOp->getOpOperands()) {
|
||||
auto tensorType = operand.get().getType().dyn_cast<TensorType>();
|
||||
if (!tensorType)
|
||||
continue;
|
||||
|
||||
bool foundOutsideEquivalent = false;
|
||||
aliasInfo.applyOnEquivalenceClass(operand.get(), [&](Value value) {
|
||||
Operation *valueOp = value.getDefiningOp();
|
||||
if (value.isa<BlockArgument>())
|
||||
valueOp = value.cast<BlockArgument>().getOwner()->getParentOp();
|
||||
|
||||
bool inThenBlock = ifOp.thenBlock()->findAncestorOpInBlock(*valueOp);
|
||||
bool inElseBlock = ifOp.elseBlock()->findAncestorOpInBlock(*valueOp);
|
||||
|
||||
if (!inThenBlock && !inElseBlock)
|
||||
foundOutsideEquivalent = true;
|
||||
});
|
||||
|
||||
if (!foundOutsideEquivalent) {
|
||||
status = yieldOp->emitError()
|
||||
<< "Yield operand #" << operand.getOperandNumber()
|
||||
<< " does not bufferize to a buffer that is equivalent to a"
|
||||
<< " buffer defined outside of the scf::if op";
|
||||
return WalkResult::interrupt();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -97,7 +97,8 @@ void LinalgComprehensiveModuleBufferize::runOnOperation() {
|
|||
// TODO: Find a way to enable this step automatically when bufferizing tensor
|
||||
// dialect ops.
|
||||
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
|
||||
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
|
||||
if (!allowReturnMemref)
|
||||
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
|
||||
|
||||
ModuleOp moduleOp = getOperation();
|
||||
applyEnablingTransformations(moduleOp);
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize=test-analysis-only -split-input-file | FileCheck %s
|
||||
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref" -split-input-file | FileCheck %s
|
||||
|
||||
// Run fuzzer with different seeds.
|
||||
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=23" -split-input-file -o /dev/null
|
||||
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=59" -split-input-file -o /dev/null
|
||||
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only analysis-fuzzer-seed=91" -split-input-file -o /dev/null
|
||||
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=23" -split-input-file -o /dev/null
|
||||
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=59" -split-input-file -o /dev/null
|
||||
// RUN: mlir-opt %s -linalg-comprehensive-module-bufferize="test-analysis-only allow-return-memref analysis-fuzzer-seed=91" -split-input-file -o /dev/null
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Simple cases
|
||||
|
|
|
@ -38,12 +38,12 @@ func @swappy(%cond1 : i1, %cond2 : i1, %t1 : tensor<f32>, %t2 : tensor<f32>)
|
|||
func @scf_if_not_equivalent(
|
||||
%cond: i1, %t1: tensor<?xf32> {linalg.inplaceable = true},
|
||||
%idx: index) -> tensor<?xf32> {
|
||||
// expected-error @+1 {{result buffer is ambiguous}}
|
||||
%r = scf.if %cond -> (tensor<?xf32>) {
|
||||
scf.yield %t1 : tensor<?xf32>
|
||||
} else {
|
||||
// This buffer aliases, but is not equivalent.
|
||||
%t2 = tensor.extract_slice %t1 [%idx] [%idx] [1] : tensor<?xf32> to tensor<?xf32>
|
||||
// expected-error @+1 {{Yield operand #0 does not bufferize to a buffer that is equivalent to a buffer defined outside of the scf::if op}}
|
||||
scf.yield %t2 : tensor<?xf32>
|
||||
}
|
||||
return %r : tensor<?xf32>
|
||||
|
@ -127,9 +127,9 @@ func @extract_slice_fun(%A : tensor<?xf32> {linalg.inplaceable = true})
|
|||
|
||||
// -----
|
||||
|
||||
// expected-error @+1 {{memref return type is unsupported}}
|
||||
func @scf_yield(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
|
||||
{
|
||||
// expected-error @+1 {{result buffer is ambiguous}}
|
||||
%r = scf.if %b -> (tensor<4xf32>) {
|
||||
scf.yield %A : tensor<4xf32>
|
||||
} else {
|
||||
|
|
|
@ -194,3 +194,28 @@ func @simple_scf_for(
|
|||
// CHECK-SCF: return %[[scf_for_tensor]]
|
||||
return %0 : tensor<?xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-SCF-LABEL: func @simple_scf_if(
|
||||
// CHECK-SCF-SAME: %[[t1:.*]]: tensor<?xf32> {linalg.inplaceable = true}, %[[c:.*]]: i1, %[[pos:.*]]: index
|
||||
func @simple_scf_if(%t1: tensor<?xf32> {linalg.inplaceable = true}, %c: i1, %pos: index, %f: f32)
|
||||
-> (tensor<?xf32>, index) {
|
||||
// CHECK-SCF: %[[r:.*]] = scf.if %[[c]] -> (memref<?xf32, #{{.*}}>) {
|
||||
%r1, %r2 = scf.if %c -> (tensor<?xf32>, index) {
|
||||
// CHECK-SCF: %[[t1_memref:.*]] = bufferization.to_memref %[[t1]]
|
||||
// CHECK-SCF: scf.yield %[[t1_memref]]
|
||||
scf.yield %t1, %pos : tensor<?xf32>, index
|
||||
// CHECK-SCF: } else {
|
||||
} else {
|
||||
// CHECK-SCF: %[[insert:.*]] = tensor.insert %{{.*}} into %[[t1]][{{.*}}]
|
||||
// CHECK-SCF: %[[insert_memref:.*]] = bufferization.to_memref %[[insert]]
|
||||
%1 = tensor.insert %f into %t1[%pos] : tensor<?xf32>
|
||||
// CHECK-SCF: scf.yield %[[insert_memref]]
|
||||
scf.yield %1, %pos : tensor<?xf32>, index
|
||||
}
|
||||
|
||||
// CHECK-SCF: %[[r_tensor:.*]] = bufferization.to_tensor %[[r]]
|
||||
// CHECK-SCF: return %[[r_tensor]], %[[pos]]
|
||||
return %r1, %r2 : tensor<?xf32>, index
|
||||
}
|
||||
|
|
|
@ -921,6 +921,22 @@ func @scf_if_inside_scf_for(%t1: tensor<?xf32> {linalg.inplaceable = true},
|
|||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @scf_if_non_equiv_yields(
|
||||
// CHECK-SAME: %[[cond:.*]]: i1, %[[A:.*]]: memref<{{.*}}>, %[[B:.*]]: memref<{{.*}}>) -> memref<{{.*}}>
|
||||
func @scf_if_non_equiv_yields(%b : i1, %A : tensor<4xf32>, %B : tensor<4xf32>) -> tensor<4xf32>
|
||||
{
|
||||
// CHECK: %[[r:.*]] = select %[[cond]], %[[A]], %[[B]]
|
||||
%r = scf.if %b -> (tensor<4xf32>) {
|
||||
scf.yield %A : tensor<4xf32>
|
||||
} else {
|
||||
scf.yield %B : tensor<4xf32>
|
||||
}
|
||||
// CHECK: return %[[r]]
|
||||
return %r: tensor<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func @insert_op
|
||||
// CHECK-SAME: %[[t1:.*]]: memref<?xf32, {{.*}}>, %[[s:.*]]: f32, %[[i:.*]]: index
|
||||
func @insert_op(%t1 : tensor<?xf32> {linalg.inplaceable = true},
|
||||
|
|
|
@ -101,6 +101,8 @@ void TestComprehensiveFunctionBufferize::runOnFunction() {
|
|||
// TODO: Find a way to enable this step automatically when bufferizing
|
||||
// tensor dialect ops.
|
||||
options.addPostAnalysisStep<tensor_ext::InplaceInsertSliceOpAnalysis>();
|
||||
if (!allowReturnMemref)
|
||||
options.addPostAnalysisStep<scf_ext::AssertDestinationPassingStyle>();
|
||||
|
||||
options.allowReturnMemref = allowReturnMemref;
|
||||
options.allowUnknownOps = allowUnknownOps;
|
||||
|
|
Loading…
Reference in New Issue