forked from OSchip/llvm-project
[mlir][spirv] Use SingleBlock + NoTerminator for spv.module
This allows us to remove the `spv.mlir.endmodule` op and all the code associated with it. Along the way, tightened the APIs for `spv.module` a bit by removing some aliases. Now we use `getRegion` to get the only region, and `getBody` to get the region's only block. Reviewed By: mravishankar, hanchung Differential Revision: https://reviews.llvm.org/D103265
This commit is contained in:
parent
64b2fb7967
commit
56f60a1ce7
|
@ -92,8 +92,8 @@ The SPIR-V dialect adopts the following conventions for IR:
|
|||
(de)serialization.
|
||||
* Ops with `mlir.snake_case` names are those that have no corresponding
|
||||
instructions (or concepts) in the binary format. They are introduced to
|
||||
satisfy MLIR structural requirements. For example, `spv.mlir.endmodule` and
|
||||
`spv.mlir.merge`. They map to no instructions during (de)serialization.
|
||||
satisfy MLIR structural requirements. For example, `spv.mlir.merge`. They
|
||||
map to no instructions during (de)serialization.
|
||||
|
||||
(TODO: consider merging the last two cases and adopting `spv.mlir.` prefix for
|
||||
them.)
|
||||
|
|
|
@ -810,8 +810,6 @@ Module in SPIR-V has one region that contains one block. It is defined via
|
|||
`spv.module` is converted into `ModuleOp`. This plays a role of enclosing scope
|
||||
to LLVM ops. At the moment, SPIR-V module attributes are ignored.
|
||||
|
||||
`spv.mlir.endmodule` is mapped to an equivalent terminator `ModuleTerminatorOp`.
|
||||
|
||||
## `mlir-spirv-cpu-runner`
|
||||
|
||||
`mlir-spirv-cpu-runner` allows to execute `gpu` dialect kernel on the CPU via
|
||||
|
|
|
@ -407,9 +407,8 @@ def SPV_GlobalVariableOp : SPV_Op<"GlobalVariable", [InModuleScope, Symbol]> {
|
|||
// -----
|
||||
|
||||
def SPV_ModuleOp : SPV_Op<"module",
|
||||
[IsolatedFromAbove,
|
||||
SingleBlockImplicitTerminator<"ModuleEndOp">,
|
||||
SymbolTable, Symbol]> {
|
||||
[IsolatedFromAbove, NoRegionArguments, NoTerminator,
|
||||
SingleBlock, SymbolTable, Symbol]> {
|
||||
let summary = "The top-level op that defines a SPIR-V module";
|
||||
|
||||
let description = [{
|
||||
|
@ -426,7 +425,7 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
implicitly capture values from the enclosing environment.
|
||||
|
||||
This op has only one region, which only contains one block. The block
|
||||
must be terminated via the `spv.mlir.endmodule` op.
|
||||
has no terminator.
|
||||
|
||||
<!-- End of AutoGen section -->
|
||||
|
||||
|
@ -463,7 +462,7 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
|
||||
let results = (outs);
|
||||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
let regions = (region AnyRegion);
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins CArg<"Optional<StringRef>", "llvm::None">:$name)>,
|
||||
|
@ -487,40 +486,11 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
Optional<StringRef> getName() { return sym_name(); }
|
||||
|
||||
static StringRef getVCETripleAttrName() { return "vce_triple"; }
|
||||
|
||||
Block& getBlock() {
|
||||
return this->getOperation()->getRegion(0).front();
|
||||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_ModuleEndOp : SPV_Op<"mlir.endmodule", [InModuleScope, Terminator]> {
|
||||
let summary = "The pseudo op that ends a SPIR-V module";
|
||||
|
||||
let description = [{
|
||||
This op terminates the only block inside a `spv.module`'s only region.
|
||||
This op does not have a corresponding SPIR-V instruction and thus will
|
||||
not be serialized into the binary format; it is used solely to satisfy
|
||||
the structual requirement that an block must be ended with a terminator.
|
||||
}];
|
||||
|
||||
let arguments = (ins);
|
||||
|
||||
let results = (outs);
|
||||
|
||||
let assemblyFormat = "attr-dict";
|
||||
|
||||
let verifier = [{ return success(); }];
|
||||
|
||||
let hasOpcode = 0;
|
||||
|
||||
let autogenSerialization = 0;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
def SPV_ReferenceOfOp : SPV_Op<"mlir.referenceof", [NoSideEffect]> {
|
||||
let summary = "Reference a specialization constant.";
|
||||
|
||||
|
|
|
@ -1,14 +1,9 @@
|
|||
set(LLVM_TARGET_DEFINITIONS GPUToSPIRV.td)
|
||||
mlir_tablegen(GPUToSPIRV.cpp.inc -gen-rewriters)
|
||||
add_public_tablegen_target(MLIRGPUToSPIRVIncGen)
|
||||
|
||||
add_mlir_conversion_library(MLIRGPUToSPIRV
|
||||
GPUToSPIRV.cpp
|
||||
GPUToSPIRVPass.cpp
|
||||
|
||||
DEPENDS
|
||||
MLIRConversionPassIncGen
|
||||
MLIRGPUToSPIRVIncGen
|
||||
|
||||
LINK_LIBS PUBLIC
|
||||
MLIRGPU
|
||||
|
|
|
@ -85,6 +85,19 @@ public:
|
|||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
class GPUModuleEndConversion final
|
||||
: public OpConversionPattern<gpu::ModuleEndOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(gpu::ModuleEndOp endOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.eraseOp(endOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
/// Pattern to convert a gpu.return into a SPIR-V return.
|
||||
// TODO: This can go to DRR when GPU return has operands.
|
||||
class GPUReturnOpConversion final : public OpConversionPattern<gpu::ReturnOp> {
|
||||
|
@ -301,12 +314,10 @@ LogicalResult GPUModuleConversion::matchAndRewrite(
|
|||
StringRef(spvModuleName));
|
||||
|
||||
// Move the region from the module op into the SPIR-V module.
|
||||
Region &spvModuleRegion = spvModule.body();
|
||||
Region &spvModuleRegion = spvModule.getRegion();
|
||||
rewriter.inlineRegionBefore(moduleOp.body(), spvModuleRegion,
|
||||
spvModuleRegion.begin());
|
||||
// The spv.module build method adds a block with a terminator. Remove that
|
||||
// block. The terminator of the module op in the remaining block will be
|
||||
// legalized later.
|
||||
// The spv.module build method adds a block. Remove that.
|
||||
rewriter.eraseBlock(&spvModuleRegion.back());
|
||||
rewriter.eraseOp(moduleOp);
|
||||
return success();
|
||||
|
@ -330,15 +341,11 @@ LogicalResult GPUReturnOpConversion::matchAndRewrite(
|
|||
// GPU To SPIRV Patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
#include "GPUToSPIRV.cpp.inc"
|
||||
}
|
||||
|
||||
void mlir::populateGPUToSPIRVPatterns(SPIRVTypeConverter &typeConverter,
|
||||
RewritePatternSet &patterns) {
|
||||
populateWithGenerated(patterns);
|
||||
patterns.add<
|
||||
GPUFuncOpConversion, GPUModuleConversion, GPUReturnOpConversion,
|
||||
GPUFuncOpConversion, GPUModuleConversion, GPUModuleEndConversion,
|
||||
GPUReturnOpConversion,
|
||||
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
|
||||
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
|
||||
LaunchConfigConversion<gpu::ThreadIdOp,
|
||||
|
|
|
@ -1,22 +0,0 @@
|
|||
//===-- GPUToSPIRV.td - GPU to SPIR-V Dialect Lowerings ----*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file contains patterns to lower GPU dialect ops to to SPIR-V ops.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
|
||||
#ifndef MLIR_CONVERSION_GPU_TO_SPIRV
|
||||
#define MLIR_CONVERSION_GPU_TO_SPIRV
|
||||
|
||||
include "mlir/Dialect/GPU/GPUOps.td"
|
||||
include "mlir/Dialect/SPIRV/IR/SPIRVStructureOps.td"
|
||||
|
||||
def : Pat<(GPU_ModuleEndOp), (SPV_ModuleEndOp)>;
|
||||
|
||||
#endif // MLIR_CONVERSION_GPU_TO_SPIRV
|
|
@ -1342,7 +1342,7 @@ public:
|
|||
|
||||
auto newModuleOp =
|
||||
rewriter.create<ModuleOp>(spvModuleOp.getLoc(), spvModuleOp.getName());
|
||||
rewriter.inlineRegionBefore(spvModuleOp.body(), newModuleOp.getBody());
|
||||
rewriter.inlineRegionBefore(spvModuleOp.getRegion(), newModuleOp.getBody());
|
||||
|
||||
// Remove the terminator block that was automatically added by builder
|
||||
rewriter.eraseBlock(&newModuleOp.getBodyRegion().back());
|
||||
|
@ -1351,20 +1351,6 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
class ModuleEndConversionPattern
|
||||
: public SPIRVToLLVMConversion<spirv::ModuleEndOp> {
|
||||
public:
|
||||
using SPIRVToLLVMConversion<spirv::ModuleEndOp>::SPIRVToLLVMConversion;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(spirv::ModuleEndOp moduleEndOp, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
|
||||
rewriter.eraseOp(moduleEndOp);
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1507,8 +1493,7 @@ void mlir::populateSPIRVToLLVMFunctionConversionPatterns(
|
|||
|
||||
void mlir::populateSPIRVToLLVMModuleConversionPatterns(
|
||||
LLVMTypeConverter &typeConverter, RewritePatternSet &patterns) {
|
||||
patterns.add<ModuleConversionPattern, ModuleEndConversionPattern>(
|
||||
patterns.getContext(), typeConverter);
|
||||
patterns.add<ModuleConversionPattern>(patterns.getContext(), typeConverter);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -2529,7 +2529,8 @@ static LogicalResult verify(spirv::MergeOp mergeOp) {
|
|||
|
||||
void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
||||
Optional<StringRef> name) {
|
||||
ensureTerminator(*state.addRegion(), builder, state.location);
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(state.addRegion());
|
||||
if (name) {
|
||||
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
|
||||
builder.getStringAttr(*name));
|
||||
|
@ -2545,7 +2546,8 @@ void spirv::ModuleOp::build(OpBuilder &builder, OperationState &state,
|
|||
builder.getI32IntegerAttr(static_cast<int32_t>(addressingModel)));
|
||||
state.addAttribute("memory_model", builder.getI32IntegerAttr(
|
||||
static_cast<int32_t>(memoryModel)));
|
||||
ensureTerminator(*state.addRegion(), builder, state.location);
|
||||
OpBuilder::InsertionGuard guard(builder);
|
||||
builder.createBlock(state.addRegion());
|
||||
if (name) {
|
||||
state.attributes.append(mlir::SymbolTable::getSymbolAttrName(),
|
||||
builder.getStringAttr(*name));
|
||||
|
@ -2581,7 +2583,10 @@ static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
|||
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
|
||||
return failure();
|
||||
|
||||
spirv::ModuleOp::ensureTerminator(*body, parser.getBuilder(), state.location);
|
||||
// Make sure we have at least one block.
|
||||
if (body->empty())
|
||||
body->push_back(new Block());
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -2608,8 +2613,7 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) {
|
|||
}
|
||||
|
||||
printer.printOptionalAttrDictWithKeyword(moduleOp->getAttrs(), elidedAttrs);
|
||||
printer.printRegion(moduleOp.body(), /*printEntryBlockArgs=*/false,
|
||||
/*printBlockTerminators=*/false);
|
||||
printer.printRegion(moduleOp.getRegion());
|
||||
}
|
||||
|
||||
static LogicalResult verify(spirv::ModuleOp moduleOp) {
|
||||
|
@ -2619,7 +2623,7 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) {
|
|||
entryPoints;
|
||||
SymbolTable table(moduleOp);
|
||||
|
||||
for (auto &op : moduleOp.getBlock()) {
|
||||
for (auto &op : *moduleOp.getBody()) {
|
||||
if (op.getDialect() != dialect)
|
||||
return op.emitError("'spv.module' can only contain spv.* ops");
|
||||
|
||||
|
|
|
@ -134,7 +134,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
|||
|
||||
auto combinedModule = combinedModuleBuilder.create<spirv::ModuleOp>(
|
||||
modules[0].getLoc(), addressingModel, memoryModel);
|
||||
combinedModuleBuilder.setInsertionPointToStart(&*combinedModule.getBody());
|
||||
combinedModuleBuilder.setInsertionPointToStart(combinedModule.getBody());
|
||||
|
||||
// In some cases, a symbol in the (current state of the) combined module is
|
||||
// renamed in order to maintain the conflicting symbol in the input module
|
||||
|
@ -160,7 +160,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
|||
// for spv.funcs. This way, if the conflicting op in the input module is
|
||||
// non-spv.func, we rename that symbol instead and maintain the spv.func in
|
||||
// the combined module name as it is.
|
||||
for (auto &op : combinedModule.getBlock().without_terminator()) {
|
||||
for (auto &op : *combinedModule.getBody()) {
|
||||
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
|
||||
StringRef oldSymName = symbolOp.getName();
|
||||
|
||||
|
@ -195,7 +195,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
|||
|
||||
// In the current input module, rename all symbols that conflict with
|
||||
// symbols from the combined module. This includes renaming spv.funcs.
|
||||
for (auto &op : moduleClone.getBlock().without_terminator()) {
|
||||
for (auto &op : *moduleClone.getBody()) {
|
||||
if (auto symbolOp = dyn_cast<SymbolOpInterface>(op)) {
|
||||
StringRef oldSymName = symbolOp.getName();
|
||||
|
||||
|
@ -225,7 +225,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
|||
}
|
||||
|
||||
// Clone all the module's ops to the combined module.
|
||||
for (auto &op : moduleClone.getBlock().without_terminator())
|
||||
for (auto &op : *moduleClone.getBody())
|
||||
combinedModuleBuilder.insert(op.clone());
|
||||
}
|
||||
|
||||
|
@ -233,7 +233,7 @@ combine(llvm::MutableArrayRef<spirv::ModuleOp> modules,
|
|||
DenseMap<llvm::hash_code, SymbolOpInterface> hashToSymbolOp;
|
||||
SmallVector<SymbolOpInterface, 0> eraseList;
|
||||
|
||||
for (auto &op : combinedModule.getBlock().without_terminator()) {
|
||||
for (auto &op : *combinedModule.getBody()) {
|
||||
llvm::hash_code hashCode(0);
|
||||
SymbolOpInterface symbolOp = dyn_cast<SymbolOpInterface>(op);
|
||||
|
||||
|
|
|
@ -115,7 +115,7 @@ static LogicalResult lowerEntryPointABIAttr(spirv::FuncOp funcOp,
|
|||
|
||||
OpBuilder::InsertionGuard moduleInsertionGuard(builder);
|
||||
auto spirvModule = funcOp->getParentOfType<spirv::ModuleOp>();
|
||||
builder.setInsertionPoint(spirvModule.body().front().getTerminator());
|
||||
builder.setInsertionPointToEnd(spirvModule.getBody());
|
||||
|
||||
// Adds the spv.EntryPointOp after collecting all the interface variables
|
||||
// needed.
|
||||
|
|
|
@ -51,7 +51,7 @@ static inline bool isFnEntryBlock(Block *block) {
|
|||
spirv::Deserializer::Deserializer(ArrayRef<uint32_t> binary,
|
||||
MLIRContext *context)
|
||||
: binary(binary), context(context), unknownLoc(UnknownLoc::get(context)),
|
||||
module(createModuleOp()), opBuilder(module->body()) {}
|
||||
module(createModuleOp()), opBuilder(module->getRegion()) {}
|
||||
|
||||
LogicalResult spirv::Deserializer::deserialize() {
|
||||
LLVM_DEBUG(llvm::dbgs() << "+++ starting deserialization +++\n");
|
||||
|
|
|
@ -99,7 +99,7 @@ LogicalResult Serializer::serialize() {
|
|||
|
||||
// Iterate over the module body to serialize it. Assumptions are that there is
|
||||
// only one basic block in the moduleOp
|
||||
for (auto &op : module.getBlock()) {
|
||||
for (auto &op : *module.getBody()) {
|
||||
if (failed(processOperation(&op))) {
|
||||
return failure();
|
||||
}
|
||||
|
@ -1090,7 +1090,6 @@ LogicalResult Serializer::processOperation(Operation *opInst) {
|
|||
return processGlobalVariableOp(op);
|
||||
})
|
||||
.Case([&](spirv::LoopOp op) { return processLoopOp(op); })
|
||||
.Case([&](spirv::ModuleEndOp) { return success(); })
|
||||
.Case([&](spirv::ReferenceOfOp op) { return processReferenceOfOp(op); })
|
||||
.Case([&](spirv::SelectionOp op) { return processSelectionOp(op); })
|
||||
.Case([&](spirv::SpecConstantOp op) { return processSpecConstantOp(op); })
|
||||
|
|
|
@ -13,12 +13,6 @@ spv.module @foo Logical GLSL450 {}
|
|||
// CHECK: module
|
||||
spv.module Logical GLSL450 requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]> {}
|
||||
|
||||
// CHECK: module
|
||||
spv.module Logical GLSL450 {
|
||||
// CHECK: }
|
||||
spv.mlir.endmodule
|
||||
}
|
||||
|
||||
// CHECK: module
|
||||
spv.module Logical GLSL450 {
|
||||
// CHECK-LABEL: llvm.func @empty()
|
||||
|
|
|
@ -425,12 +425,6 @@ spv.module Logical GLSL450
|
|||
requires #spv.vce<v1.0, [Shader], [SPV_KHR_16bit_storage]>
|
||||
attributes {foo = "bar"} { }
|
||||
|
||||
// Module with explicit spv.mlir.endmodule
|
||||
// CHECK: spv.module
|
||||
spv.module Logical GLSL450 {
|
||||
spv.mlir.endmodule
|
||||
}
|
||||
|
||||
// Module with function
|
||||
// CHECK: spv.module
|
||||
spv.module Logical GLSL450 {
|
||||
|
@ -476,15 +470,6 @@ spv.module Logical GLSL450 {
|
|||
|
||||
// -----
|
||||
|
||||
// Module with wrong terminator
|
||||
// expected-error@+2 {{expects regions to end with 'spv.mlir.endmodule'}}
|
||||
// expected-note@+1 {{in custom textual format, the absence of terminator implies 'spv.mlir.endmodule'}}
|
||||
"spv.module"() ({
|
||||
%0 = spv.Constant true
|
||||
}) {addressing_model = 0 : i32, memory_model = 1 : i32} : () -> ()
|
||||
|
||||
// -----
|
||||
|
||||
// Use non SPIR-V op inside module
|
||||
spv.module Logical GLSL450 {
|
||||
// expected-error @+1 {{'spv.module' can only contain spv.* ops}}
|
||||
|
@ -511,17 +496,6 @@ spv.module Logical GLSL450 {
|
|||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.mlir.endmodule
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @module_end_not_in_module() -> () {
|
||||
// expected-error @+1 {{op must appear in a module-like op's block}}
|
||||
spv.mlir.endmodule
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// spv.mlir.referenceof
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -59,7 +59,7 @@ protected:
|
|||
}
|
||||
|
||||
Type getFloatStructType() {
|
||||
OpBuilder opBuilder(module->body());
|
||||
OpBuilder opBuilder(module->getRegion());
|
||||
llvm::SmallVector<Type, 1> elementTypes{opBuilder.getF32Type()};
|
||||
llvm::SmallVector<spirv::StructType::OffsetInfo, 1> offsetInfo{0};
|
||||
auto structType = spirv::StructType::get(elementTypes, offsetInfo);
|
||||
|
@ -67,7 +67,7 @@ protected:
|
|||
}
|
||||
|
||||
void addGlobalVar(Type type, llvm::StringRef name) {
|
||||
OpBuilder opBuilder(module->body());
|
||||
OpBuilder opBuilder(module->getRegion());
|
||||
auto ptrType = spirv::PointerType::get(type, spirv::StorageClass::Uniform);
|
||||
opBuilder.create<spirv::GlobalVariableOp>(
|
||||
UnknownLoc::get(&context), TypeAttr::get(ptrType),
|
||||
|
|
Loading…
Reference in New Issue