Add lowering for module with gpu.kernel_module attribute.

The existing GPU to SPIR-V lowering created a spv.module for every
function with gpu.kernel attribute. A better approach is to lower the
module that the function lives in (which has the attribute
gpu.kernel_module) to a spv.module operation. This better captures the
host-device separation modeled by GPU dialect and simplifies the
lowering as well.

PiperOrigin-RevId: 284574688
This commit is contained in:
Mahesh Ravishankar 2019-12-09 09:51:25 -08:00 committed by A. Unique TensorFlower
parent 312ccb1c0f
commit 4a62019eb8
6 changed files with 148 additions and 43 deletions

View File

@ -329,13 +329,17 @@ def SPV_ModuleOp : SPV_Op<"module",
let regions = (region SizedRegion<1>:$body);
let builders = [OpBuilder<"Builder *, OperationState &state">,
OpBuilder<[{Builder *, OperationState &state,
IntegerAttr addressing_model,
IntegerAttr memory_model,
/*optional*/ArrayAttr capabilities = nullptr,
/*optional*/ArrayAttr extensions = nullptr,
/*optional*/ArrayAttr extended_instruction_sets = nullptr}]>];
let builders =
[OpBuilder<"Builder *, OperationState &state">,
OpBuilder<[{Builder *, OperationState &state,
IntegerAttr addressing_model,
IntegerAttr memory_model}]>,
OpBuilder<[{Builder *, OperationState &state,
spirv::AddressingModel addressing_model,
spirv::MemoryModel memory_model,
/*optional*/ ArrayRef<spirv::Capability> capabilities = {},
/*optional*/ ArrayRef<spirv::Extension> extensions = {},
/*optional*/ ArrayAttr extended_instruction_sets = nullptr}]>];
// We need to ensure the block inside the region is properly terminated;
// the auto-generated builders do not guarantee that.

View File

@ -23,6 +23,7 @@
#include "mlir/Dialect/SPIRV/SPIRVDialect.h"
#include "mlir/Dialect/SPIRV/SPIRVLowering.h"
#include "mlir/Dialect/SPIRV/SPIRVOps.h"
#include "mlir/IR/Module.h"
using namespace mlir;
@ -71,8 +72,36 @@ private:
SmallVector<int32_t, 3> workGroupSizeAsInt32;
};
/// Pattern to convert a module with gpu.kernel_module attribute to a
/// spv.module.
class KernelModuleConversion final : public SPIRVOpLowering<ModuleOp> {
public:
using SPIRVOpLowering<ModuleOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(ModuleOp moduleOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override;
};
/// Pattern to convert a module terminator op to a terminator of spv.module op.
// TODO: Move this into DRR, but that requires ModuleTerminatorOp to be defined
// in ODS.
class KernelModuleTerminatorConversion final
: public SPIRVOpLowering<ModuleTerminatorOp> {
public:
using SPIRVOpLowering<ModuleTerminatorOp>::SPIRVOpLowering;
PatternMatchResult
matchAndRewrite(ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override;
};
} // namespace
//===----------------------------------------------------------------------===//
// loop::ForOp.
//===----------------------------------------------------------------------===//
PatternMatchResult
ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
@ -142,6 +171,10 @@ ForOpConversion::matchAndRewrite(loop::ForOp forOp, ArrayRef<Value *> operands,
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// Builtins.
//===----------------------------------------------------------------------===//
template <typename SourceOp, spirv::BuiltIn builtin>
PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
SourceOp op, ArrayRef<Value *> operands,
@ -170,6 +203,10 @@ PatternMatchResult LaunchConfigConversion<SourceOp, builtin>::matchAndRewrite(
return this->matchSuccess();
}
//===----------------------------------------------------------------------===//
// FuncOp with gpu.kernel attribute.
//===----------------------------------------------------------------------===//
PatternMatchResult
KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
@ -196,6 +233,51 @@ KernelFnConversion::matchAndRewrite(FuncOp funcOp, ArrayRef<Value *> operands,
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// ModuleOp with gpu.kernel_module.
//===----------------------------------------------------------------------===//
PatternMatchResult KernelModuleConversion::matchAndRewrite(
ModuleOp moduleOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
if (!moduleOp.getAttrOfType<UnitAttr>(
gpu::GPUDialect::getKernelModuleAttrName())) {
return matchFailure();
}
// TODO : Generalize this to account for different extensions,
// capabilities, extended_instruction_sets, other addressing models
// and memory models.
auto spvModule = rewriter.create<spirv::ModuleOp>(
moduleOp.getLoc(), spirv::AddressingModel::Logical,
spirv::MemoryModel::GLSL450, spirv::Capability::Shader,
spirv::Extension::SPV_KHR_storage_buffer_storage_class);
// Move the region from the module op into the SPIR-V module.
Region &spvModuleRegion = spvModule.getOperation()->getRegion(0);
rewriter.inlineRegionBefore(moduleOp.getBodyRegion(), spvModuleRegion,
spvModuleRegion.begin());
// The spv.module build method adds a block with a terminator. Remove that
// block. The terminator of the module op in the remaining block will be
// legalized later.
spvModuleRegion.back().erase();
rewriter.eraseOp(moduleOp);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// ModuleTerminatorOp for gpu.kernel_module.
//===----------------------------------------------------------------------===//
PatternMatchResult KernelModuleTerminatorConversion::matchAndRewrite(
ModuleTerminatorOp terminatorOp, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const {
rewriter.replaceOpWithNewOp<spirv::ModuleEndOp>(terminatorOp);
return matchSuccess();
}
//===----------------------------------------------------------------------===//
// GPU To SPIRV Patterns.
//===----------------------------------------------------------------------===//
namespace mlir {
void populateGPUToSPIRVPatterns(MLIRContext *context,
SPIRVTypeConverter &typeConverter,
@ -203,7 +285,7 @@ void populateGPUToSPIRVPatterns(MLIRContext *context,
ArrayRef<int64_t> workGroupSize) {
patterns.insert<KernelFnConversion>(context, typeConverter, workGroupSize);
patterns.insert<
ForOpConversion,
ForOpConversion, KernelModuleConversion, KernelModuleTerminatorConversion,
LaunchConfigConversion<gpu::BlockDimOp, spirv::BuiltIn::WorkgroupSize>,
LaunchConfigConversion<gpu::BlockIdOp, spirv::BuiltIn::WorkgroupId>,
LaunchConfigConversion<gpu::GridDimOp, spirv::BuiltIn::NumWorkgroups>,

View File

@ -67,34 +67,19 @@ void GPUToSPIRVPass::runOnModule() {
auto context = &getContext();
auto module = getModule();
SmallVector<Operation *, 4> spirvModules;
module.walk([&module, &spirvModules](FuncOp funcOp) {
if (!gpu::GPUDialect::isKernel(funcOp)) {
return;
SmallVector<Operation *, 1> kernelModules;
OpBuilder builder(context);
module.walk([&builder, &kernelModules](ModuleOp moduleOp) {
if (moduleOp.getAttrOfType<UnitAttr>(
gpu::GPUDialect::getKernelModuleAttrName())) {
// For each kernel module (should be only 1 for now, but that is not a
// requirement here), clone the module for conversion because the
// gpu.launch function still needs the kernel module.
builder.setInsertionPoint(moduleOp.getOperation());
kernelModules.push_back(builder.clone(*moduleOp.getOperation()));
}
OpBuilder builder(funcOp.getOperation());
// Create a new spirv::ModuleOp for this function, and clone the
// function into it.
// TODO : Generalize this to account for different extensions,
// capabilities, extended_instruction_sets, other addressing models
// and memory models.
auto spvModule = builder.create<spirv::ModuleOp>(
funcOp.getLoc(),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::AddressingModel::Logical)),
builder.getI32IntegerAttr(
static_cast<int32_t>(spirv::MemoryModel::GLSL450)),
builder.getStrArrayAttr(
spirv::stringifyCapability(spirv::Capability::Shader)),
builder.getStrArrayAttr(spirv::stringifyExtension(
spirv::Extension::SPV_KHR_storage_buffer_storage_class)));
// Hardwire the capability to be Shader.
OpBuilder moduleBuilder(spvModule.getOperation()->getRegion(0));
moduleBuilder.clone(*funcOp.getOperation());
spirvModules.push_back(spvModule);
});
/// Dialect conversion to lower the functions with the spirv::ModuleOps.
SPIRVTypeConverter typeConverter;
OwningRewritePatternList patterns;
populateGPUToSPIRVPatterns(context, typeConverter, patterns, workGroupSize);
@ -105,7 +90,7 @@ void GPUToSPIRVPass::runOnModule() {
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return typeConverter.isSignatureLegal(op.getType()); });
if (failed(applyFullConversion(spirvModules, target, patterns,
if (failed(applyFullConversion(kernelModules, target, patterns,
&typeConverter))) {
return signalPassFailure();
}

View File

@ -286,7 +286,7 @@ FuncOp mlir::spirv::lowerAsEntryFunction(
newFuncOp.setType(rewriter.getFunctionType(
signatureConverter.getConvertedTypes(), llvm::None));
rewriter.applySignatureConversion(&newFuncOp.getBody(), signatureConverter);
rewriter.replaceOp(funcOp.getOperation(), llvm::None);
rewriter.eraseOp(funcOp);
// Set the attributes for argument and the function.
StringRef argABIAttrName = spirv::getInterfaceVarABIAttrName();

View File

@ -75,6 +75,21 @@ static LogicalResult extractValueFromConstOp(Operation *op,
return success();
}
template <typename Ty>
static ArrayAttr
getStrArrayAttrForEnumList(Builder &builder, ArrayRef<Ty> enumValues,
llvm::function_ref<StringRef(Ty)> stringifyFn) {
if (enumValues.empty()) {
return nullptr;
}
SmallVector<StringRef, 1> enumValStrs;
enumValStrs.reserve(enumValues.size());
for (auto val : enumValues) {
enumValStrs.emplace_back(stringifyFn(val));
}
return builder.getStrArrayAttr(enumValStrs);
}
template <typename EnumClass>
static ParseResult
parseEnumAttribute(EnumClass &value, OpAsmParser &parser,
@ -2039,20 +2054,38 @@ void spirv::ModuleOp::build(Builder *builder, OperationState &state) {
ensureTerminator(*state.addRegion(), *builder, state.location);
}
// TODO(ravishankarm): This is only here for resolving some dependency outside
// of mlir. Remove once it is done.
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
IntegerAttr addressing_model,
IntegerAttr memory_model, ArrayAttr capabilities,
ArrayAttr extensions,
ArrayAttr extended_instruction_sets) {
IntegerAttr memory_model) {
state.addAttribute("addressing_model", addressing_model);
state.addAttribute("memory_model", memory_model);
if (capabilities)
state.addAttribute("capabilities", capabilities);
if (extensions)
state.addAttribute("extensions", extensions);
build(builder, state);
}
void spirv::ModuleOp::build(Builder *builder, OperationState &state,
spirv::AddressingModel addressing_model,
spirv::MemoryModel memory_model,
ArrayRef<spirv::Capability> capabilities,
ArrayRef<spirv::Extension> extensions,
ArrayAttr extended_instruction_sets) {
state.addAttribute(
"addressing_model",
builder->getI32IntegerAttr(static_cast<int32_t>(addressing_model)));
state.addAttribute("memory_model", builder->getI32IntegerAttr(
static_cast<int32_t>(memory_model)));
if (!capabilities.empty())
state.addAttribute("capabilities",
getStrArrayAttrForEnumList<spirv::Capability>(
*builder, capabilities, spirv::stringifyCapability));
if (!extensions.empty())
state.addAttribute("extensions",
getStrArrayAttrForEnumList<spirv::Extension>(
*builder, extensions, spirv::stringifyExtension));
if (extended_instruction_sets)
state.addAttribute("extended_instruction_sets", extended_instruction_sets);
ensureTerminator(*state.addRegion(), *builder, state.location);
build(builder, state);
}
static ParseResult parseModuleOp(OpAsmParser &parser, OperationState &state) {

View File

@ -13,6 +13,7 @@ module attributes {gpu.container_module} {
// CHECK: spv.Return
return
}
// CHECK: attributes {capabilities = ["Shader"], extensions = ["SPV_KHR_storage_buffer_storage_class"]}
}
func @foo() {