[spirv] AccessChainOp canonicalization.

Combine chained `spirv::AccessChainOp` operations into one
`spirv::AccessChainOp` operation.

Closes tensorflow/mlir#198

COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/198 from denis0x0D:sandbox/canon_access_chain 0cb87955a85511071143d62637ff939d0dabc2bd
PiperOrigin-RevId: 276609345
This commit is contained in:
Denis Khalikov 2019-10-24 18:40:38 -07:00 committed by A. Unique TensorFlower
parent 2b61b7979e
commit dd2e444325
3 changed files with 99 additions and 3 deletions

View File

@ -127,6 +127,8 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
Value *basePtr, ArrayRef<Value *> indices}]>];
let hasCanonicalizer = 1;
}
// -----

View File

@ -544,6 +544,41 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
return success();
}
namespace {
// Combine chained `spirv::AccessChainOp` operations into one
// `spirv::AccessChainOp` operation.
struct CombineChainedAccessChain
: public OpRewritePattern<spirv::AccessChainOp> {
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
PatternRewriter &rewriter) const override {
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
accessChainOp.base_ptr()->getDefiningOp());
if (!parentAccessChainOp) {
return matchFailure();
}
// Combine indices.
SmallVector<Value *, 4> indices(parentAccessChainOp.indices());
indices.append(accessChainOp.indices().begin(),
accessChainOp.indices().end());
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
accessChainOp, parentAccessChainOp.base_ptr(), indices);
return matchSuccess();
}
};
} // namespace
void spirv::AccessChainOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<CombineChainedAccessChain>(context);
}
//===----------------------------------------------------------------------===//
// spv._address_of
//===----------------------------------------------------------------------===//
@ -1976,7 +2011,8 @@ namespace {
// | merge block |
// +-------------+
//
struct SelectionOpCanonicalizer : public OpRewritePattern<spirv::SelectionOp> {
struct ConvertSelectionOpToSelect
: public OpRewritePattern<spirv::SelectionOp> {
using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
@ -2071,7 +2107,7 @@ private:
}
};
PatternMatchResult SelectionOpCanonicalizer::canCanonicalizeSelection(
PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
// Each block must consists of 2 operations.
if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
@ -2110,7 +2146,7 @@ PatternMatchResult SelectionOpCanonicalizer::canCanonicalizeSelection(
void spirv::SelectionOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<SelectionOpCanonicalizer>(context);
results.insert<ConvertSelectionOpToSelect>(context);
}
//===----------------------------------------------------------------------===//

View File

@ -1,5 +1,63 @@
// RUN: mlir-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
//===----------------------------------------------------------------------===//
// spv.AccsessChain
//===----------------------------------------------------------------------===//
func @combine_full_access_chain() -> f32 {
// CHECK: %[[INDEX:.*]] = spv.constant 0
// CHECK-NEXT: %[[VAR:.*]] = spv.Variable
// CHECK-NEXT: %[[PTR:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]]
// CHECK-NEXT: spv.Load "Function" %[[PTR]]
%c0 = spv.constant 0: i32
%0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
%1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
%2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
%3 = spv.Load "Function" %2 : f32
spv.ReturnValue %3 : f32
}
// -----
func @combine_access_chain_multi_use() -> !spv.array<4xf32> {
// CHECK: %[[INDEX:.*]] = spv.constant 0
// CHECK-NEXT: %[[VAR:.*]] = spv.Variable
// CHECK-NEXT: %[[PTR_0:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]]]
// CHECK-NEXT: %[[PTR_1:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]]
// CHECK-NEXT: spv.Load "Function" %[[PTR_0]]
// CHECK-NEXT: spv.Load "Function" %[[PTR_1]]
%c0 = spv.constant 0: i32
%0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
%1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
%2 = spv.AccessChain %1[%c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
%3 = spv.AccessChain %2[%c0] : !spv.ptr<!spv.array<4xf32>, Function>
%4 = spv.Load "Function" %2 : !spv.array<4xf32>
%5 = spv.Load "Function" %3 : f32
spv.ReturnValue %4: !spv.array<4xf32>
}
// -----
func @dont_combine_access_chain_without_common_base() -> !spv.array<4xi32> {
// CHECK: %[[INDEX:.*]] = spv.constant 1
// CHECK-NEXT: %[[VAR_0:.*]] = spv.Variable
// CHECK-NEXT: %[[VAR_1:.*]] = spv.Variable
// CHECK-NEXT: %[[VAR_0_PTR:.*]] = spv.AccessChain %[[VAR_0]][%[[INDEX]]]
// CHECK-NEXT: %[[VAR_1_PTR:.*]] = spv.AccessChain %[[VAR_1]][%[[INDEX]]]
// CHECK-NEXT: spv.Load "Function" %[[VAR_0_PTR]]
// CHECK-NEXT: spv.Load "Function" %[[VAR_1_PTR]]
%c1 = spv.constant 1: i32
%0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
%1 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
%2 = spv.AccessChain %0[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
%3 = spv.AccessChain %1[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
%4 = spv.Load "Function" %2 : !spv.array<4xi32>
%5 = spv.Load "Function" %3 : !spv.array<4xi32>
spv.ReturnValue %4 : !spv.array<4xi32>
}
// -----
//===----------------------------------------------------------------------===//
// spv.CompositeExtract
//===----------------------------------------------------------------------===//