Add a new NVVM dialect that extends the LLVM dialect with some NVVM specific operations.

Currently, this is limited to operations that give access to the special registers of
    NVIDIA gpus that represent block and thread indices.

--

PiperOrigin-RevId: 245378632
This commit is contained in:
Stephan Herhut 2019-04-26 00:57:10 -07:00 committed by Mehdi Amini
parent 880df8f6ad
commit 65ccb8cfd5
8 changed files with 279 additions and 22 deletions

View File

@ -2,6 +2,13 @@ set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMOps.h.inc -gen-op-decls)
mlir_tablegen(LLVMOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRLLVMOpsIncGen)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMOps.h.inc -gen-op-decls)
mlir_tablegen(NVVMOps.cpp.inc -gen-op-defs)
add_public_tablegen_target(MLIRNVVMOpsIncGen)
set(LLVM_TARGET_DEFINITIONS LLVMOps.td)
mlir_tablegen(LLVMConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRLLVMConversionsIncGen)
set(LLVM_TARGET_DEFINITIONS NVVMOps.td)
mlir_tablegen(NVVMConversions.inc -gen-llvmir-conversions)
add_public_tablegen_target(MLIRNVVMConversionsIncGen)

View File

@ -0,0 +1,54 @@
//===-- LLVMOpBase.td - LLVM IR dialect shared definitions -*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file contains shared definitions for the LLVM IR dialect and its
// subdialects.
//
//===----------------------------------------------------------------------===//
#ifdef LLVMIR_OP_BASE
#else
#define LLVMIR_OP_BASE
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
// LLVM IR type wrapped in MLIR.
def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
"LLVM dialect type">;
// Base class for LLVM operations. Defines the interface to the llvm::IRBuilder
// used to translate to LLVM IR proper.
class LLVM_OpBase<string mnemonic, list<OpTrait> traits = []> :
Op<mnemonic, traits> {
// A pattern for constructing the LLVM IR Instruction (or other Value) that
// corresponds to this op. This pattern can use `builder` to refer to an
// `llvm::IRBuilder<>` instance, $-names of arguments and results and the
// following special variable names:
// - $_resultType - substituted with the LLVM IR type of the result;
// - $_numOperands - substituted with the number of operands (including
// the variadic ones);
// - $_hasResult - substituted with a check that a variadic-result op does
// have a result (LLVM ops can have 0 or 1 result);
// - $_location - mlir::Location object of the instruction.
// Additionally, `$$` can be used to produce the dollar character.
string llvmBuilder = "";
}
#endif // LLVMIR_OP_BASE

View File

@ -23,33 +23,14 @@
#else
#define LLVMIR_OPS
#ifdef OP_BASE
#else
include "mlir/IR/OpBase.td"
#endif // OP_BASE
// LLVM IR type wrapped in MLIR.
def LLVM_Type : Type<CPred<"$_self.isa<::mlir::LLVM::LLVMType>()">,
"LLVM dialect type">;
include "mlir/LLVMIR/LLVMOpBase.td"
// Base class for LLVM operations. All operations get an "llvm." prefix in
// their name automatically. LLVM operations have either zero or one result,
// this class is specialized below for both cases and should not be used
// directly.
class LLVM_Op<string mnemonic, list<OpTrait> traits = []> :
Op<!strconcat("llvm.", mnemonic), traits> {
// A pattern for constructing the LLVM IR Instruction (or other Value) that
// corresponds to this op. This pattern can use `builder` to refer to an
// `llvm::IRBuilder<>` instance, $-names of arguments and results and the
// following special variable names:
// - $_resultType - substituted with the LLVM IR type of the result;
// - $_numOperands - substituted with the number of operands (including
// the variadic ones);
// - $_hasResult - substituted with a check that a variadic-result op does
// have a result (LLVM ops can have 0 or 1 result);
// - $_location - mlir::Location object of the instruction.
// Additionally, `$$` can be used to produce the dollar character.
string llvmBuilder = "";
LLVM_OpBase<!strconcat("llvm.", mnemonic), traits> {
}
class LLVM_Builder<string builder> {

View File

@ -0,0 +1,43 @@
//===- NVVMDialect.h - MLIR NVVM IR dialect ---------------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the NVVM IR dialect in MLIR, containing NVVM operations and
// NVVM specific extensions to the LLVM type system.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_LLVMIR_NVVMDIALECT_H_
#define MLIR_LLVMIR_NVVMDIALECT_H_
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
namespace mlir {
namespace NVVM {
///// Ops /////
#define GET_OP_CLASSES
#include "mlir/LLVMIR/NVVMOps.h.inc"
class NVVMDialect : public Dialect {
public:
explicit NVVMDialect(MLIRContext *context);
};
} // namespace NVVM
} // namespace mlir
#endif /* MLIR_LLVMIR_NVVMDIALECT_H_ */

View File

@ -0,0 +1,55 @@
//===-- NVVMOps.td - NVVM IR dialect op definition file ----*- tablegen -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This is the NVVM IR operation definition file.
//
//===----------------------------------------------------------------------===//
#ifdef NVVMIR_OPS
#else
#define NVVMIR_OPS
include "mlir/LLVMIR/LLVMOpBase.td"
class NVVM_Op<string mnemonic, list<OpTrait> traits = []> :
LLVM_OpBase<!strconcat("nvvm.", mnemonic), traits> {
}
class NVVM_SpecialRegisterOp<string mnemonic,
list<OpTrait> traits = []> :
NVVM_Op<mnemonic, !listconcat(traits, [NoSideEffect])>,
Results<(outs LLVM_Type:$res)>, Arguments<(ins)> {
string llvmBuilder = "createIntrinsicCall(builder, llvm::Intrinsic::nvvm_"
# !subst(".","_", mnemonic) # ");";
let parser = [{ return parseNVVMSpecialRegisterOp(parser, result); }];
let printer = [{ printNVVMSpecialRegisterOp(p, this->getOperation()); }];
}
def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">;
def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">;
def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">;
def NVVM_ThreadDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">;
def NVVM_ThreadDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">;
def NVVM_ThreadDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">;
def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">;
def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">;
def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">;
def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">;
def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">;
def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">;
#endif // NVVMIR_OPS

View File

@ -0,0 +1,88 @@
//===- NVVMDialect.cpp - NVVM IR Ops and Dialect registration -------------===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
//
// This file defines the types and operation details for the NVVM IR dialect in
// MLIR, and the LLVM IR dialect. It also registers the dialect.
//
// The NVVM dialect only contains GPU specific additions on top of the general
// LLVM dialect.
//
//===----------------------------------------------------------------------===//
#include "mlir/LLVMIR/NVVMDialect.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Operation.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/LLVMIR/LLVMDialect.h"
#include "llvm/AsmParser/Parser.h"
#include "llvm/IR/Attributes.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/Type.h"
#include "llvm/Support/SourceMgr.h"
namespace mlir {
namespace NVVM {
//===----------------------------------------------------------------------===//
// Printing/parsing for NVVM ops
//===----------------------------------------------------------------------===//
static void printNVVMSpecialRegisterOp(OpAsmPrinter *p, Operation *op) {
*p << op->getName() << " : ";
if (op->getNumResults() == 1) {
*p << op->getResult(0)->getType();
} else {
*p << "###invalid type###";
}
}
// <operation> ::= `llvm.nvvm.XYZ` : type
static bool parseNVVMSpecialRegisterOp(OpAsmParser *parser,
OperationState *result) {
Type type;
if (parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type))
return true;
result->addTypes(type);
return false;
}
//===----------------------------------------------------------------------===//
// NVVMDialect initialization, type parsing, and registration.
//===----------------------------------------------------------------------===//
// TODO(herhut): This should be the llvm.nvvm dialect once this is supported.
NVVMDialect::NVVMDialect(MLIRContext *context) : Dialect("nvvm", context) {
addOperations<
#define GET_OP_LIST
#include "mlir/LLVMIR/NVVMOps.cpp.inc"
>();
// Support unknown operations because not all NVVM operations are registered.
allowUnknownOperations();
}
#define GET_OP_CLASSES
#include "mlir/LLVMIR/NVVMOps.cpp.inc"
static DialectRegistration<NVVMDialect> nvvmDialect;
} // namespace NVVM
} // namespace mlir

View File

@ -0,0 +1,29 @@
// RUN: mlir-opt %s | FileCheck %s
func @nvvm_special_regs() -> !llvm.i32 {
// CHECK: %0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
%0 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
// CHECK: %1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
%1 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
// CHECK: %2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
%2 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
// CHECK: %3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
%3 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
// CHECK: %4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
%4 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
// CHECK: %5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
%5 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
// CHECK: %6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
%6 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
// CHECK: %7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
%7 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
// CHECK: %8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
%8 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
// CHECK: %9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
%9 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
// CHECK: %10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
%10 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
// CHECK: %11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
%11 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
llvm.return %0 : !llvm.i32
}

View File

@ -173,7 +173,7 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
// Emit all builders. Returns false on success because of the generator
// registration requirements.
static bool emitBuilders(const RecordKeeper &recordKeeper, raw_ostream &os) {
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_Op")) {
for (const auto *def : recordKeeper.getAllDerivedDefinitions("LLVM_OpBase")) {
if (!emitOneBuilder(*def, os))
return true;
}