forked from OSchip/llvm-project
[mlir][spirv] NFC: move folders and canonicalizers in a separate file
This gives us better file organization and faster compilation time by avoid having a gigantic SPIRVOps.cpp file.
This commit is contained in:
parent
387c3f74fd
commit
1e9321e97a
|
@ -4,6 +4,7 @@ add_public_tablegen_target(MLIRSPIRVCanonicalizationIncGen)
|
||||||
|
|
||||||
add_llvm_library(MLIRSPIRV
|
add_llvm_library(MLIRSPIRV
|
||||||
LayoutUtils.cpp
|
LayoutUtils.cpp
|
||||||
|
SPIRVCanonicalization.cpp
|
||||||
SPIRVDialect.cpp
|
SPIRVDialect.cpp
|
||||||
SPIRVOps.cpp
|
SPIRVOps.cpp
|
||||||
SPIRVLowering.cpp
|
SPIRVLowering.cpp
|
||||||
|
|
|
@ -0,0 +1,367 @@
|
||||||
|
//===- SPIRVCanonicalization.cpp - MLIR SPIR-V canonicalization patterns --===//
|
||||||
|
//
|
||||||
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||||
|
// See https://llvm.org/LICENSE.txt for license information.
|
||||||
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
//
|
||||||
|
// This file defines the folders and canonicalization patterns for SPIR-V ops.
|
||||||
|
//
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||||
|
|
||||||
|
#include "mlir/Dialect/CommonFolders.h"
|
||||||
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||||
|
#include "mlir/IR/Matchers.h"
|
||||||
|
#include "mlir/IR/PatternMatch.h"
|
||||||
|
#include "mlir/Support/Functional.h"
|
||||||
|
|
||||||
|
using namespace mlir;
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// Common utility functions
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
// Extracts an element from the given `composite` by following the given
|
||||||
|
// `indices`. Returns a null Attribute if error happens.
|
||||||
|
static Attribute extractCompositeElement(Attribute composite,
|
||||||
|
ArrayRef<unsigned> indices) {
|
||||||
|
// Check that given composite is a constant.
|
||||||
|
if (!composite)
|
||||||
|
return {};
|
||||||
|
// Return composite itself if we reach the end of the index chain.
|
||||||
|
if (indices.empty())
|
||||||
|
return composite;
|
||||||
|
|
||||||
|
if (auto vector = composite.dyn_cast<ElementsAttr>()) {
|
||||||
|
assert(indices.size() == 1 && "must have exactly one index for a vector");
|
||||||
|
return vector.getValue({indices[0]});
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto array = composite.dyn_cast<ArrayAttr>()) {
|
||||||
|
assert(!indices.empty() && "must have at least one index for an array");
|
||||||
|
return extractCompositeElement(array.getValue()[indices[0]],
|
||||||
|
indices.drop_front());
|
||||||
|
}
|
||||||
|
|
||||||
|
return {};
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// TableGen'erated canonicalizers
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
#include "SPIRVCanonicalization.inc"
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.AccessChainOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
/// Combines chained `spirv::AccessChainOp` operations into one
|
||||||
|
/// `spirv::AccessChainOp` operation.
|
||||||
|
struct CombineChainedAccessChain
|
||||||
|
: public OpRewritePattern<spirv::AccessChainOp> {
|
||||||
|
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
|
||||||
|
accessChainOp.base_ptr().getDefiningOp());
|
||||||
|
|
||||||
|
if (!parentAccessChainOp) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Combine indices.
|
||||||
|
SmallVector<Value, 4> indices(parentAccessChainOp.indices());
|
||||||
|
indices.append(accessChainOp.indices().begin(),
|
||||||
|
accessChainOp.indices().end());
|
||||||
|
|
||||||
|
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
||||||
|
accessChainOp, parentAccessChainOp.base_ptr(), indices);
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
void spirv::AccessChainOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<CombineChainedAccessChain>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.BitcastOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void spirv::BitcastOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<ConvertChainedBitcast>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.CompositeExtractOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
|
||||||
|
auto indexVector = functional::map(
|
||||||
|
[](Attribute attr) {
|
||||||
|
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
|
||||||
|
},
|
||||||
|
indices());
|
||||||
|
return extractCompositeElement(operands[0], indexVector);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.constant
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
assert(operands.empty() && "spv.constant has no operands");
|
||||||
|
return value();
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.IAdd
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
assert(operands.size() == 2 && "spv.IAdd expects two operands");
|
||||||
|
// x + 0 = x
|
||||||
|
if (matchPattern(operand2(), m_Zero()))
|
||||||
|
return operand1();
|
||||||
|
|
||||||
|
// According to the SPIR-V spec:
|
||||||
|
//
|
||||||
|
// The resulting value will equal the low-order N bits of the correct result
|
||||||
|
// R, where N is the component width and R is computed with enough precision
|
||||||
|
// to avoid overflow and underflow.
|
||||||
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||||
|
[](APInt a, APInt b) { return a + b; });
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.IMul
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
assert(operands.size() == 2 && "spv.IMul expects two operands");
|
||||||
|
// x * 0 == 0
|
||||||
|
if (matchPattern(operand2(), m_Zero()))
|
||||||
|
return operand2();
|
||||||
|
// x * 1 = x
|
||||||
|
if (matchPattern(operand2(), m_One()))
|
||||||
|
return operand1();
|
||||||
|
|
||||||
|
// According to the SPIR-V spec:
|
||||||
|
//
|
||||||
|
// The resulting value will equal the low-order N bits of the correct result
|
||||||
|
// R, where N is the component width and R is computed with enough precision
|
||||||
|
// to avoid overflow and underflow.
|
||||||
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||||
|
[](APInt a, APInt b) { return a * b; });
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.ISub
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
|
||||||
|
// x - x = 0
|
||||||
|
if (operand1() == operand2())
|
||||||
|
return Builder(getContext()).getIntegerAttr(getType(), 0);
|
||||||
|
|
||||||
|
// According to the SPIR-V spec:
|
||||||
|
//
|
||||||
|
// The resulting value will equal the low-order N bits of the correct result
|
||||||
|
// R, where N is the component width and R is computed with enough precision
|
||||||
|
// to avoid overflow and underflow.
|
||||||
|
return constFoldBinaryOp<IntegerAttr>(operands,
|
||||||
|
[](APInt a, APInt b) { return a - b; });
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.LogicalNot
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
void spirv::LogicalNotOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
|
||||||
|
ConvertLogicalNotOfLogicalEqual,
|
||||||
|
ConvertLogicalNotOfLogicalNotEqual>(context);
|
||||||
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// spv.selection
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
// Blocks from the given `spv.selection` operation must satisfy the following
|
||||||
|
// layout:
|
||||||
|
//
|
||||||
|
// +-----------------------------------------------+
|
||||||
|
// | header block |
|
||||||
|
// | spv.BranchConditionalOp %cond, ^case0, ^case1 |
|
||||||
|
// +-----------------------------------------------+
|
||||||
|
// / \
|
||||||
|
// ...
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// +------------------------+ +------------------------+
|
||||||
|
// | case #0 | | case #1 |
|
||||||
|
// | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
|
||||||
|
// | spv.Branch ^merge | | spv.Branch ^merge |
|
||||||
|
// +------------------------+ +------------------------+
|
||||||
|
//
|
||||||
|
//
|
||||||
|
// ...
|
||||||
|
// \ /
|
||||||
|
// v
|
||||||
|
// +-------------+
|
||||||
|
// | merge block |
|
||||||
|
// +-------------+
|
||||||
|
//
|
||||||
|
struct ConvertSelectionOpToSelect
|
||||||
|
: public OpRewritePattern<spirv::SelectionOp> {
|
||||||
|
using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
|
||||||
|
|
||||||
|
PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
|
||||||
|
PatternRewriter &rewriter) const override {
|
||||||
|
auto *op = selectionOp.getOperation();
|
||||||
|
auto &body = op->getRegion(0);
|
||||||
|
// Verifier allows an empty region for `spv.selection`.
|
||||||
|
if (body.empty()) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that region consists of 4 blocks:
|
||||||
|
// header block, `true` block, `false` block and merge block.
|
||||||
|
if (std::distance(body.begin(), body.end()) != 4) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto *headerBlock = selectionOp.getHeaderBlock();
|
||||||
|
if (!onlyContainsBranchConditionalOp(headerBlock)) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto brConditionalOp =
|
||||||
|
cast<spirv::BranchConditionalOp>(headerBlock->front());
|
||||||
|
|
||||||
|
auto *trueBlock = brConditionalOp.getSuccessor(0);
|
||||||
|
auto *falseBlock = brConditionalOp.getSuccessor(1);
|
||||||
|
auto *mergeBlock = selectionOp.getMergeBlock();
|
||||||
|
|
||||||
|
if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto trueValue = getSrcValue(trueBlock);
|
||||||
|
auto falseValue = getSrcValue(falseBlock);
|
||||||
|
auto ptrValue = getDstPtr(trueBlock);
|
||||||
|
auto storeOpAttributes =
|
||||||
|
cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
|
||||||
|
|
||||||
|
auto selectOp = rewriter.create<spirv::SelectOp>(
|
||||||
|
selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
|
||||||
|
trueValue, falseValue);
|
||||||
|
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
|
||||||
|
selectOp.getResult(), storeOpAttributes);
|
||||||
|
|
||||||
|
// `spv.selection` is not needed anymore.
|
||||||
|
rewriter.eraseOp(op);
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
// Checks that given blocks follow the following rules:
|
||||||
|
// 1. Each conditional block consists of two operations, the first operation
|
||||||
|
// is a `spv.Store` and the last operation is a `spv.Branch`.
|
||||||
|
// 2. Each `spv.Store` uses the same pointer and the same memory attributes.
|
||||||
|
// 3. A control flow goes into the given merge block from the given
|
||||||
|
// conditional blocks.
|
||||||
|
PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
|
||||||
|
Block *falseBlock,
|
||||||
|
Block *mergeBlock) const;
|
||||||
|
|
||||||
|
bool onlyContainsBranchConditionalOp(Block *block) const {
|
||||||
|
return std::next(block->begin()) == block->end() &&
|
||||||
|
isa<spirv::BranchConditionalOp>(block->front());
|
||||||
|
}
|
||||||
|
|
||||||
|
bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
|
||||||
|
return lhs.getOperation()->getAttrList().getDictionary() ==
|
||||||
|
rhs.getOperation()->getAttrList().getDictionary();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Checks that given type is valid for `spv.SelectOp`.
|
||||||
|
// According to SPIR-V spec:
|
||||||
|
// "Before version 1.4, Result Type must be a pointer, scalar, or vector.
|
||||||
|
// Starting with version 1.4, Result Type can additionally be a composite type
|
||||||
|
// other than a vector."
|
||||||
|
bool isValidType(Type type) const {
|
||||||
|
return spirv::SPIRVDialect::isValidScalarType(type) ||
|
||||||
|
type.isa<VectorType>();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a source value for the given block.
|
||||||
|
Value getSrcValue(Block *block) const {
|
||||||
|
auto storeOp = cast<spirv::StoreOp>(block->front());
|
||||||
|
return storeOp.value();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Returns a destination value for the given block.
|
||||||
|
Value getDstPtr(Block *block) const {
|
||||||
|
auto storeOp = cast<spirv::StoreOp>(block->front());
|
||||||
|
return storeOp.ptr();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
||||||
|
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
|
||||||
|
// Each block must consists of 2 operations.
|
||||||
|
if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
|
||||||
|
(std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
|
||||||
|
auto trueBrBranchOp =
|
||||||
|
dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
|
||||||
|
auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
|
||||||
|
auto falseBrBranchOp =
|
||||||
|
dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
|
||||||
|
|
||||||
|
if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
|
||||||
|
!falseBrBranchOp) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that each `spv.Store` uses the same pointer, memory access
|
||||||
|
// attributes and a valid type of the value.
|
||||||
|
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
|
||||||
|
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
|
||||||
|
!isValidType(trueBrStoreOp.value().getType())) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
|
||||||
|
(falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
|
||||||
|
return matchFailure();
|
||||||
|
}
|
||||||
|
|
||||||
|
return matchSuccess();
|
||||||
|
}
|
||||||
|
} // end anonymous namespace
|
||||||
|
|
||||||
|
void spirv::SelectionOp::getCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &results, MLIRContext *context) {
|
||||||
|
results.insert<ConvertSelectionOpToSelect>(context);
|
||||||
|
}
|
|
@ -13,17 +13,13 @@
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||||
|
|
||||||
#include "mlir/Analysis/CallInterfaces.h"
|
#include "mlir/Analysis/CallInterfaces.h"
|
||||||
#include "mlir/Dialect/CommonFolders.h"
|
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||||
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
#include "mlir/Dialect/SPIRV/SPIRVTypes.h"
|
||||||
#include "mlir/IR/Builders.h"
|
#include "mlir/IR/Builders.h"
|
||||||
#include "mlir/IR/Function.h"
|
#include "mlir/IR/Function.h"
|
||||||
#include "mlir/IR/FunctionImplementation.h"
|
#include "mlir/IR/FunctionImplementation.h"
|
||||||
#include "mlir/IR/Matchers.h"
|
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/PatternMatch.h"
|
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
#include "mlir/Support/Functional.h"
|
|
||||||
#include "mlir/Support/StringExtras.h"
|
#include "mlir/Support/StringExtras.h"
|
||||||
#include "llvm/ADT/bit.h"
|
#include "llvm/ADT/bit.h"
|
||||||
|
|
||||||
|
@ -360,31 +356,6 @@ static void printVariableDecorations(Operation *op, OpAsmPrinter &printer,
|
||||||
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs);
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extracts an element from the given `composite` by following the given
|
|
||||||
// `indices`. Returns a null Attribute if error happens.
|
|
||||||
static Attribute extractCompositeElement(Attribute composite,
|
|
||||||
ArrayRef<unsigned> indices) {
|
|
||||||
// Check that given composite is a constant.
|
|
||||||
if (!composite)
|
|
||||||
return {};
|
|
||||||
// Return composite itself if we reach the end of the index chain.
|
|
||||||
if (indices.empty())
|
|
||||||
return composite;
|
|
||||||
|
|
||||||
if (auto vector = composite.dyn_cast<ElementsAttr>()) {
|
|
||||||
assert(indices.size() == 1 && "must have exactly one index for a vector");
|
|
||||||
return vector.getValue({indices[0]});
|
|
||||||
}
|
|
||||||
|
|
||||||
if (auto array = composite.dyn_cast<ArrayAttr>()) {
|
|
||||||
assert(!indices.empty() && "must have at least one index for an array");
|
|
||||||
return extractCompositeElement(array.getValue()[indices[0]],
|
|
||||||
indices.drop_front());
|
|
||||||
}
|
|
||||||
|
|
||||||
return {};
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get bit width of types.
|
// Get bit width of types.
|
||||||
static unsigned getBitWidth(Type type) {
|
static unsigned getBitWidth(Type type) {
|
||||||
if (type.isa<spirv::PointerType>()) {
|
if (type.isa<spirv::PointerType>()) {
|
||||||
|
@ -477,14 +448,6 @@ static inline bool isMergeBlock(Block &block) {
|
||||||
isa<spirv::MergeOp>(block.front());
|
isa<spirv::MergeOp>(block.front());
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// TableGen'erated canonicalizers
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
namespace {
|
|
||||||
#include "SPIRVCanonicalization.inc"
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// Common parsers and printers
|
// Common parsers and printers
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -848,41 +811,6 @@ static LogicalResult verify(spirv::AccessChainOp accessChainOp) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
/// Combines chained `spirv::AccessChainOp` operations into one
|
|
||||||
/// `spirv::AccessChainOp` operation.
|
|
||||||
struct CombineChainedAccessChain
|
|
||||||
: public OpRewritePattern<spirv::AccessChainOp> {
|
|
||||||
using OpRewritePattern<spirv::AccessChainOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(spirv::AccessChainOp accessChainOp,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto parentAccessChainOp = dyn_cast_or_null<spirv::AccessChainOp>(
|
|
||||||
accessChainOp.base_ptr().getDefiningOp());
|
|
||||||
|
|
||||||
if (!parentAccessChainOp) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Combine indices.
|
|
||||||
SmallVector<Value, 4> indices(parentAccessChainOp.indices());
|
|
||||||
indices.append(accessChainOp.indices().begin(),
|
|
||||||
accessChainOp.indices().end());
|
|
||||||
|
|
||||||
rewriter.replaceOpWithNewOp<spirv::AccessChainOp>(
|
|
||||||
accessChainOp, parentAccessChainOp.base_ptr(), indices);
|
|
||||||
|
|
||||||
return matchSuccess();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
} // end anonymous namespace
|
|
||||||
|
|
||||||
void spirv::AccessChainOp::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<CombineChainedAccessChain>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv._address_of
|
// spv._address_of
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1013,11 +941,6 @@ static LogicalResult verify(spirv::BitcastOp bitcastOp) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
void spirv::BitcastOp::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<ConvertChainedBitcast>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.BranchConditionalOp
|
// spv.BranchConditionalOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1230,16 +1153,6 @@ static LogicalResult verify(spirv::CompositeExtractOp compExOp) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult spirv::CompositeExtractOp::fold(ArrayRef<Attribute> operands) {
|
|
||||||
assert(operands.size() == 1 && "spv.CompositeExtract expects one operand");
|
|
||||||
auto indexVector = functional::map(
|
|
||||||
[](Attribute attr) {
|
|
||||||
return static_cast<unsigned>(attr.cast<IntegerAttr>().getInt());
|
|
||||||
},
|
|
||||||
indices());
|
|
||||||
return extractCompositeElement(operands[0], indexVector);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.CompositeInsert
|
// spv.CompositeInsert
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -1390,11 +1303,6 @@ static LogicalResult verify(spirv::ConstantOp constOp) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
OpFoldResult spirv::ConstantOp::fold(ArrayRef<Attribute> operands) {
|
|
||||||
assert(operands.empty() && "spv.constant has no operands");
|
|
||||||
return value();
|
|
||||||
}
|
|
||||||
|
|
||||||
bool spirv::ConstantOp::isBuildableWith(Type type) {
|
bool spirv::ConstantOp::isBuildableWith(Type type) {
|
||||||
// Must be valid SPIR-V type first.
|
// Must be valid SPIR-V type first.
|
||||||
if (!SPIRVDialect::isValidType(type))
|
if (!SPIRVDialect::isValidType(type))
|
||||||
|
@ -1890,65 +1798,6 @@ static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// spv.IAdd
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
OpFoldResult spirv::IAddOp::fold(ArrayRef<Attribute> operands) {
|
|
||||||
assert(operands.size() == 2 && "spv.IAdd expects two operands");
|
|
||||||
// x + 0 = x
|
|
||||||
if (matchPattern(operand2(), m_Zero()))
|
|
||||||
return operand1();
|
|
||||||
|
|
||||||
// According to the SPIR-V spec:
|
|
||||||
//
|
|
||||||
// The resulting value will equal the low-order N bits of the correct result
|
|
||||||
// R, where N is the component width and R is computed with enough precision
|
|
||||||
// to avoid overflow and underflow.
|
|
||||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
||||||
[](APInt a, APInt b) { return a + b; });
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// spv.IMul
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
OpFoldResult spirv::IMulOp::fold(ArrayRef<Attribute> operands) {
|
|
||||||
assert(operands.size() == 2 && "spv.IMul expects two operands");
|
|
||||||
// x * 0 == 0
|
|
||||||
if (matchPattern(operand2(), m_Zero()))
|
|
||||||
return operand2();
|
|
||||||
// x * 1 = x
|
|
||||||
if (matchPattern(operand2(), m_One()))
|
|
||||||
return operand1();
|
|
||||||
|
|
||||||
// According to the SPIR-V spec:
|
|
||||||
//
|
|
||||||
// The resulting value will equal the low-order N bits of the correct result
|
|
||||||
// R, where N is the component width and R is computed with enough precision
|
|
||||||
// to avoid overflow and underflow.
|
|
||||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
||||||
[](APInt a, APInt b) { return a * b; });
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// spv.ISub
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
OpFoldResult spirv::ISubOp::fold(ArrayRef<Attribute> operands) {
|
|
||||||
// x - x = 0
|
|
||||||
if (operand1() == operand2())
|
|
||||||
return Builder(getContext()).getIntegerAttr(getType(), 0);
|
|
||||||
|
|
||||||
// According to the SPIR-V spec:
|
|
||||||
//
|
|
||||||
// The resulting value will equal the low-order N bits of the correct result
|
|
||||||
// R, where N is the component width and R is computed with enough precision
|
|
||||||
// to avoid overflow and underflow.
|
|
||||||
return constFoldBinaryOp<IntegerAttr>(operands,
|
|
||||||
[](APInt a, APInt b) { return a - b; });
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.LoadOp
|
// spv.LoadOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2008,17 +1857,6 @@ static LogicalResult verify(spirv::LoadOp loadOp) {
|
||||||
return verifyMemoryAccessAttribute(loadOp);
|
return verifyMemoryAccessAttribute(loadOp);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
// spv.LogicalNot
|
|
||||||
//===----------------------------------------------------------------------===//
|
|
||||||
|
|
||||||
void spirv::LogicalNotOp::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<ConvertLogicalNotOfIEqual, ConvertLogicalNotOfINotEqual,
|
|
||||||
ConvertLogicalNotOfLogicalEqual,
|
|
||||||
ConvertLogicalNotOfLogicalNotEqual>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.loop
|
// spv.loop
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -2547,170 +2385,6 @@ spirv::SelectionOp spirv::SelectionOp::createIfThen(
|
||||||
return selectionOp;
|
return selectionOp;
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace {
|
|
||||||
// Blocks from the given `spv.selection` operation must satisfy the following
|
|
||||||
// layout:
|
|
||||||
//
|
|
||||||
// +-----------------------------------------------+
|
|
||||||
// | header block |
|
|
||||||
// | spv.BranchConditionalOp %cond, ^case0, ^case1 |
|
|
||||||
// +-----------------------------------------------+
|
|
||||||
// / \
|
|
||||||
// ...
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// +------------------------+ +------------------------+
|
|
||||||
// | case #0 | | case #1 |
|
|
||||||
// | spv.Store %ptr %value0 | | spv.Store %ptr %value1 |
|
|
||||||
// | spv.Branch ^merge | | spv.Branch ^merge |
|
|
||||||
// +------------------------+ +------------------------+
|
|
||||||
//
|
|
||||||
//
|
|
||||||
// ...
|
|
||||||
// \ /
|
|
||||||
// v
|
|
||||||
// +-------------+
|
|
||||||
// | merge block |
|
|
||||||
// +-------------+
|
|
||||||
//
|
|
||||||
struct ConvertSelectionOpToSelect
|
|
||||||
: public OpRewritePattern<spirv::SelectionOp> {
|
|
||||||
using OpRewritePattern<spirv::SelectionOp>::OpRewritePattern;
|
|
||||||
|
|
||||||
PatternMatchResult matchAndRewrite(spirv::SelectionOp selectionOp,
|
|
||||||
PatternRewriter &rewriter) const override {
|
|
||||||
auto *op = selectionOp.getOperation();
|
|
||||||
auto &body = op->getRegion(0);
|
|
||||||
// Verifier allows an empty region for `spv.selection`.
|
|
||||||
if (body.empty()) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that region consists of 4 blocks:
|
|
||||||
// header block, `true` block, `false` block and merge block.
|
|
||||||
if (std::distance(body.begin(), body.end()) != 4) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto *headerBlock = selectionOp.getHeaderBlock();
|
|
||||||
if (!onlyContainsBranchConditionalOp(headerBlock)) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto brConditionalOp =
|
|
||||||
cast<spirv::BranchConditionalOp>(headerBlock->front());
|
|
||||||
|
|
||||||
auto *trueBlock = brConditionalOp.getSuccessor(0);
|
|
||||||
auto *falseBlock = brConditionalOp.getSuccessor(1);
|
|
||||||
auto *mergeBlock = selectionOp.getMergeBlock();
|
|
||||||
|
|
||||||
if (!canCanonicalizeSelection(trueBlock, falseBlock, mergeBlock)) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto trueValue = getSrcValue(trueBlock);
|
|
||||||
auto falseValue = getSrcValue(falseBlock);
|
|
||||||
auto ptrValue = getDstPtr(trueBlock);
|
|
||||||
auto storeOpAttributes =
|
|
||||||
cast<spirv::StoreOp>(trueBlock->front()).getOperation()->getAttrs();
|
|
||||||
|
|
||||||
auto selectOp = rewriter.create<spirv::SelectOp>(
|
|
||||||
selectionOp.getLoc(), trueValue.getType(), brConditionalOp.condition(),
|
|
||||||
trueValue, falseValue);
|
|
||||||
rewriter.create<spirv::StoreOp>(selectOp.getLoc(), ptrValue,
|
|
||||||
selectOp.getResult(), storeOpAttributes);
|
|
||||||
|
|
||||||
// `spv.selection` is not needed anymore.
|
|
||||||
rewriter.eraseOp(op);
|
|
||||||
return matchSuccess();
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
// Checks that given blocks follow the following rules:
|
|
||||||
// 1. Each conditional block consists of two operations, the first operation
|
|
||||||
// is a `spv.Store` and the last operation is a `spv.Branch`.
|
|
||||||
// 2. Each `spv.Store` uses the same pointer and the same memory attributes.
|
|
||||||
// 3. A control flow goes into the given merge block from the given
|
|
||||||
// conditional blocks.
|
|
||||||
PatternMatchResult canCanonicalizeSelection(Block *trueBlock,
|
|
||||||
Block *falseBlock,
|
|
||||||
Block *mergeBlock) const;
|
|
||||||
|
|
||||||
bool onlyContainsBranchConditionalOp(Block *block) const {
|
|
||||||
return std::next(block->begin()) == block->end() &&
|
|
||||||
isa<spirv::BranchConditionalOp>(block->front());
|
|
||||||
}
|
|
||||||
|
|
||||||
bool isSameAttrList(spirv::StoreOp lhs, spirv::StoreOp rhs) const {
|
|
||||||
return lhs.getOperation()->getAttrList().getDictionary() ==
|
|
||||||
rhs.getOperation()->getAttrList().getDictionary();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Checks that given type is valid for `spv.SelectOp`.
|
|
||||||
// According to SPIR-V spec:
|
|
||||||
// "Before version 1.4, Result Type must be a pointer, scalar, or vector.
|
|
||||||
// Starting with version 1.4, Result Type can additionally be a composite type
|
|
||||||
// other than a vector."
|
|
||||||
bool isValidType(Type type) const {
|
|
||||||
return spirv::SPIRVDialect::isValidScalarType(type) ||
|
|
||||||
type.isa<VectorType>();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a source value for the given block.
|
|
||||||
Value getSrcValue(Block *block) const {
|
|
||||||
auto storeOp = cast<spirv::StoreOp>(block->front());
|
|
||||||
return storeOp.value();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Returns a destination value for the given block.
|
|
||||||
Value getDstPtr(Block *block) const {
|
|
||||||
auto storeOp = cast<spirv::StoreOp>(block->front());
|
|
||||||
return storeOp.ptr();
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
PatternMatchResult ConvertSelectionOpToSelect::canCanonicalizeSelection(
|
|
||||||
Block *trueBlock, Block *falseBlock, Block *mergeBlock) const {
|
|
||||||
// Each block must consists of 2 operations.
|
|
||||||
if ((std::distance(trueBlock->begin(), trueBlock->end()) != 2) ||
|
|
||||||
(std::distance(falseBlock->begin(), falseBlock->end()) != 2)) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
auto trueBrStoreOp = dyn_cast<spirv::StoreOp>(trueBlock->front());
|
|
||||||
auto trueBrBranchOp =
|
|
||||||
dyn_cast<spirv::BranchOp>(*std::next(trueBlock->begin()));
|
|
||||||
auto falseBrStoreOp = dyn_cast<spirv::StoreOp>(falseBlock->front());
|
|
||||||
auto falseBrBranchOp =
|
|
||||||
dyn_cast<spirv::BranchOp>(*std::next(falseBlock->begin()));
|
|
||||||
|
|
||||||
if (!trueBrStoreOp || !trueBrBranchOp || !falseBrStoreOp ||
|
|
||||||
!falseBrBranchOp) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check that each `spv.Store` uses the same pointer, memory access
|
|
||||||
// attributes and a valid type of the value.
|
|
||||||
if ((trueBrStoreOp.ptr() != falseBrStoreOp.ptr()) ||
|
|
||||||
!isSameAttrList(trueBrStoreOp, falseBrStoreOp) ||
|
|
||||||
!isValidType(trueBrStoreOp.value().getType())) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((trueBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock) ||
|
|
||||||
(falseBrBranchOp.getOperation()->getSuccessor(0) != mergeBlock)) {
|
|
||||||
return matchFailure();
|
|
||||||
}
|
|
||||||
|
|
||||||
return matchSuccess();
|
|
||||||
}
|
|
||||||
} // end anonymous namespace
|
|
||||||
|
|
||||||
void spirv::SelectionOp::getCanonicalizationPatterns(
|
|
||||||
OwningRewritePatternList &results, MLIRContext *context) {
|
|
||||||
results.insert<ConvertSelectionOpToSelect>(context);
|
|
||||||
}
|
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// spv.specConstant
|
// spv.specConstant
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
Loading…
Reference in New Issue