Outline GPU kernel function into a nested module.

Roll forward of commit 5684a12.

When outlining GPU kernels, put the kernel function inside a nested module. Then use a nested pipeline to generate the cubins, independently per kernel. In a final pass, move the cubins back to the parent module.

PiperOrigin-RevId: 270639748
This commit is contained in:
Christian Sigg 2019-09-23 03:16:23 -07:00 committed by A. Unique TensorFlower
parent c900d4994e
commit b8676da1fc
11 changed files with 215 additions and 164 deletions

View File

@ -26,10 +26,6 @@ class OwningRewritePatternList;
class ModuleOp; class ModuleOp;
template <typename OpT> class OpPassBase; template <typename OpT> class OpPassBase;
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);
/// Creates a pass that lowers GPU dialect operations to NVVM counterparts. /// Creates a pass that lowers GPU dialect operations to NVVM counterparts.
std::unique_ptr<OpPassBase<ModuleOp>> createLowerGpuOpsToNVVMOpsPass(); std::unique_ptr<OpPassBase<ModuleOp>> createLowerGpuOpsToNVVMOpsPass();

View File

@ -41,9 +41,12 @@ public:
/// Get the canonical string name of the dialect. /// Get the canonical string name of the dialect.
static StringRef getDialectName(); static StringRef getDialectName();
/// Get the name of the attribute used to annotate outlined kernel functions. /// Get the name of the attribute used to annotate external kernel functions.
static StringRef getKernelFuncAttrName() { return "gpu.kernel"; } static StringRef getKernelFuncAttrName() { return "gpu.kernel"; }
/// Get the name of the attribute used to annotate kernel modules.
static StringRef getKernelModuleAttrName() { return "gpu.kernel_module"; }
/// Returns whether the given function is a kernel function, i.e., has the /// Returns whether the given function is a kernel function, i.e., has the
/// 'gpu.kernel' attribute. /// 'gpu.kernel' attribute.
static bool isKernel(FuncOp function); static bool isKernel(FuncOp function);

View File

@ -49,26 +49,37 @@ namespace {
// TODO(herhut): Move to shared location. // TODO(herhut): Move to shared location.
static constexpr const char *kCubinAnnotation = "nvvm.cubin"; static constexpr const char *kCubinAnnotation = "nvvm.cubin";
/// A pass converting tagged kernel functions to cubin blobs. /// A pass converting tagged kernel modules to cubin blobs.
///
/// If tagged as a kernel module, each contained function is translated to NVVM
/// IR and further to PTX. A user provided CubinGenerator compiles the PTX to
/// GPU binary code, which is then attached as an attribute to the function. The
/// function body is erased.
class GpuKernelToCubinPass : public ModulePass<GpuKernelToCubinPass> { class GpuKernelToCubinPass : public ModulePass<GpuKernelToCubinPass> {
public: public:
GpuKernelToCubinPass( GpuKernelToCubinPass(
CubinGenerator cubinGenerator = compilePtxToCubinForTesting) CubinGenerator cubinGenerator = compilePtxToCubinForTesting)
: cubinGenerator(cubinGenerator) {} : cubinGenerator(cubinGenerator) {}
// Run the dialect converter on the module.
void runOnModule() override { void runOnModule() override {
if (!getModule().getAttrOfType<UnitAttr>(
gpu::GPUDialect::getKernelModuleAttrName()))
return;
// Make sure the NVPTX target is initialized. // Make sure the NVPTX target is initialized.
LLVMInitializeNVPTXTarget(); LLVMInitializeNVPTXTarget();
LLVMInitializeNVPTXTargetInfo(); LLVMInitializeNVPTXTargetInfo();
LLVMInitializeNVPTXTargetMC(); LLVMInitializeNVPTXTargetMC();
LLVMInitializeNVPTXAsmPrinter(); LLVMInitializeNVPTXAsmPrinter();
auto llvmModule = translateModuleToNVVMIR(getModule());
if (!llvmModule)
return signalPassFailure();
for (auto function : getModule().getOps<FuncOp>()) { for (auto function : getModule().getOps<FuncOp>()) {
if (!gpu::GPUDialect::isKernel(function) || function.isExternal()) { if (!gpu::GPUDialect::isKernel(function))
continue; continue;
} if (failed(translateGpuKernelToCubinAnnotation(*llvmModule, function)))
if (failed(translateGpuKernelToCubinAnnotation(function)))
signalPassFailure(); signalPassFailure();
} }
} }
@ -79,8 +90,13 @@ private:
std::string translateModuleToPtx(llvm::Module &module, std::string translateModuleToPtx(llvm::Module &module,
llvm::TargetMachine &target_machine); llvm::TargetMachine &target_machine);
/// Converts llvmModule to cubin using the user-provded generator.
OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, FuncOp &function); OwnedCubin convertModuleToCubin(llvm::Module &llvmModule, FuncOp &function);
LogicalResult translateGpuKernelToCubinAnnotation(FuncOp &function);
/// Translates llvmModule to cubin and assigns it to attribute of function.
LogicalResult translateGpuKernelToCubinAnnotation(llvm::Module &llvmModule,
FuncOp &function);
CubinGenerator cubinGenerator; CubinGenerator cubinGenerator;
}; };
@ -135,22 +151,13 @@ OwnedCubin GpuKernelToCubinPass::convertModuleToCubin(llvm::Module &llvmModule,
return cubinGenerator(ptx, function); return cubinGenerator(ptx, function);
} }
LogicalResult LogicalResult GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(
GpuKernelToCubinPass::translateGpuKernelToCubinAnnotation(FuncOp &function) { llvm::Module &llvmModule, FuncOp &function) {
Builder builder(function.getContext()); auto cubin = convertModuleToCubin(llvmModule, function);
if (!cubin)
OwningModuleRef module = ModuleOp::create(function.getLoc());
// TODO(herhut): Also handle called functions.
module->push_back(function.clone());
auto llvmModule = translateModuleToNVVMIR(*module);
auto cubin = convertModuleToCubin(*llvmModule, function);
if (!cubin) {
return function.emitError("translation to CUDA binary failed."); return function.emitError("translation to CUDA binary failed.");
}
Builder builder(function.getContext());
function.setAttr(kCubinAnnotation, function.setAttr(kCubinAnnotation,
builder.getStringAttr({cubin->data(), cubin->size()})); builder.getStringAttr({cubin->data(), cubin->size()}));

View File

@ -43,8 +43,15 @@ constexpr const char *kCubinGetterAnnotation = "nvvm.cubingetter";
constexpr const char *kCubinGetterSuffix = "_cubin"; constexpr const char *kCubinGetterSuffix = "_cubin";
constexpr const char *kCubinStorageSuffix = "_cubin_cst"; constexpr const char *kCubinStorageSuffix = "_cubin_cst";
/// A pass generating global strings and getter functions for all cubin blobs /// A pass which moves cubin from function attributes in nested modules
/// annotated on functions via the nvvm.cubin attribute. /// to global strings and generates getter functions.
///
/// The GpuKernelToCubinPass annotates kernels functions with compiled device
/// code blobs. These functions reside in nested modules generated by
/// GpuKernelOutliningPass. This pass consumes these modules and moves the cubin
/// blobs back to the parent module as global strings and generates accessor
/// functions for them. The external kernel functions (also generated by the
/// outlining pass) are annotated with the symbol of the cubin accessor.
class GpuGenerateCubinAccessorsPass class GpuGenerateCubinAccessorsPass
: public ModulePass<GpuGenerateCubinAccessorsPass> { : public ModulePass<GpuGenerateCubinAccessorsPass> {
private: private:
@ -55,18 +62,25 @@ private:
} }
// Inserts a global constant string containing `blob` into the parent module // Inserts a global constant string containing `blob` into the parent module
// of `orig` and generates the function that returns the address of the first // of `kernelFunc` and generates the function that returns the address of the
// character of this string. // first character of this string.
// TODO(herhut): consider fusing this pass with launch-func-to-cuda. // TODO(herhut): consider fusing this pass with launch-func-to-cuda.
void generate(FuncOp orig, StringAttr blob) { void generate(FuncOp kernelFunc, StringAttr blob) {
Location loc = orig.getLoc(); auto stubFunc = getModule().lookupSymbol<FuncOp>(kernelFunc.getName());
SmallString<128> nameBuffer(orig.getName()); if (!stubFunc) {
auto module = orig.getParentOfType<ModuleOp>(); kernelFunc.emitError(
"corresponding external function not found in parent module");
return signalPassFailure();
}
Location loc = stubFunc.getLoc();
SmallString<128> nameBuffer(stubFunc.getName());
auto module = stubFunc.getParentOfType<ModuleOp>();
assert(module && "function must belong to a module"); assert(module && "function must belong to a module");
// Insert the getter function just after the original function. // Insert the getter function just after the original function.
OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin()); OpBuilder moduleBuilder(module.getBody(), module.getBody()->begin());
moduleBuilder.setInsertionPoint(orig.getOperation()->getNextNode()); moduleBuilder.setInsertionPoint(stubFunc.getOperation()->getNextNode());
auto getterType = moduleBuilder.getFunctionType( auto getterType = moduleBuilder.getFunctionType(
llvm::None, LLVM::LLVMType::getInt8PtrTy(llvmDialect)); llvm::None, LLVM::LLVMType::getInt8PtrTy(llvmDialect));
nameBuffer.append(kCubinGetterSuffix); nameBuffer.append(kCubinGetterSuffix);
@ -75,7 +89,7 @@ private:
Block *entryBlock = result.addEntryBlock(); Block *entryBlock = result.addEntryBlock();
// Drop the getter suffix before appending the storage suffix. // Drop the getter suffix before appending the storage suffix.
nameBuffer.resize(orig.getName().size()); nameBuffer.resize(stubFunc.getName().size());
nameBuffer.append(kCubinStorageSuffix); nameBuffer.append(kCubinStorageSuffix);
// Obtain the address of the first character of the global string containing // Obtain the address of the first character of the global string containing
@ -86,21 +100,23 @@ private:
builder.create<LLVM::ReturnOp>(loc, startPtr); builder.create<LLVM::ReturnOp>(loc, startPtr);
// Store the name of the getter on the function for easier lookup. // Store the name of the getter on the function for easier lookup.
orig.setAttr(kCubinGetterAnnotation, builder.getSymbolRefAttr(result)); stubFunc.setAttr(kCubinGetterAnnotation, builder.getSymbolRefAttr(result));
} }
public: public:
// Perform the conversion on the module. This may insert globals, so it
// cannot be done on multiple functions in parallel.
void runOnModule() override { void runOnModule() override {
llvmDialect = llvmDialect = getContext().getRegisteredDialect<LLVM::LLVMDialect>();
getModule().getContext()->getRegisteredDialect<LLVM::LLVMDialect>();
for (auto func : getModule().getOps<FuncOp>()) { auto modules = getModule().getOps<ModuleOp>();
StringAttr cubinBlob = func.getAttrOfType<StringAttr>(kCubinAnnotation); for (auto module : llvm::make_early_inc_range(modules)) {
if (!cubinBlob) if (!module.getAttrOfType<UnitAttr>(
gpu::GPUDialect::getKernelModuleAttrName()))
continue; continue;
generate(func, cubinBlob); for (auto func : module.getOps<FuncOp>()) {
if (StringAttr blob = func.getAttrOfType<StringAttr>(kCubinAnnotation))
generate(func, blob);
}
module.erase();
} }
} }

View File

@ -23,6 +23,7 @@
#include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h" #include "mlir/Conversion/GPUToNVVM/GPUToNVVMPass.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h"
#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h"
#include "mlir/Dialect/GPU/GPUDialect.h" #include "mlir/Dialect/GPU/GPUDialect.h"
#include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h"
#include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h"
@ -38,23 +39,6 @@ using namespace mlir;
namespace { namespace {
// Rewriting that replaces the types of a LaunchFunc operation with their
// LLVM counterparts.
struct GPULaunchFuncOpLowering : public LLVMOpLowering {
public:
explicit GPULaunchFuncOpLowering(LLVMTypeConverter &lowering_)
: LLVMOpLowering(gpu::LaunchFuncOp::getOperationName(),
lowering_.getDialect()->getContext(), lowering_) {}
// Convert the kernel arguments to an LLVM type, preserve the rest.
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.clone(*op)->setOperands(operands);
return rewriter.replaceOp(op, llvm::None), matchSuccess();
}
};
// Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension // Rewriting that replaces Op with XOp, YOp, or ZOp depending on the dimension
// that Op operates on. Op is assumed to return an `std.index` value and // that Op operates on. Op is assumed to return an `std.index` value and
// XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on // XOp, YOp and ZOp are assumed to return an `llvm.i32` value. Depending on
@ -119,20 +103,31 @@ public:
} }
}; };
// A pass that replaces all occurences of GPU operations with their // A pass that replaces all occurences of GPU device operations with their
// corresponding NVVM equivalent. // corresponding NVVM equivalent.
// //
// This pass does not handle launching of kernels. Instead, it is meant to be // This pass only handles device code and is not meant to be run on GPU host
// used on the body region of a launch or the body region of a kernel // code.
// function.
class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> { class LowerGpuOpsToNVVMOpsPass : public ModulePass<LowerGpuOpsToNVVMOpsPass> {
public: public:
void runOnModule() override { void runOnModule() override {
ModuleOp m = getModule(); ModuleOp m = getModule();
if (!m.getAttrOfType<UnitAttr>(gpu::GPUDialect::getKernelModuleAttrName()))
return;
OwningRewritePatternList patterns; OwningRewritePatternList patterns;
LLVMTypeConverter converter(m.getContext()); LLVMTypeConverter converter(m.getContext());
populateGpuToNVVMConversionPatterns(converter, patterns); populateStdToLLVMConversionPatterns(converter, patterns);
patterns.insert<
GPUIndexIntrinsicOpLowering<gpu::ThreadId, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::BlockDim, NVVM::BlockDimXOp,
NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
GPUIndexIntrinsicOpLowering<gpu::BlockId, NVVM::BlockIdXOp,
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::GridDim, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>>(
converter);
ConversionTarget target(getContext()); ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>(); target.addLegalDialect<LLVM::LLVMDialect>();
@ -146,22 +141,6 @@ public:
} // anonymous namespace } // anonymous namespace
/// Collect a set of patterns to convert from the GPU dialect to NVVM.
void mlir::populateGpuToNVVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
patterns
.insert<GPULaunchFuncOpLowering,
GPUIndexIntrinsicOpLowering<gpu::ThreadId, NVVM::ThreadIdXOp,
NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::BlockDim, NVVM::BlockDimXOp,
NVVM::BlockDimYOp, NVVM::BlockDimZOp>,
GPUIndexIntrinsicOpLowering<gpu::BlockId, NVVM::BlockIdXOp,
NVVM::BlockIdYOp, NVVM::BlockIdZOp>,
GPUIndexIntrinsicOpLowering<gpu::GridDim, NVVM::GridDimXOp,
NVVM::GridDimYOp, NVVM::GridDimZOp>>(
converter);
}
std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToNVVMOpsPass() { std::unique_ptr<OpPassBase<ModuleOp>> mlir::createLowerGpuOpsToNVVMOpsPass() {
return std::make_unique<LowerGpuOpsToNVVMOpsPass>(); return std::make_unique<LowerGpuOpsToNVVMOpsPass>();
} }

View File

@ -93,7 +93,7 @@ static gpu::LaunchFuncOp inlineConstants(FuncOp kernelFunc,
} }
// Outline the `gpu.launch` operation body into a kernel function. Replace // Outline the `gpu.launch` operation body into a kernel function. Replace
// `gpu.return` operations by `std.return` in the generated functions. // `gpu.return` operations by `std.return` in the generated function.
static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) { static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) {
Location loc = launchOp.getLoc(); Location loc = launchOp.getLoc();
SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes()); SmallVector<Type, 4> kernelOperandTypes(launchOp.getKernelOperandTypes());
@ -107,7 +107,7 @@ static FuncOp outlineKernelFunc(gpu::LaunchOp launchOp) {
outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(), outlinedFunc.setAttr(gpu::GPUDialect::getKernelFuncAttrName(),
builder.getUnitAttr()); builder.getUnitAttr());
injectGpuIndexOperations(loc, outlinedFunc); injectGpuIndexOperations(loc, outlinedFunc);
outlinedFunc.walk([](mlir::gpu::Return op) { outlinedFunc.walk([](gpu::Return op) {
OpBuilder replacer(op); OpBuilder replacer(op);
replacer.create<ReturnOp>(op.getLoc()); replacer.create<ReturnOp>(op.getLoc());
op.erase(); op.erase();
@ -131,15 +131,44 @@ static void convertToLaunchFuncOp(gpu::LaunchOp &launchOp, FuncOp kernelFunc) {
namespace { namespace {
/// Pass that moves the kernel of each LaunchOp into its separate nested module.
///
/// This pass moves the kernel code of each LaunchOp into a function created
/// inside a nested module. It also creates an external function of the same
/// name in the parent module.
///
/// The kernel modules are intended to be compiled to a cubin blob independently
/// in a separate pass. The external functions can then be annotated with the
/// symbol of the cubin accessor function.
class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> { class GpuKernelOutliningPass : public ModulePass<GpuKernelOutliningPass> {
public: public:
void runOnModule() override { void runOnModule() override {
ModuleManager moduleManager(getModule()); ModuleManager moduleManager(getModule());
auto context = getModule().getContext();
Builder builder(context);
for (auto func : getModule().getOps<FuncOp>()) { for (auto func : getModule().getOps<FuncOp>()) {
func.walk([&](mlir::gpu::LaunchOp op) { // Insert just after the function.
Block::iterator insertPt(func.getOperation()->getNextNode());
func.walk([&](gpu::LaunchOp op) {
// TODO(b/141098412): Handle called functions and globals.
FuncOp outlinedFunc = outlineKernelFunc(op); FuncOp outlinedFunc = outlineKernelFunc(op);
moduleManager.insert(outlinedFunc);
// Potentially renames outlinedFunc to make symbol unique.
moduleManager.insert(insertPt, outlinedFunc);
// Potentially changes signature, pulling in constants.
convertToLaunchFuncOp(op, outlinedFunc); convertToLaunchFuncOp(op, outlinedFunc);
// Create clone and move body from outlinedFunc.
auto kernelFunc = outlinedFunc.cloneWithoutRegions();
kernelFunc.getBody().takeBody(outlinedFunc.getBody());
// Create nested module and insert kernelFunc.
auto kernelModule = ModuleOp::create(UnknownLoc::get(context));
kernelModule.setAttr(gpu::GPUDialect::getKernelModuleAttrName(),
builder.getUnitAttr());
kernelModule.push_back(kernelFunc);
getModule().insert(insertPt, kernelModule);
}); });
} }
} }

View File

@ -2,9 +2,14 @@
// CHECK: llvm.mlir.global constant @[[global:.*]]("CUBIN") // CHECK: llvm.mlir.global constant @[[global:.*]]("CUBIN")
module attributes {gpu.kernel_module} {
func @kernel(!llvm.float, !llvm<"float*">)
attributes {nvvm.cubin = "CUBIN"}
}
func @kernel(!llvm.float, !llvm<"float*">) func @kernel(!llvm.float, !llvm<"float*">)
// CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN", nvvm.cubingetter = @[[getter:.*]]} // CHECK: attributes {gpu.kernel, nvvm.cubingetter = @[[getter:.*]]}
attributes {gpu.kernel, nvvm.cubin = "CUBIN"} attributes {gpu.kernel}
// CHECK: func @[[getter]]() -> !llvm<"i8*"> // CHECK: func @[[getter]]() -> !llvm<"i8*">
// CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]] // CHECK: %[[addressof:.*]] = llvm.mlir.addressof @[[global]]

View File

@ -1,8 +1,26 @@
// RUN: mlir-opt %s --test-kernel-to-cubin | FileCheck %s // RUN: mlir-opt %s --test-kernel-to-cubin -split-input-file | FileCheck %s
func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">) module attributes {gpu.kernel_module} {
// CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN"} func @kernel(%arg0 : !llvm.float, %arg1 : !llvm<"float*">)
attributes { gpu.kernel } { // CHECK: attributes {gpu.kernel, nvvm.cubin = "CUBIN"}
// CHECK-NOT: llvm.return attributes { gpu.kernel } {
llvm.return // CHECK-NOT: llvm.return
} llvm.return
}
}
// -----
module attributes {gpu.kernel_module} {
// CHECK: func @kernel_a
func @kernel_a()
attributes { gpu.kernel } {
llvm.return
}
// CHECK: func @kernel_b
func @kernel_b()
attributes { gpu.kernel } {
llvm.return
}
}

View File

@ -1,35 +1,37 @@
// RUN: mlir-opt %s -lower-gpu-ops-to-nvvm-ops | FileCheck %s // RUN: mlir-opt %s -lower-gpu-ops-to-nvvm-ops | FileCheck %s
// CHECK-LABEL: func @gpu_index_ops() module attributes {gpu.kernel_module} {
func @gpu_index_ops() // CHECK-LABEL: func @gpu_index_ops()
attributes { gpu.kernel } { func @gpu_index_ops()
// CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32 attributes { gpu.kernel } {
%tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.tid.x : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32 %tIdX = "gpu.thread_id"() {dimension = "x"} : () -> (index)
%tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.tid.y : !llvm.i32
// CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32 %tIdY = "gpu.thread_id"() {dimension = "y"} : () -> (index)
%tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index) // CHECK: = nvvm.read.ptx.sreg.tid.z : !llvm.i32
%tIdZ = "gpu.thread_id"() {dimension = "z"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
%bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index) %bDimX = "gpu.block_dim"() {dimension = "x"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
%bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index) %bDimY = "gpu.block_dim"() {dimension = "y"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
%bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index) %bDimZ = "gpu.block_dim"() {dimension = "z"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
%bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index) %bIdX = "gpu.block_id"() {dimension = "x"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
%bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index) %bIdY = "gpu.block_id"() {dimension = "y"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
%bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index) %bIdZ = "gpu.block_id"() {dimension = "z"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
%gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index) %gDimX = "gpu.grid_dim"() {dimension = "x"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
%gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index) %gDimY = "gpu.grid_dim"() {dimension = "y"} : () -> (index)
// CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32 // CHECK: = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
%gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index) %gDimZ = "gpu.grid_dim"() {dimension = "z"} : () -> (index)
std.return std.return
}
} }

View File

@ -1,4 +1,4 @@
// RUN: mlir-opt -gpu-kernel-outlining -split-input-file %s | FileCheck %s // RUN: mlir-opt -gpu-kernel-outlining -split-input-file -verify-diagnostics %s | FileCheck %s
// CHECK-LABEL: func @launch() // CHECK-LABEL: func @launch()
func @launch() { func @launch() {
@ -35,7 +35,11 @@ func @launch() {
} }
// CHECK-LABEL: func @launch_kernel // CHECK-LABEL: func @launch_kernel
// CHECK-SAME: (%[[ARG0:.*]]: f32, %[[ARG1:.*]]: memref<?xf32, 1>) // CHECK-SAME: (f32, memref<?xf32, 1>)
// CHECK-NEXT: attributes {gpu.kernel}
// CHECK-LABEL: func @launch_kernel
// CHECK-SAME: (%[[KERNEL_ARG0:.*]]: f32, %[[KERNEL_ARG1:.*]]: memref<?xf32, 1>)
// CHECK-NEXT: attributes {gpu.kernel} // CHECK-NEXT: attributes {gpu.kernel}
// CHECK-NEXT: %[[BID:.*]] = "gpu.block_id"() {dimension = "x"} : () -> index // CHECK-NEXT: %[[BID:.*]] = "gpu.block_id"() {dimension = "x"} : () -> index
// CHECK-NEXT: = "gpu.block_id"() {dimension = "y"} : () -> index // CHECK-NEXT: = "gpu.block_id"() {dimension = "y"} : () -> index
@ -49,9 +53,9 @@ func @launch() {
// CHECK-NEXT: %[[BDIM:.*]] = "gpu.block_dim"() {dimension = "x"} : () -> index // CHECK-NEXT: %[[BDIM:.*]] = "gpu.block_dim"() {dimension = "x"} : () -> index
// CHECK-NEXT: = "gpu.block_dim"() {dimension = "y"} : () -> index // CHECK-NEXT: = "gpu.block_dim"() {dimension = "y"} : () -> index
// CHECK-NEXT: = "gpu.block_dim"() {dimension = "z"} : () -> index // CHECK-NEXT: = "gpu.block_dim"() {dimension = "z"} : () -> index
// CHECK-NEXT: "use"(%[[ARG0]]) : (f32) -> () // CHECK-NEXT: "use"(%[[KERNEL_ARG0]]) : (f32) -> ()
// CHECK-NEXT: "some_op"(%[[BID]], %[[BDIM]]) : (index, index) -> () // CHECK-NEXT: "some_op"(%[[BID]], %[[BDIM]]) : (index, index) -> ()
// CHECK-NEXT: = load %[[ARG1]][%[[TID]]] : memref<?xf32, 1> // CHECK-NEXT: = load %[[KERNEL_ARG1]][%[[TID]]] : memref<?xf32, 1>
// ----- // -----
@ -75,8 +79,8 @@ func @multiple_launches() {
return return
} }
// CHECK-LABEL: func @multiple_launches_kernel() // CHECK: func @multiple_launches_kernel()
// CHECK-LABEL: func @multiple_launches_kernel_0() // CHECK: func @multiple_launches_kernel_0()
// ----- // -----
@ -100,3 +104,23 @@ func @extra_constants(%arg0 : memref<?xf32>) {
// CHECK-LABEL: func @extra_constants_kernel(%{{.*}}: memref<?xf32>) // CHECK-LABEL: func @extra_constants_kernel(%{{.*}}: memref<?xf32>)
// CHECK: constant // CHECK: constant
// CHECK: constant // CHECK: constant
// -----
func @function_call(%arg0 : memref<?xf32>) {
%cst = constant 8 : index
gpu.launch blocks(%bx, %by, %bz) in (%grid_x = %cst, %grid_y = %cst,
%grid_z = %cst)
threads(%tx, %ty, %tz) in (%block_x = %cst, %block_y = %cst,
%block_z = %cst) {
// TODO(b/141098412): Support function calls.
// expected-error @+1 {{'device_function' does not reference a valid function}}
call @device_function() : () -> ()
gpu.return
}
return
}
func @device_function() {
gpu.return
}

View File

@ -108,50 +108,22 @@ OwnedCubin compilePtxToCubin(const std::string ptx, FuncOp &function) {
return result; return result;
} }
namespace {
// A pass that lowers all Standard and Gpu operations to LLVM dialect. It does
// not lower the GPULaunch operation to actual code but dows translate the
// signature of its kernel argument.
class LowerStandardAndGpuToLLVMAndNVVM
: public ModulePass<LowerStandardAndGpuToLLVMAndNVVM> {
public:
void runOnModule() override {
ModuleOp m = getModule();
OwningRewritePatternList patterns;
LLVMTypeConverter converter(m.getContext());
populateStdToLLVMConversionPatterns(converter, patterns);
populateGpuToNVVMConversionPatterns(converter, patterns);
ConversionTarget target(getContext());
target.addLegalDialect<LLVM::LLVMDialect>();
target.addLegalDialect<NVVM::NVVMDialect>();
target.addLegalOp<ModuleOp>();
target.addLegalOp<ModuleTerminatorOp>();
target.addDynamicallyLegalOp<FuncOp>(
[&](FuncOp op) { return converter.isSignatureLegal(op.getType()); });
if (failed(applyFullConversion(m, target, patterns, &converter)))
signalPassFailure();
}
};
} // end anonymous namespace
static LogicalResult runMLIRPasses(ModuleOp m) { static LogicalResult runMLIRPasses(ModuleOp m) {
PassManager pm(m.getContext()); PassManager pm(m.getContext());
applyPassManagerCLOptions(pm);
pm.addPass(createGpuKernelOutliningPass()); pm.addPass(createGpuKernelOutliningPass());
pm.addPass(static_cast<std::unique_ptr<OpPassBase<ModuleOp>>>( auto &kernelPm = pm.nest<ModuleOp>();
std::make_unique<LowerStandardAndGpuToLLVMAndNVVM>())); kernelPm.addPass(createLowerGpuOpsToNVVMOpsPass());
pm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin)); kernelPm.addPass(createConvertGPUKernelToCubinPass(&compilePtxToCubin));
pm.addPass(createLowerToLLVMPass());
pm.addPass(createGenerateCubinAccessorPass()); pm.addPass(createGenerateCubinAccessorPass());
pm.addPass(createConvertGpuLaunchFuncToCudaCallsPass()); pm.addPass(createConvertGpuLaunchFuncToCudaCallsPass());
if (failed(pm.run(m))) return pm.run(m);
return failure();
return success();
} }
int main(int argc, char **argv) { int main(int argc, char **argv) {
registerPassManagerCLOptions();
return mlir::JitRunnerMain(argc, argv, &runMLIRPasses); return mlir::JitRunnerMain(argc, argv, &runMLIRPasses);
} }