forked from OSchip/llvm-project
Add lowering for module with gpu.kernel_module attribute.
The existing GPU to SPIR-V lowering created a spv.module for every function with gpu.kernel attribute. A better approach is to lower the module that the function lives in (which has the attribute gpu.kernel_module) to a spv.module operation. This better captures the host-device separation modeled by GPU dialect and simplifies the lowering as well. PiperOrigin-RevId: 284574688
This commit is contained in:
parent
312ccb1c0f
commit
4a62019eb8
|
@ -329,13 +329,17 @@ def SPV_ModuleOp : SPV_Op<"module",
|
|||
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let builders = [OpBuilder<"Builder *, OperationState &state">,
|
||||
OpBuilder<[{Builder *, OperationState &state,
|
||||
IntegerAttr addressing_model,
|
||||
IntegerAttr memory_model,
|
||||
/*optional*/ArrayAttr capabilities = nullptr,
|
||||
/*optional*/ArrayAttr extensions = nullptr,
|
||||
/*optional*/ArrayAttr extended_instruction_sets = nullptr}]>];
|
||||
let builders =
|
||||
[OpBuilder<"Builder *, OperationState &state">,
|
||||
OpBuilder<[{Builder *, OperationState &state,
|
||||
IntegerAttr addressing_model,
|
||||
IntegerAttr memory_model}]>,
|
||||
OpBuilder<[{Builder *, OperationState &state,
|
||||
spirv::AddressingModel addressing_model,
|
||||
spirv::MemoryModel memory_model,
|
||||
/*optional*/ ArrayRef<spirv::Capability> capabilities = {},
|
||||
/*optional*/ ArrayRef<spirv::Extension> extensions = {},
|
||||
/*optional*/ ArrayAttr extended_instruction_sets = nullptr}]>];
|
||||
|
||||
// We need to ensure the block inside the region is properly terminated;
|
||||
// the auto-generated builders do not guarantee that.
|
||||
|
|
|
@ -23,6 +23,7 @@
|
|||
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
|
||||
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
|
||||
#include "mlir/IR/Module.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
|
@ -71,8 +72,36 @@ private:
|
|||
SmallVector<int32_t, 3> workGroupSizeAsInt32;
|
||||
};
|
||||
|
||||
/// Pattern to convert a module with gpu.kernel_module attribute to a
|
||||
/// spv.module.
|
||||
class KernelModuleConversion final : public SPIRVOpLowering<ModuleOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
/// Pattern to convert a module terminator op to a terminator of spv.module op.
|
||||
// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined
|
||||
// in ODS.
|
||||
class KernelModuleTerminatorConversion final
|
||||
: public SPIRVOpLowering<ModuleTerminatorOp> {
|
||||
public:
|
||||
using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering;
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const override;
|
||||
};
|
||||
|
||||
} // namespace
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// loop::ForOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
@ -142,6 +171,10 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
|
|||
return matchSuccess();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Builtins.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
template <typename SourceOp, spirv::BuiltIn builtin>
|
||||
PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
||||
SourceOp op, ArrayRef<Value *> operands,
|
||||
|
@ -170,6 +203,10 @@ PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
|
|||
return this->matchSuccess();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FuncOp with gpu.kernel attribute.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult
|
||||
KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
|
@ -196,6 +233,51 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
|
|||
return matchSuccess();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleOp with gpu.kernel_module.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult KernelModuleConversion::matchAndRewrite(
|
||||
ModuleOp moduleOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
if (!moduleOp.getAttrOfType<UnitAttr>(
|
||||
gpu::GPUDialect::getKernelModuleAttrName())) {
|
||||
return matchFailure();
|
||||
}
|
||||
// TODO : Generalize this to account for different extensions,
|
||||
// capabilities, extended_instruction_sets, other addressing models
|
||||
// and memory models.
|
||||
auto spvModule = rewriter.create<spirv::ModuleOp>(
|
||||
moduleOp.getLoc(), spirv::AddressingModel::Logical,
|
||||
spirv::MemoryModel::GLSL450, spirv::Capability::Shader,
|
||||
spirv::Extension::SPV_KHR_storage_buffer_storage_class);
|
||||
// Move the region from the module op into the SPIR-V module.
|
||||
Region &spvModuleRegion = spvModule.getOperation()->getRegion(0);
|
||||
rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
|
||||
spvModuleRegion.begin());
|
||||
// The spv.module build method adds a block with a terminator. Remove that
|
||||
// block. The terminator of the module op in the remaining block will be
|
||||
// legalized later.
|
||||
spvModuleRegion.back().erase();
|
||||
rewriter.eraseOp(moduleOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ModuleTerminatorOp for gpu.kernel_module.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite(
|
||||
ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
|
||||
ConversionPatternRewriter &rewriter) const {
|
||||
rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp);
|
||||
return matchSuccess();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GPU To SPIRV Patterns.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
void populateGPUToSPIRVPatterns(MLIRContext *context,
|
||||
SPIRVTypeConverter &typeConverter,
|
||||
|
@ -203,7 +285,7 @@ void populateGPUToSPIRVPatterns(MLIRContext *context,
|
|||
ArrayRef<int64_t> workGroupSize) {
|
||||
patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
|
||||
patterns.insert<
|
||||
ForOpConversion,
|
||||
ForOpConversion, KernelModuleConversion, KernelModuleTerminatorConversion,
|
||||
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
|
||||
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
|
||||
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,
|
||||
|
|
|
@ -67,34 +67,19 @@ void GPUToSPIRVPass::runOnModule() {
|
|||
auto context = &getContext();
|
||||
auto module = getModule();
|
||||
|
||||
SmallVector<Operation *, 4> spirvModules;
|
||||
module.walk([&module, &spirvModules](FuncOp funcOp) {
|
||||
if (!gpu::GPUDialect::isKernel(funcOp)) {
|
||||
return;
|
||||
SmallVector<Operation *, 1> kernelModules;
|
||||
OpBuilder builder(context);
|
||||
module.walk([&builder, &kernelModules](ModuleOp moduleOp) {
|
||||
if (moduleOp.getAttrOfType<UnitAttr>(
|
||||
gpu::GPUDialect::getKernelModuleAttrName())) {
|
||||
// For each kernel module (should be only 1 for now, but that is not a
|
||||
// requirement here), clone the module for conversion because the
|
||||
// gpu.launch function still needs the kernel module.
|
||||
builder.setInsertionPoint(moduleOp.getOperation());
|
||||
kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
|
||||
}
|
||||
OpBuilder builder(funcOp.getOperation());
|
||||
// Create a new spirv::ModuleOp for this function, and clone the
|
||||
// function into it.
|
||||
// TODO : Generalize this to account for different extensions,
|
||||
// capabilities, extended_instruction_sets, other addressing models
|
||||
// and memory models.
|
||||
auto spvModule = builder.create<spirv::ModuleOp>(
|
||||
funcOp.getLoc(),
|
||||
builder.getI32IntegerAttr(
|
||||
static_cast<int32_t>(spirv::AddressingModel::Logical)),
|
||||
builder.getI32IntegerAttr(
|
||||
static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
|
||||
builder.getStrArrayAttr(
|
||||
spirv::stringifyCapability(spirv::Capability::Shader)),
|
||||
builder.getStrArrayAttr(spirv::stringifyExtension(
|
||||
spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
|
||||
// Hardwire the capability to be Shader.
|
||||
OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
|
||||
moduleBuilder.clone(*funcOp.getOperation());
|
||||
spirvModules.push_back(spvModule);
|
||||
});
|
||||
|
||||
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
|
||||
SPIRVTypeConverter typeConverter;
|
||||
OwningRewritePatternList patterns;
|
||||
populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
|
||||
|
@ -105,7 +90,7 @@ void GPUToSPIRVPass::runOnModule() {
|
|||
target.addDynamicallyLegalOp<FuncOp>(
|
||||
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
|
||||
|
||||
if (failed(applyFullConversion(spirvModules, target, patterns,
|
||||
if (failed(applyFullConversion(kernelModules, target, patterns,
|
||||
&typeConverter))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
|
|
@ -286,7 +286,7 @@ FuncOp mlir::spirv::lowerAsEntryFunction(
|
|||
newFuncOp.setType(rewriter.getFunctionType(
|
||||
signatureConverter.getConvertedTypes(), llvm::None));
|
||||
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
|
||||
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
|
||||
rewriter.eraseOp(funcOp);
|
||||
|
||||
// Set the attributes for argument and the function.
|
||||
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();
|
||||
|
|
|
@ -75,6 +75,21 @@ static LogicalResult extractValueFromConstOp(Operation *op,
|
|||
return success();
|
||||
}
|
||||
|
||||
template <typename Ty>
|
||||
static ArrayAttr
|
||||
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
|
||||
llvm::function_ref<StringRef(Ty)> stringifyFn) {
|
||||
if (enumValues.empty()) {
|
||||
return nullptr;
|
||||
}
|
||||
SmallVector<StringRef, 1> enumValStrs;
|
||||
enumValStrs.reserve(enumValues.size());
|
||||
for (auto val : enumValues) {
|
||||
enumValStrs.emplace_back(stringifyFn(val));
|
||||
}
|
||||
return builder.getStrArrayAttr(enumValStrs);
|
||||
}
|
||||
|
||||
template <typename EnumClass>
|
||||
static ParseResult
|
||||
parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
|
||||
|
@ -2039,20 +2054,38 @@ void spirv::ModuleOp::build(Builder *builder, OperationState &state) {
|
|||
ensureTerminator(*state.addRegion(), *builder, state.location);
|
||||
}
|
||||
|
||||
// TODO(ravishankarm): This is only here for resolving some dependency outside
|
||||
// of mlir. Remove once it is done.
|
||||
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
|
||||
IntegerAttr addressing_model,
|
||||
IntegerAttr memory_model, ArrayAttr capabilities,
|
||||
ArrayAttr extensions,
|
||||
ArrayAttr extended_instruction_sets) {
|
||||
IntegerAttr memory_model) {
|
||||
state.addAttribute("addressing_model", addressing_model);
|
||||
state.addAttribute("memory_model", memory_model);
|
||||
if (capabilities)
|
||||
state.addAttribute("capabilities", capabilities);
|
||||
if (extensions)
|
||||
state.addAttribute("extensions", extensions);
|
||||
build(builder, state);
|
||||
}
|
||||
|
||||
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
|
||||
spirv::AddressingModel addressing_model,
|
||||
spirv::MemoryModel memory_model,
|
||||
ArrayRef<spirv::Capability> capabilities,
|
||||
ArrayRef<spirv::Extension> extensions,
|
||||
ArrayAttr extended_instruction_sets) {
|
||||
state.addAttribute(
|
||||
"addressing_model",
|
||||
builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
|
||||
state.addAttribute("memory_model", builder->getI32IntegerAttr(
|
||||
static_cast<int32_t>(memory_model)));
|
||||
if (!capabilities.empty())
|
||||
state.addAttribute("capabilities",
|
||||
getStrArrayAttrForEnumList<spirv::Capability>(
|
||||
*builder, capabilities, spirv::stringifyCapability));
|
||||
if (!extensions.empty())
|
||||
state.addAttribute("extensions",
|
||||
getStrArrayAttrForEnumList<spirv::Extension>(
|
||||
*builder, extensions, spirv::stringifyExtension));
|
||||
if (extended_instruction_sets)
|
||||
state.addAttribute("extended_instruction_sets", extended_instruction_sets);
|
||||
ensureTerminator(*state.addRegion(), *builder, state.location);
|
||||
build(builder, state);
|
||||
}
|
||||
|
||||
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {
|
||||
|
|
|
@ -13,6 +13,7 @@ module attributes {gpu.container_module} {
|
|||
// CHECK: spv.Return
|
||||
return
|
||||
}
|
||||
// CHECK: attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
|
||||
}
|
||||
|
||||
func @foo() {
|
||||
|
|
Loading…
Reference in New Issue