[mlir][spirv] Use SingleBlock + NoTerminator for spv.module

This allows us to remove the `spv.mlir.endmodule` op and
all the code associated with it.

Along the way, tightened the APIs for `spv.module` a bit
by removing some aliases. Now we use `getRegion` to get
the only region, and `getBody` to get the region's only
block.

Reviewed By: mravishankar, hanchung

Differential Revision: https://reviews.llvm.org/D103265
This commit is contained in:
Lei Zhang 2021-06-09 13:58:13 -04:00
parent 64b2fb7967
commit 56f60a1ce7
15 changed files with 45 additions and 141 deletions

View File

@ -92,8 +92,8 @@ The SPIR-V dialect adopts the following conventions for IR:
(de)serialization.
* Ops with `mlir.snake_case` names are those that have no corresponding
instructions (or concepts) in the binary format. They are introduced to
satisfy MLIR structural requirements. For example, `spv.mlir.endmodule` and
`spv.mlir.merge`. They map to no instructions during (de)serialization.
satisfy MLIR structural requirements. For example, `spv.mlir.merge`. They
map to no instructions during (de)serialization.
(TODO: consider merging the last two cases and adopting `spv.mlir.` prefix for
them.)

View File

@ -810,8 +810,6 @@ Module in SPIR-V has one region that contains one block. It is defined via
`spv.module` is converted into `ModuleOp`. This plays a role of enclosing scope
to LLVM ops. At the moment, SPIR-V module attributes are ignored.
`spv.mlir.endmodule` is mapped to an equivalent terminator `ModuleTerminatorOp`.
## `mlir-spirv-cpu-runner`
`mlir-spirv-cpu-runner` allows to execute `gpu` dialect kernel on the CPU via

View File

@ -407,9 +407,8 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
// -----
def SPV_ModuleOp : SPV_Op<"module",
[IsolatedFromAbove,
SingleBlockImplicitTerminator<"ModuleEndOp">,
SymbolTable, Symbol]> {
[IsolatedFromAbove, NoRegionArguments, NoTerminator,
SingleBlock, SymbolTable, Symbol]> {
let summary = "The top-level op that defines a SPIR-V module";
let description = [{
@ -426,7 +425,7 @@ def SPV_ModuleOp : SPV_Op<"module",
implicitly capture values from the enclosing environment.
This op has only one region, which only contains one block. The block
must be terminated via the `spv.mlir.endmodule` op.
has no terminator.
<!-- End of AutoGen section -->
@ -463,7 +462,7 @@ def SPV_ModuleOp : SPV_Op<"module",
let results = (outs);
let regions = (region SizedRegion<1>:$body);
let regions = (region AnyRegion);
let builders = [
OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
@ -487,40 +486,11 @@ def SPV_ModuleOp : SPV_Op<"module",
Optional<StringRef> getName() { return sym_name(); }
static StringRef getVCETripleAttrName() { return "vce_triple"; }
Block& getBlock() {
return this->getOperation()->getRegion(0).front();
}
}];
}
// -----
def SPV_ModuleEndOp : SPV_Op<"mlir.endmodule", [InModuleScope, Terminator]> {
let summary = "The pseudo op that ends a SPIR-V module";
let description = [{
This op terminates the only block inside a `spv.module`'s only region.
This op does not have a corresponding SPIR-V instruction and thus will
not be serialized into the binary format; it is used solely to satisfy
the structual requirement that an block must be ended with a terminator.
}];
let arguments = (ins);
let results = (outs);
let assemblyFormat = "attr-dict";
let verifier = [{ return success(); }];
let hasOpcode = 0;
let autogenSerialization = 0;
}
// -----
def SPV_ReferenceOfOp : SPV_Op<"mlir.referenceof", [NoSideEffect]> {
let summary = "Reference a specialization constant.";

View File

@ -1,14 +1,9 @@
set(LLVM_TARGET_DEFINITIONS GPUToSPIRV.td)
mlir_tablegen(GPUToSPIRV.cpp.inc -gen-rewriters)
add_public_tablegen_target(MLIRGPUToSPIRVIncGen)
add_mlir_conversion_library(MLIRGPUToSPIRV
GPUToSPIRV.cpp
GPUToSPIRVPass.cpp
DEPENDS
MLIRConversionPassIncGen
MLIRGPUToSPIRVIncGen
LINK_LIBS PUBLIC
MLIRGPU

View File

@ -85,6 +85,19 @@ public:
ConversionPatternRewriter &rewriter) const override;
};
class GPUModuleEndConversion final
: public OpConversionPattern<gpu::ModuleEndOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(endOp);
return success();
}
};
/// Pattern to convert a gpu.return into a SPIR-V return.
// TODO: This can go to DRR when GPU return has operands.
class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
@ -301,12 +314,10 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
StringRef(spvModuleName));
// Move the region from the module op into the SPIR-V module.
Region &spvModuleRegion = spvModule.body();
Region &spvModuleRegion = spvModule.getRegion();
rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion,
spvModuleRegion.begin());
// The spv.module build method adds a block with a terminator. Remove that
// block. The terminator of the module op in the remaining block will be
// legalized later.
// The spv.module build method adds a block. Remove that.
rewriter.eraseBlock(&spvModuleRegion.back());
rewriter.eraseOp(moduleOp);
return success();
@ -330,15 +341,11 @@ LogicalResult GPUReturnOpConversion::matchAndRewrite(
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
namespace {
#include "GPUToSPIRV.cpp.inc"
}
void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
RewritePatternSet &patterns) {
populateWithGenerated(patterns);
patterns.add<
GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion,
GPUReturnOpConversion,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
LaunchConfigConversion<gpu::ThreadIdOp,

View File

@ -1,22 +0,0 @@
//===-- GPUToSPIRV.td - GPU to SPIR-V Dialect Lowerings ----*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file contains patterns to lower GPU dialect ops to to SPIR-V ops.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_CONVERSION_GPU_TO_SPIRV
#define MLIR_CONVERSION_GPU_TO_SPIRV
include "mlir/Dialect/GPU/GPUOps.td"
include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
def : Pat<(GPU_ModuleEndOp), (SPV_ModuleEndOp)>;
#endif // MLIR_CONVERSION_GPU_TO_SPIRV

View File

@ -1342,7 +1342,7 @@ public:
auto newModuleOp =
rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
// Remove the terminator block that was automatically added by builder
rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
@ -1351,20 +1351,6 @@ public:
}
};
class ModuleEndConversionPattern
: public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
public:
using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
LogicalResult
matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.eraseOp(moduleEndOp);
return success();
}
};
} // namespace
//===----------------------------------------------------------------------===//
@ -1507,8 +1493,7 @@ void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
void mlir::populateSPIRVToLLVMModuleConversionPatterns(
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
patterns.add<ModuleConversionPattern, ModuleEndConversionPattern>(
patterns.getContext(), typeConverter);
patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
}
//===----------------------------------------------------------------------===//

View File

@ -2529,7 +2529,8 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
Optional<StringRef> name) {
ensureTerminator(*state.addRegion(), builder, state.location);
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (name) {
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
@ -2545,7 +2546,8 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
state.addAttribute("memory_model", builder.getI32IntegerAttr(
static_cast<int32_t>(memoryModel)));
ensureTerminator(*state.addRegion(), builder, state.location);
OpBuilder::InsertionGuard guard(builder);
builder.createBlock(state.addRegion());
if (name) {
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
builder.getStringAttr(*name));
@ -2581,7 +2583,10 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
return failure();
spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
// Make sure we have at least one block.
if (body->empty())
body->push_back(new Block());
return success();
}
@ -2608,8 +2613,7 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
}
printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs);
printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
/*printBlockTerminators=*/false);
printer.printRegion(moduleOp.getRegion());
}
static LogicalResult verify(spirv::ModuleOp moduleOp) {
@ -2619,7 +2623,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
entryPoints;
SymbolTable table(moduleOp);
for (auto &op : moduleOp.getBlock()) {
for (auto &op : *moduleOp.getBody()) {
if (op.getDialect() != dialect)
return op.emitError("'spv.module' can only contain spv.* ops");

View File

@ -134,7 +134,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
modules[0].getLoc(), addressingModel, memoryModel);
combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
// In some cases, a symbol in the (current state of the) combined module is
// renamed in order to maintain the conflicting symbol in the input module
@ -160,7 +160,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
// for spv.funcs. This way, if the conflicting op in the input module is
// non-spv.func, we rename that symbol instead and maintain the spv.func in
// the combined module name as it is.
for (auto &op : combinedModule.getBlock().without_terminator()) {
for (auto &op : *combinedModule.getBody()) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
StringRef oldSymName = symbolOp.getName();
@ -195,7 +195,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
// In the current input module, rename all symbols that conflict with
// symbols from the combined module. This includes renaming spv.funcs.
for (auto &op : moduleClone.getBlock().without_terminator()) {
for (auto &op : *moduleClone.getBody()) {
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
StringRef oldSymName = symbolOp.getName();
@ -225,7 +225,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
}
// Clone all the module's ops to the combined module.
for (auto &op : moduleClone.getBlock().without_terminator())
for (auto &op : *moduleClone.getBody())
combinedModuleBuilder.insert(op.clone());
}
@ -233,7 +233,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
SmallVector<SymbolOpInterface, 0> eraseList;
for (auto &op : combinedModule.getBlock().without_terminator()) {
for (auto &op : *combinedModule.getBody()) {
llvm::hash_code hashCode(0);
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);

View File

@ -115,7 +115,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
OpBuilder::InsertionGuard moduleInsertionGuard(builder);
auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
builder.setInsertionPoint(spirvModule.body().front().getTerminator());
builder.setInsertionPointToEnd(spirvModule.getBody());
// Adds the spv.EntryPointOp after collecting all the interface variables
// needed.

View File

@ -51,7 +51,7 @@ static inline bool isFnEntryBlock(Block *block) {
spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
MLIRContext *context)
: binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
module(createModuleOp()), opBuilder(module->body()) {}
module(createModuleOp()), opBuilder(module->getRegion()) {}
LogicalResult spirv::Deserializer::deserialize() {
LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");

View File

@ -99,7 +99,7 @@ LogicalResult Serializer::serialize() {
// Iterate over the module body to serialize it. Assumptions are that there is
// only one basic block in the moduleOp
for (auto &op : module.getBlock()) {
for (auto &op : *module.getBody()) {
if (failed(processOperation(&op))) {
return failure();
}
@ -1090,7 +1090,6 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
return processGlobalVariableOp(op);
})
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
.Case([&](spirv::ModuleEndOp) { return success(); })
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })

View File

@ -13,12 +13,6 @@ spv.module @foo Logical GLSL450 {}
// CHECK: module
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]> {}
// CHECK: module
spv.module Logical GLSL450 {
// CHECK: }
spv.mlir.endmodule
}
// CHECK: module
spv.module Logical GLSL450 {
// CHECK-LABEL: llvm.func @empty()

View File

@ -425,12 +425,6 @@ spv.module Logical GLSL450
requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]>
attributes {foo = "bar"} { }
// Module with explicit spv.mlir.endmodule
// CHECK: spv.module
spv.module Logical GLSL450 {
spv.mlir.endmodule
}
// Module with function
// CHECK: spv.module
spv.module Logical GLSL450 {
@ -476,15 +470,6 @@ spv.module Logical GLSL450 {
// -----
// Module with wrong terminator
// expected-error@+2 {{expects regions to end with 'spv.mlir.endmodule'}}
// expected-note@+1 {{in custom textual format, the absence of terminator implies 'spv.mlir.endmodule'}}
"spv.module"() ({
%0 = spv.Constant true
}) {addressing_model = 0 : i32, memory_model = 1 : i32} : () -> ()
// -----
// Use non SPIR-V op inside module
spv.module Logical GLSL450 {
// expected-error @+1 {{'spv.module' can only contain spv.* ops}}
@ -511,17 +496,6 @@ spv.module Logical GLSL450 {
// -----
//===----------------------------------------------------------------------===//
// spv.mlir.endmodule
//===----------------------------------------------------------------------===//
func @module_end_not_in_module() -> () {
// expected-error @+1 {{op must appear in a module-like op's block}}
spv.mlir.endmodule
}
// -----
//===----------------------------------------------------------------------===//
// spv.mlir.referenceof
//===----------------------------------------------------------------------===//

View File

@ -59,7 +59,7 @@ protected:
}
Type getFloatStructType() {
OpBuilder opBuilder(module->body());
OpBuilder opBuilder(module->getRegion());
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
auto structType = spirv::StructType::get(elementTypes, offsetInfo);
@ -67,7 +67,7 @@ protected:
}
void addGlobalVar(Type type, llvm::StringRef name) {
OpBuilder opBuilder(module->body());
OpBuilder opBuilder(module->getRegion());
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
opBuilder.create<spirv::GlobalVariableOp>(
UnknownLoc::get(&context), TypeAttr::get(ptrType),