From 6c77e59bfd28e7195754cbf8cc32c6cce90de6b6 Mon Sep 17 00:00:00 2001 From: Denis Khalikov Date: Mon, 18 Nov 2019 12:36:16 -0800 Subject: [PATCH] [spirv] Add a canonicalizer for BitcastOp. Convert chained `spirv::BitcastOp` operations into one `spirv::BitcastOp` operation. Closes tensorflow/mlir#238 COPYBARA_INTEGRATE_REVIEW=https://github.com/tensorflow/mlir/pull/238 from denis0x0D:sandbox/canon_bitcast 4352ed4f81b959ec92f849c599e733b62a99c010 PiperOrigin-RevId: 281129234 --- .../mlir/Dialect/SPIRV/SPIRVCastOps.td | 2 + mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 37 +++++++++++++++++-- mlir/test/Dialect/SPIRV/canonicalize.mlir | 28 ++++++++++++++ 3 files changed, 63 insertions(+), 4 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td index 245a2248948b..1798b9d5b159 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVCastOps.td @@ -98,6 +98,8 @@ def SPV_BitcastOp : SPV_Op<"Bitcast", [NoSideEffect]> { let parser = [{ return mlir::impl::parseCastOp(parser, result); }]; let printer = [{ mlir::impl::printCastOp(this->getOperation(), p); }]; + + let hasCanonicalizer = 1; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 5cda90756885..8964963cb0b9 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -652,8 +652,8 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) { namespace { -// Combine chained `spirv::AccessChainOp` operations into one -// `spirv::AccessChainOp` operation. +/// Combines chained `spirv::AccessChainOp` operations into one +/// `spirv::AccessChainOp` operation. struct CombineChainedAccessChain : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -678,7 +678,7 @@ struct CombineChainedAccessChain return matchSuccess(); } }; -} // namespace +} // end anonymous namespace void spirv::AccessChainOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { @@ -771,6 +771,35 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) { return success(); } +namespace { + +/// Converts chained `spirv::BitcastOp` operations into one +/// `spirv::BitcastOp` operation. +struct ConvertChainedBitcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(spirv::BitcastOp bitcastOp, + PatternRewriter &rewriter) const override { + auto parentBitcastOp = dyn_cast_or_null( + bitcastOp.operand()->getDefiningOp()); + + if (!parentBitcastOp) { + return matchFailure(); + } + + rewriter.replaceOpWithNewOp( + /*valuesToRemoveIfDead=*/{parentBitcastOp.result()}, bitcastOp, + bitcastOp.result()->getType(), parentBitcastOp.operand()); + return matchSuccess(); + } +}; +} // end anonymous namespace + +void spirv::BitcastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // spv.BitFieldInsert //===----------------------------------------------------------------------===// @@ -2278,7 +2307,7 @@ PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection( return matchSuccess(); } -} // namespace +} // end anonymous namespace void spirv::SelectionOp::getCanonicalizationPatterns( OwningRewritePatternList &results, MLIRContext *context) { diff --git a/mlir/test/Dialect/SPIRV/canonicalize.mlir b/mlir/test/Dialect/SPIRV/canonicalize.mlir index 02d8645973f6..87be892ce3bd 100644 --- a/mlir/test/Dialect/SPIRV/canonicalize.mlir +++ b/mlir/test/Dialect/SPIRV/canonicalize.mlir @@ -134,6 +134,34 @@ func @deduplicate_composite_constant() -> (!spv.array<1 x vector<2xi32>>, !spv.a // ----- +//===----------------------------------------------------------------------===// +// spv.Bitcast +//===----------------------------------------------------------------------===// + +func @convert_bitcast_full(%arg0 : vector<2xf32>) -> f64 { + // CHECK: %[[RESULT:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64 + // CHECK-NEXT: spv.ReturnValue %[[RESULT]] + %0 = spv.Bitcast %arg0 : vector<2xf32> to vector<2xi32> + %1 = spv.Bitcast %0 : vector<2xi32> to i64 + %2 = spv.Bitcast %1 : i64 to f64 + spv.ReturnValue %2 : f64 +} + +// ----- + +func @convert_bitcast_multi_use(%arg0 : vector<2xf32>, %arg1 : !spv.ptr) -> f64 { + // CHECK: %[[RESULT_0:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to i64 + // CHECK-NEXT: %[[RESULT_1:.*]] = spv.Bitcast {{%.*}} : vector<2xf32> to f64 + // CHECK-NEXT: spv.Store {{".*"}} {{%.*}}, %[[RESULT_0]] + // CHECK-NEXT: spv.ReturnValue %[[RESULT_1]] + %0 = spv.Bitcast %arg0 : vector<2xf32> to i64 + %1 = spv.Bitcast %0 : i64 to f64 + spv.Store "Uniform" %arg1, %0 : i64 + spv.ReturnValue %1 : f64 +} + +// ----- + //===----------------------------------------------------------------------===// // spv.selection //===----------------------------------------------------------------------===//