From 1e9321e97aba43e41ccd7ab2f1bef41d5bcf65af Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Wed, 26 Feb 2020 09:12:56 -0500 Subject: [PATCH] [mlir][spirv] NFC: move folders and canonicalizers in a separate file This gives us better file organization and faster compilation time by avoid having a gigantic SPIRVOps.cpp file. --- mlir/lib/Dialect/SPIRV/CMakeLists.txt | 1 + .../Dialect/SPIRV/SPIRVCanonicalization.cpp | 367 ++++++++++++++++++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 326 ---------------- 3 files changed, 368 insertions(+), 326 deletions(-) create mode 100644 mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp diff --git a/mlir/lib/Dialect/SPIRV/CMakeLists.txt b/mlir/lib/Dialect/SPIRV/CMakeLists.txt index d0ff25ef68f0..85bb7390b716 100644 --- a/mlir/lib/Dialect/SPIRV/CMakeLists.txt +++ b/mlir/lib/Dialect/SPIRV/CMakeLists.txt @@ -4,6 +4,7 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen) add_llvm_library(MLIRSPIRV LayoutUtils.cpp + SPIRVCanonicalization.cpp SPIRVDialect.cpp SPIRVOps.cpp SPIRVLowering.cpp diff --git a/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp new file mode 100644 index 000000000000..32090f3d1ec0 --- /dev/null +++ b/mlir/lib/Dialect/SPIRV/SPIRVCanonicalization.cpp @@ -0,0 +1,367 @@ +//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file defines the folders and canonicalization patterns for SPIR-V ops. +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/SPIRV/SPIRVOps.h" + +#include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/SPIRV/SPIRVDialect.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Support/Functional.h" + +using namespace mlir; + +//===----------------------------------------------------------------------===// +// Common utility functions +//===----------------------------------------------------------------------===// + +// Extracts an element from the given `composite` by following the given +// `indices`. Returns a null Attribute if error happens. +static Attribute extractCompositeElement(Attribute composite, + ArrayRef indices) { + // Check that given composite is a constant. + if (!composite) + return {}; + // Return composite itself if we reach the end of the index chain. + if (indices.empty()) + return composite; + + if (auto vector = composite.dyn_cast()) { + assert(indices.size() == 1 && "must have exactly one index for a vector"); + return vector.getValue({indices[0]}); + } + + if (auto array = composite.dyn_cast()) { + assert(!indices.empty() && "must have at least one index for an array"); + return extractCompositeElement(array.getValue()[indices[0]], + indices.drop_front()); + } + + return {}; +} + +//===----------------------------------------------------------------------===// +// TableGen'erated canonicalizers +//===----------------------------------------------------------------------===// + +namespace { +#include "SPIRVCanonicalization.inc" +} + +//===----------------------------------------------------------------------===// +// spv.AccessChainOp +//===----------------------------------------------------------------------===// + +namespace { + +/// Combines chained `spirv::AccessChainOp` operations into one +/// `spirv::AccessChainOp` operation. +struct CombineChainedAccessChain + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp, + PatternRewriter &rewriter) const override { + auto parentAccessChainOp = dyn_cast_or_null( + accessChainOp.base_ptr().getDefiningOp()); + + if (!parentAccessChainOp) { + return matchFailure(); + } + + // Combine indices. + SmallVector indices(parentAccessChainOp.indices()); + indices.append(accessChainOp.indices().begin(), + accessChainOp.indices().end()); + + rewriter.replaceOpWithNewOp( + accessChainOp, parentAccessChainOp.base_ptr(), indices); + + return matchSuccess(); + } +}; +} // end anonymous namespace + +void spirv::AccessChainOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// spv.BitcastOp +//===----------------------------------------------------------------------===// + +void spirv::BitcastOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// spv.CompositeExtractOp +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef operands) { + assert(operands.size() == 1 && "spv.CompositeExtract expects one operand"); + auto indexVector = functional::map( + [](Attribute attr) { + return static_cast(attr.cast().getInt()); + }, + indices()); + return extractCompositeElement(operands[0], indexVector); +} + +//===----------------------------------------------------------------------===// +// spv.constant +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::ConstantOp::fold(ArrayRef operands) { + assert(operands.empty() && "spv.constant has no operands"); + return value(); +} + +//===----------------------------------------------------------------------===// +// spv.IAdd +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::IAddOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "spv.IAdd expects two operands"); + // x + 0 = x + if (matchPattern(operand2(), m_Zero())) + return operand1(); + + // According to the SPIR-V spec: + // + // The resulting value will equal the low-order N bits of the correct result + // R, where N is the component width and R is computed with enough precision + // to avoid overflow and underflow. + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a + b; }); +} + +//===----------------------------------------------------------------------===// +// spv.IMul +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::IMulOp::fold(ArrayRef operands) { + assert(operands.size() == 2 && "spv.IMul expects two operands"); + // x * 0 == 0 + if (matchPattern(operand2(), m_Zero())) + return operand2(); + // x * 1 = x + if (matchPattern(operand2(), m_One())) + return operand1(); + + // According to the SPIR-V spec: + // + // The resulting value will equal the low-order N bits of the correct result + // R, where N is the component width and R is computed with enough precision + // to avoid overflow and underflow. + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a * b; }); +} + +//===----------------------------------------------------------------------===// +// spv.ISub +//===----------------------------------------------------------------------===// + +OpFoldResult spirv::ISubOp::fold(ArrayRef operands) { + // x - x = 0 + if (operand1() == operand2()) + return Builder(getContext()).getIntegerAttr(getType(), 0); + + // According to the SPIR-V spec: + // + // The resulting value will equal the low-order N bits of the correct result + // R, where N is the component width and R is computed with enough precision + // to avoid overflow and underflow. + return constFoldBinaryOp(operands, + [](APInt a, APInt b) { return a - b; }); +} + +//===----------------------------------------------------------------------===// +// spv.LogicalNot +//===----------------------------------------------------------------------===// + +void spirv::LogicalNotOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} + +//===----------------------------------------------------------------------===// +// spv.selection +//===----------------------------------------------------------------------===// + +namespace { +// Blocks from the given `spv.selection` operation must satisfy the following +// layout: +// +// +-----------------------------------------------+ +// | header block | +// | spv.BranchConditionalOp %cond, ^case0, ^case1 | +// +-----------------------------------------------+ +// / \ +// ... +// +// +// +------------------------+ +------------------------+ +// | case #0 | | case #1 | +// | spv.Store %ptr %value0 | | spv.Store %ptr %value1 | +// | spv.Branch ^merge | | spv.Branch ^merge | +// +------------------------+ +------------------------+ +// +// +// ... +// \ / +// v +// +-------------+ +// | merge block | +// +-------------+ +// +struct ConvertSelectionOpToSelect + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp, + PatternRewriter &rewriter) const override { + auto *op = selectionOp.getOperation(); + auto &body = op->getRegion(0); + // Verifier allows an empty region for `spv.selection`. + if (body.empty()) { + return matchFailure(); + } + + // Check that region consists of 4 blocks: + // header block, `true` block, `false` block and merge block. + if (std::distance(body.begin(), body.end()) != 4) { + return matchFailure(); + } + + auto *headerBlock = selectionOp.getHeaderBlock(); + if (!onlyContainsBranchConditionalOp(headerBlock)) { + return matchFailure(); + } + + auto brConditionalOp = + cast(headerBlock->front()); + + auto *trueBlock = brConditionalOp.getSuccessor(0); + auto *falseBlock = brConditionalOp.getSuccessor(1); + auto *mergeBlock = selectionOp.getMergeBlock(); + + if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) { + return matchFailure(); + } + + auto trueValue = getSrcValue(trueBlock); + auto falseValue = getSrcValue(falseBlock); + auto ptrValue = getDstPtr(trueBlock); + auto storeOpAttributes = + cast(trueBlock->front()).getOperation()->getAttrs(); + + auto selectOp = rewriter.create( + selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(), + trueValue, falseValue); + rewriter.create(selectOp.getLoc(), ptrValue, + selectOp.getResult(), storeOpAttributes); + + // `spv.selection` is not needed anymore. + rewriter.eraseOp(op); + return matchSuccess(); + } + +private: + // Checks that given blocks follow the following rules: + // 1. Each conditional block consists of two operations, the first operation + // is a `spv.Store` and the last operation is a `spv.Branch`. + // 2. Each `spv.Store` uses the same pointer and the same memory attributes. + // 3. A control flow goes into the given merge block from the given + // conditional blocks. + PatternMatchResult canCanonicalizeSelection(Block *trueBlock, + Block *falseBlock, + Block *mergeBlock) const; + + bool onlyContainsBranchConditionalOp(Block *block) const { + return std::next(block->begin()) == block->end() && + isa(block->front()); + } + + bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { + return lhs.getOperation()->getAttrList().getDictionary() == + rhs.getOperation()->getAttrList().getDictionary(); + } + + // Checks that given type is valid for `spv.SelectOp`. + // According to SPIR-V spec: + // "Before version 1.4, Result Type must be a pointer, scalar, or vector. + // Starting with version 1.4, Result Type can additionally be a composite type + // other than a vector." + bool isValidType(Type type) const { + return spirv::SPIRVDialect::isValidScalarType(type) || + type.isa(); + } + + // Returns a source value for the given block. + Value getSrcValue(Block *block) const { + auto storeOp = cast(block->front()); + return storeOp.value(); + } + + // Returns a destination value for the given block. + Value getDstPtr(Block *block) const { + auto storeOp = cast(block->front()); + return storeOp.ptr(); + } +}; + +PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection( + Block *trueBlock, Block *falseBlock, Block *mergeBlock) const { + // Each block must consists of 2 operations. + if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) || + (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) { + return matchFailure(); + } + + auto trueBrStoreOp = dyn_cast(trueBlock->front()); + auto trueBrBranchOp = + dyn_cast(*std::next(trueBlock->begin())); + auto falseBrStoreOp = dyn_cast(falseBlock->front()); + auto falseBrBranchOp = + dyn_cast(*std::next(falseBlock->begin())); + + if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp || + !falseBrBranchOp) { + return matchFailure(); + } + + // Check that each `spv.Store` uses the same pointer, memory access + // attributes and a valid type of the value. + if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) || + !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || + !isValidType(trueBrStoreOp.value().getType())) { + return matchFailure(); + } + + if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) || + (falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) { + return matchFailure(); + } + + return matchSuccess(); +} +} // end anonymous namespace + +void spirv::SelectionOp::getCanonicalizationPatterns( + OwningRewritePatternList &results, MLIRContext *context) { + results.insert(context); +} diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index 01197498a704..1dc4dd9aee0a 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -13,17 +13,13 @@ #include "mlir/Dialect/SPIRV/SPIRVOps.h" #include "mlir/Analysis/CallInterfaces.h" -#include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/SPIRV/SPIRVDialect.h" #include "mlir/Dialect/SPIRV/SPIRVTypes.h" #include "mlir/IR/Builders.h" #include "mlir/IR/Function.h" #include "mlir/IR/FunctionImplementation.h" -#include "mlir/IR/Matchers.h" #include "mlir/IR/OpImplementation.h" -#include "mlir/IR/PatternMatch.h" #include "mlir/IR/StandardTypes.h" -#include "mlir/Support/Functional.h" #include "mlir/Support/StringExtras.h" #include "llvm/ADT/bit.h" @@ -360,31 +356,6 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer, printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } -// Extracts an element from the given `composite` by following the given -// `indices`. Returns a null Attribute if error happens. -static Attribute extractCompositeElement(Attribute composite, - ArrayRef indices) { - // Check that given composite is a constant. - if (!composite) - return {}; - // Return composite itself if we reach the end of the index chain. - if (indices.empty()) - return composite; - - if (auto vector = composite.dyn_cast()) { - assert(indices.size() == 1 && "must have exactly one index for a vector"); - return vector.getValue({indices[0]}); - } - - if (auto array = composite.dyn_cast()) { - assert(!indices.empty() && "must have at least one index for an array"); - return extractCompositeElement(array.getValue()[indices[0]], - indices.drop_front()); - } - - return {}; -} - // Get bit width of types. static unsigned getBitWidth(Type type) { if (type.isa()) { @@ -477,14 +448,6 @@ static inline bool isMergeBlock(Block &block) { isa(block.front()); } -//===----------------------------------------------------------------------===// -// TableGen'erated canonicalizers -//===----------------------------------------------------------------------===// - -namespace { -#include "SPIRVCanonicalization.inc" -} - //===----------------------------------------------------------------------===// // Common parsers and printers //===----------------------------------------------------------------------===// @@ -848,41 +811,6 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) { return success(); } -namespace { - -/// Combines chained `spirv::AccessChainOp` operations into one -/// `spirv::AccessChainOp` operation. -struct CombineChainedAccessChain - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp, - PatternRewriter &rewriter) const override { - auto parentAccessChainOp = dyn_cast_or_null( - accessChainOp.base_ptr().getDefiningOp()); - - if (!parentAccessChainOp) { - return matchFailure(); - } - - // Combine indices. - SmallVector indices(parentAccessChainOp.indices()); - indices.append(accessChainOp.indices().begin(), - accessChainOp.indices().end()); - - rewriter.replaceOpWithNewOp( - accessChainOp, parentAccessChainOp.base_ptr(), indices); - - return matchSuccess(); - } -}; -} // end anonymous namespace - -void spirv::AccessChainOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // spv._address_of //===----------------------------------------------------------------------===// @@ -1013,11 +941,6 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) { return success(); } -void spirv::BitcastOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // spv.BranchConditionalOp //===----------------------------------------------------------------------===// @@ -1230,16 +1153,6 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) { return success(); } -OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef operands) { - assert(operands.size() == 1 && "spv.CompositeExtract expects one operand"); - auto indexVector = functional::map( - [](Attribute attr) { - return static_cast(attr.cast().getInt()); - }, - indices()); - return extractCompositeElement(operands[0], indexVector); -} - //===----------------------------------------------------------------------===// // spv.CompositeInsert //===----------------------------------------------------------------------===// @@ -1390,11 +1303,6 @@ static LogicalResult verify(spirv::ConstantOp constOp) { return success(); } -OpFoldResult spirv::ConstantOp::fold(ArrayRef operands) { - assert(operands.empty() && "spv.constant has no operands"); - return value(); -} - bool spirv::ConstantOp::isBuildableWith(Type type) { // Must be valid SPIR-V type first. if (!SPIRVDialect::isValidType(type)) @@ -1890,65 +1798,6 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) { return success(); } -//===----------------------------------------------------------------------===// -// spv.IAdd -//===----------------------------------------------------------------------===// - -OpFoldResult spirv::IAddOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "spv.IAdd expects two operands"); - // x + 0 = x - if (matchPattern(operand2(), m_Zero())) - return operand1(); - - // According to the SPIR-V spec: - // - // The resulting value will equal the low-order N bits of the correct result - // R, where N is the component width and R is computed with enough precision - // to avoid overflow and underflow. - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a + b; }); -} - -//===----------------------------------------------------------------------===// -// spv.IMul -//===----------------------------------------------------------------------===// - -OpFoldResult spirv::IMulOp::fold(ArrayRef operands) { - assert(operands.size() == 2 && "spv.IMul expects two operands"); - // x * 0 == 0 - if (matchPattern(operand2(), m_Zero())) - return operand2(); - // x * 1 = x - if (matchPattern(operand2(), m_One())) - return operand1(); - - // According to the SPIR-V spec: - // - // The resulting value will equal the low-order N bits of the correct result - // R, where N is the component width and R is computed with enough precision - // to avoid overflow and underflow. - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a * b; }); -} - -//===----------------------------------------------------------------------===// -// spv.ISub -//===----------------------------------------------------------------------===// - -OpFoldResult spirv::ISubOp::fold(ArrayRef operands) { - // x - x = 0 - if (operand1() == operand2()) - return Builder(getContext()).getIntegerAttr(getType(), 0); - - // According to the SPIR-V spec: - // - // The resulting value will equal the low-order N bits of the correct result - // R, where N is the component width and R is computed with enough precision - // to avoid overflow and underflow. - return constFoldBinaryOp(operands, - [](APInt a, APInt b) { return a - b; }); -} - //===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// @@ -2008,17 +1857,6 @@ static LogicalResult verify(spirv::LoadOp loadOp) { return verifyMemoryAccessAttribute(loadOp); } -//===----------------------------------------------------------------------===// -// spv.LogicalNot -//===----------------------------------------------------------------------===// - -void spirv::LogicalNotOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // spv.loop //===----------------------------------------------------------------------===// @@ -2547,170 +2385,6 @@ spirv::SelectionOp spirv::SelectionOp::createIfThen( return selectionOp; } -namespace { -// Blocks from the given `spv.selection` operation must satisfy the following -// layout: -// -// +-----------------------------------------------+ -// | header block | -// | spv.BranchConditionalOp %cond, ^case0, ^case1 | -// +-----------------------------------------------+ -// / \ -// ... -// -// -// +------------------------+ +------------------------+ -// | case #0 | | case #1 | -// | spv.Store %ptr %value0 | | spv.Store %ptr %value1 | -// | spv.Branch ^merge | | spv.Branch ^merge | -// +------------------------+ +------------------------+ -// -// -// ... -// \ / -// v -// +-------------+ -// | merge block | -// +-------------+ -// -struct ConvertSelectionOpToSelect - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp, - PatternRewriter &rewriter) const override { - auto *op = selectionOp.getOperation(); - auto &body = op->getRegion(0); - // Verifier allows an empty region for `spv.selection`. - if (body.empty()) { - return matchFailure(); - } - - // Check that region consists of 4 blocks: - // header block, `true` block, `false` block and merge block. - if (std::distance(body.begin(), body.end()) != 4) { - return matchFailure(); - } - - auto *headerBlock = selectionOp.getHeaderBlock(); - if (!onlyContainsBranchConditionalOp(headerBlock)) { - return matchFailure(); - } - - auto brConditionalOp = - cast(headerBlock->front()); - - auto *trueBlock = brConditionalOp.getSuccessor(0); - auto *falseBlock = brConditionalOp.getSuccessor(1); - auto *mergeBlock = selectionOp.getMergeBlock(); - - if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) { - return matchFailure(); - } - - auto trueValue = getSrcValue(trueBlock); - auto falseValue = getSrcValue(falseBlock); - auto ptrValue = getDstPtr(trueBlock); - auto storeOpAttributes = - cast(trueBlock->front()).getOperation()->getAttrs(); - - auto selectOp = rewriter.create( - selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(), - trueValue, falseValue); - rewriter.create(selectOp.getLoc(), ptrValue, - selectOp.getResult(), storeOpAttributes); - - // `spv.selection` is not needed anymore. - rewriter.eraseOp(op); - return matchSuccess(); - } - -private: - // Checks that given blocks follow the following rules: - // 1. Each conditional block consists of two operations, the first operation - // is a `spv.Store` and the last operation is a `spv.Branch`. - // 2. Each `spv.Store` uses the same pointer and the same memory attributes. - // 3. A control flow goes into the given merge block from the given - // conditional blocks. - PatternMatchResult canCanonicalizeSelection(Block *trueBlock, - Block *falseBlock, - Block *mergeBlock) const; - - bool onlyContainsBranchConditionalOp(Block *block) const { - return std::next(block->begin()) == block->end() && - isa(block->front()); - } - - bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const { - return lhs.getOperation()->getAttrList().getDictionary() == - rhs.getOperation()->getAttrList().getDictionary(); - } - - // Checks that given type is valid for `spv.SelectOp`. - // According to SPIR-V spec: - // "Before version 1.4, Result Type must be a pointer, scalar, or vector. - // Starting with version 1.4, Result Type can additionally be a composite type - // other than a vector." - bool isValidType(Type type) const { - return spirv::SPIRVDialect::isValidScalarType(type) || - type.isa(); - } - - // Returns a source value for the given block. - Value getSrcValue(Block *block) const { - auto storeOp = cast(block->front()); - return storeOp.value(); - } - - // Returns a destination value for the given block. - Value getDstPtr(Block *block) const { - auto storeOp = cast(block->front()); - return storeOp.ptr(); - } -}; - -PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection( - Block *trueBlock, Block *falseBlock, Block *mergeBlock) const { - // Each block must consists of 2 operations. - if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) || - (std::distance(falseBlock->begin(), falseBlock->end()) != 2)) { - return matchFailure(); - } - - auto trueBrStoreOp = dyn_cast(trueBlock->front()); - auto trueBrBranchOp = - dyn_cast(*std::next(trueBlock->begin())); - auto falseBrStoreOp = dyn_cast(falseBlock->front()); - auto falseBrBranchOp = - dyn_cast(*std::next(falseBlock->begin())); - - if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp || - !falseBrBranchOp) { - return matchFailure(); - } - - // Check that each `spv.Store` uses the same pointer, memory access - // attributes and a valid type of the value. - if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) || - !isSameAttrList(trueBrStoreOp, falseBrStoreOp) || - !isValidType(trueBrStoreOp.value().getType())) { - return matchFailure(); - } - - if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) || - (falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) { - return matchFailure(); - } - - return matchSuccess(); -} -} // end anonymous namespace - -void spirv::SelectionOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - //===----------------------------------------------------------------------===// // spv.specConstant //===----------------------------------------------------------------------===//