NFC: Wire up DRR settings for SPIR-V canonicalization patterns

This CL added necessary files and settings for using DRR to
write SPIR-V canonicalization patterns and also converted the
patterns for spv.Bitcast and spv.LogicalNot.

PiperOrigin-RevId: 282132786
This commit is contained in:
Lei Zhang 2019-11-23 06:08:50 -08:00 committed by A. Unique TensorFlower
parent aaafeac89b
commit ae821fe626
3 changed files with 55 additions and 58 deletions

View File

@ -1,3 +1,7 @@
set(LLVM_TARGET_DEFINITIONS SPIRVCanonicalization.td)
mlir_tablegen(SPIRVCanonicalization.inc -gen-rewriters)
add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
add_llvm_library(MLIRSPIRV
DialectRegistration.cpp
LayoutUtils.cpp
@ -11,8 +15,9 @@ add_llvm_library(MLIRSPIRV
)
add_dependencies(MLIRSPIRV
MLIRSPIRVOpsIncGen
MLIRSPIRVCanonicalizationIncGen
MLIRSPIRVEnumsIncGen
MLIRSPIRVOpsIncGen
MLIRSPIRVOpUtilsGen)
target_link_libraries(MLIRSPIRV

View File

@ -0,0 +1,40 @@
//==- SPIRVCanonicalization.td - Canonicalization Patterns ---*- tablegen -*==//
// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines SPIR-V canonicalization patterns.
//
//===----------------------------------------------------------------------===//
include "mlir/Dialect/SPIRV/SPIRVOps.td"
//===----------------------------------------------------------------------===//
// spv.Bitcast
//===----------------------------------------------------------------------===//
def ConvertChainedBitcast : Pat<(SPV_BitcastOp (SPV_BitcastOp $operand)),
(SPV_BitcastOp $operand)>;
//===----------------------------------------------------------------------===//
// spv.LogicalNot
//===----------------------------------------------------------------------===//
def ConvertLogicalNotOfIEqual : Pat<
(SPV_LogicalNotOp (SPV_IEqualOp $lhs, $rhs)),
(SPV_INotEqualOp $lhs, $rhs)>;
def ConvertLogicalNotOfINotEqual : Pat<
(SPV_LogicalNotOp (SPV_INotEqualOp $lhs, $rhs)),
(SPV_IEqualOp $lhs, $rhs)>;
def ConvertLogicalNotOfLogicalEqual : Pat<
(SPV_LogicalNotOp (SPV_LogicalEqualOp $lhs, $rhs)),
(SPV_LogicalNotEqualOp $lhs, $rhs)>;
def ConvertLogicalNotOfLogicalNotEqual : Pat<
(SPV_LogicalNotOp (SPV_LogicalNotEqualOp $lhs, $rhs)),
(SPV_LogicalEqualOp $lhs, $rhs)>;

View File

@ -377,6 +377,12 @@ static inline bool isMergeBlock(Block &block) {
isa<spirv::MergeOp>(block.front());
}
//===----------------------------------------------------------------------===//
// TableGen'erated canonicalizers
//===----------------------------------------------------------------------===//
#include "SPIRVCanonicalization.inc"
//===----------------------------------------------------------------------===//
// Common parsers and printers
//===----------------------------------------------------------------------===//
@ -771,30 +777,6 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
return success();
}
namespace {
/// Converts chained `spirv::BitcastOp` operations into one
/// `spirv::BitcastOp` operation.
struct ConvertChainedBitcast : public OpRewritePattern<spirv::BitcastOp> {
using OpRewritePattern<spirv::BitcastOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::BitcastOp bitcastOp,
PatternRewriter &rewriter) const override {
auto parentBitcastOp = dyn_cast_or_null<spirv::BitcastOp>(
bitcastOp.operand()->getDefiningOp());
if (!parentBitcastOp) {
return matchFailure();
}
rewriter.replaceOpWithNewOp<spirv::BitcastOp>(
/*valuesToRemoveIfDead=*/{parentBitcastOp.result()}, bitcastOp,
bitcastOp.result()->getType(), parentBitcastOp.operand());
return matchSuccess();
}
};
} // end anonymous namespace
void spirv::BitcastOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<ConvertChainedBitcast>(context);
@ -1587,41 +1569,11 @@ static LogicalResult verify(spirv::LoadOp loadOp) {
// spv.LogicalNot
//===----------------------------------------------------------------------===//
namespace {
/// Converts `spirv::LogicalNotOp` to the given `NewOp` using the first and the
/// second operands from the given `ParentOp`.
template <typename NewOp, typename ParentOp>
struct ConvertLogicalNotOp : public OpRewritePattern<spirv::LogicalNotOp> {
using OpRewritePattern<spirv::LogicalNotOp>::OpRewritePattern;
PatternMatchResult matchAndRewrite(spirv::LogicalNotOp logicalNotOp,
PatternRewriter &rewriter) const override {
auto parentOp =
dyn_cast_or_null<ParentOp>(logicalNotOp.operand()->getDefiningOp());
if (!parentOp) {
return this->matchFailure();
}
rewriter.replaceOpWithNewOp<NewOp>(
/*valuesToRemoveIfDead=*/{parentOp.result()}, logicalNotOp,
logicalNotOp.result()->getType(), parentOp.operand1(),
parentOp.operand2());
return this->matchSuccess();
}
};
} // end anonymous namespace
void spirv::LogicalNotOp::getCanonicalizationPatterns(
OwningRewritePatternList &results, MLIRContext *context) {
results.insert<
ConvertLogicalNotOp<spirv::INotEqualOp, spirv::IEqualOp>,
ConvertLogicalNotOp<spirv::IEqualOp, spirv::INotEqualOp>,
ConvertLogicalNotOp<spirv::LogicalNotEqualOp, spirv::LogicalEqualOp>,
ConvertLogicalNotOp<spirv::LogicalEqualOp, spirv::LogicalNotEqualOp>>(
context);
results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
ConvertLogicalNotOfLogicalEqual,
ConvertLogicalNotOfLogicalNotEqual>(context);
}
//===----------------------------------------------------------------------===//