forked from OSchip/llvm-project
[spirv] AccessChainOp canonicalization.
Combine chained `spirv::AccessChainOp` operations into one `spirv::AccessChainOp` operation. Closes tensorflow/mlir#198 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/198 from denis0x0D:sandbox/canon_access_chain 0cb87955a85511071143d62637ff939d0dabc2bd PiperOrigin-RevId: 276609345
This commit is contained in:
parent
2b61b7979e
commit
dd2e444325
|
@ -127,6 +127,8 @@ def SPV_AccessChainOp : SPV_Op<"AccessChain", [NoSideEffect]> {
|
|||
|
||||
let builders = [OpBuilder<[{Builder *builder, OperationState &state,
|
||||
Value *basePtr, ArrayRef<Value *> indices}]>];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
|
|
@ -544,6 +544,41 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
|||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
// Combine chained `spirv::AccessChainOp` operations into one
|
||||
// `spirv::AccessChainOp` operation.
|
||||
struct CombineChainedAccessChain
|
||||
: public OpRewritePattern<spirv::AccessChainOp> {
|
||||
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
||||
PatternRewriter &rewriter) const override {
|
||||
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
|
||||
accessChainOp.base_ptr()->getDefiningOp());
|
||||
|
||||
if (!parentAccessChainOp) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
// Combine indices.
|
||||
SmallVector<Value *, 4> indices(parentAccessChainOp.indices());
|
||||
indices.append(accessChainOp.indices().begin(),
|
||||
accessChainOp.indices().end());
|
||||
|
||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||
accessChainOp, parentAccessChainOp.base_ptr(), indices);
|
||||
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
void spirv::AccessChainOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<CombineChainedAccessChain>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv._address_of
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1976,7 +2011,8 @@ namespace {
|
|||
// | merge block |
|
||||
// +-------------+
|
||||
//
|
||||
struct SelectionOpCanonicalizer : public OpRewritePattern<spirv::SelectionOp> {
|
||||
struct ConvertSelectionOpToSelect
|
||||
: public OpRewritePattern<spirv::SelectionOp> {
|
||||
using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
|
||||
|
||||
PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
|
||||
|
@ -2071,7 +2107,7 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
PatternMatchResult SelectionOpCanonicalizer::canCanonicalizeSelection(
|
||||
PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
||||
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
|
||||
// Each block must consists of 2 operations.
|
||||
if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
|
||||
|
@ -2110,7 +2146,7 @@ PatternMatchResult SelectionOpCanonicalizer::canCanonicalizeSelection(
|
|||
|
||||
void spirv::SelectionOp::getCanonicalizationPatterns(
|
||||
OwningRewritePatternList &results, MLIRContext *context) {
|
||||
results.insert<SelectionOpCanonicalizer>(context);
|
||||
results.insert<ConvertSelectionOpToSelect>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1,5 +1,63 @@
|
|||
// RUN: mlir-opt %s -split-input-file -pass-pipeline='func(canonicalize)' | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.AccsessChain
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @combine_full_access_chain() -> f32 {
|
||||
// CHECK: %[[INDEX:.*]] = spv.constant 0
|
||||
// CHECK-NEXT: %[[VAR:.*]] = spv.Variable
|
||||
// CHECK-NEXT: %[[PTR:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]]
|
||||
// CHECK-NEXT: spv.Load "Function" %[[PTR]]
|
||||
%c0 = spv.constant 0: i32
|
||||
%0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
|
||||
%1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
|
||||
%2 = spv.AccessChain %1[%c0, %c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
|
||||
%3 = spv.Load "Function" %2 : f32
|
||||
spv.ReturnValue %3 : f32
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @combine_access_chain_multi_use() -> !spv.array<4xf32> {
|
||||
// CHECK: %[[INDEX:.*]] = spv.constant 0
|
||||
// CHECK-NEXT: %[[VAR:.*]] = spv.Variable
|
||||
// CHECK-NEXT: %[[PTR_0:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]]]
|
||||
// CHECK-NEXT: %[[PTR_1:.*]] = spv.AccessChain %[[VAR]][%[[INDEX]], %[[INDEX]], %[[INDEX]]]
|
||||
// CHECK-NEXT: spv.Load "Function" %[[PTR_0]]
|
||||
// CHECK-NEXT: spv.Load "Function" %[[PTR_1]]
|
||||
%c0 = spv.constant 0: i32
|
||||
%0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
|
||||
%1 = spv.AccessChain %0[%c0] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
|
||||
%2 = spv.AccessChain %1[%c0] : !spv.ptr<!spv.array<4x!spv.array<4xf32>>, Function>
|
||||
%3 = spv.AccessChain %2[%c0] : !spv.ptr<!spv.array<4xf32>, Function>
|
||||
%4 = spv.Load "Function" %2 : !spv.array<4xf32>
|
||||
%5 = spv.Load "Function" %3 : f32
|
||||
spv.ReturnValue %4: !spv.array<4xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @dont_combine_access_chain_without_common_base() -> !spv.array<4xi32> {
|
||||
// CHECK: %[[INDEX:.*]] = spv.constant 1
|
||||
// CHECK-NEXT: %[[VAR_0:.*]] = spv.Variable
|
||||
// CHECK-NEXT: %[[VAR_1:.*]] = spv.Variable
|
||||
// CHECK-NEXT: %[[VAR_0_PTR:.*]] = spv.AccessChain %[[VAR_0]][%[[INDEX]]]
|
||||
// CHECK-NEXT: %[[VAR_1_PTR:.*]] = spv.AccessChain %[[VAR_1]][%[[INDEX]]]
|
||||
// CHECK-NEXT: spv.Load "Function" %[[VAR_0_PTR]]
|
||||
// CHECK-NEXT: spv.Load "Function" %[[VAR_1_PTR]]
|
||||
%c1 = spv.constant 1: i32
|
||||
%0 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
|
||||
%1 = spv.Variable : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
|
||||
%2 = spv.AccessChain %0[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
|
||||
%3 = spv.AccessChain %1[%c1] : !spv.ptr<!spv.struct<!spv.array<4x!spv.array<4xf32>>, !spv.array<4xi32>>, Function>
|
||||
%4 = spv.Load "Function" %2 : !spv.array<4xi32>
|
||||
%5 = spv.Load "Function" %3 : !spv.array<4xi32>
|
||||
spv.ReturnValue %4 : !spv.array<4xi32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.CompositeExtract
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue