[mlir][nvvm] Add async copy ops to nvvm dialect

Differential Revision: https://reviews.llvm.org/D115314
This commit is contained in:
Thomas Raoux 2021-12-07 19:28:14 -08:00
parent 824ddeb994
commit 579c1ff67d
5 changed files with 86 additions and 2 deletions

View File

@ -94,12 +94,12 @@ class LLVM_PointerTo<Type pointee> : Type<
"LLVM pointer to " # pointee.summary>;
// Type constraints accepting LLVM pointer type to integer of a specific width.
class LLVM_IntPtrBase<int width> : Type<
class LLVM_IntPtrBase<int width, int addressSpace = 0> : Type<
LLVM_PointerTo<I<width>>.predicate,
"LLVM pointer to " # I<width>.summary>,
BuildableType<"::mlir::LLVM::LLVMPointerType::get("
"::mlir::IntegerType::get($_builder.getContext(), "
# width #"))">;
# width #"), "# addressSpace #")">;
def LLVM_i8Ptr : LLVM_IntPtrBase<8>;

View File

@ -16,6 +16,9 @@
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>;
def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>;
//===----------------------------------------------------------------------===//
// NVVM dialect definitions
//===----------------------------------------------------------------------===//
@ -157,6 +160,56 @@ def NVVM_VoteBallotOp :
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
}
def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
Arguments<(ins LLVM_i8Ptr_shared:$dst,
LLVM_i8Ptr_global:$src,
I32Attr:$size)> {
string llvmBuilder = [{
llvm::Intrinsic::ID id;
switch ($size) {
case 4:
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4;
break;
case 8:
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
break;
case 16:
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
break;
default:
llvm_unreachable("unsupported async copy size");
}
createIntrinsicCall(builder, id, {$dst, $src});
}];
let verifier = [{
if (size() != 4 && size() != 8 && size() != 16)
return emitError("expected byte size to be either 4, 8 or 16.");
return success();
}];
let assemblyFormat = "$dst `,` $src `,` $size attr-dict";
}
def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_commit_group);
}];
let assemblyFormat = "attr-dict";
}
def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
Arguments<(ins I32Attr:$n)> {
string llvmBuilder = [{
createIntrinsicCall(
builder,
llvm::Intrinsic::nvvm_cp_async_wait_group,
llvm::ConstantInt::get(
llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext()),
$n));
}];
let assemblyFormat = "$n attr-dict";
}
def NVVM_MmaOp :
NVVM_Op<"mma.sync">,
Results<(outs LLVM_Type:$res)>,

View File

@ -1226,3 +1226,11 @@ func @bitcast(%arg0: vector<2x3xf32>) {
llvm.bitcast %arg0 : vector<2x3xf32> to vector<2x3xi32>
return
}
// -----
func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
// expected-error @below {{expected byte size to be either 4, 8 or 16.}}
nvvm.cp.async.shared.global %arg0, %arg1, 32
return
}

View File

@ -95,6 +95,15 @@ func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32,
llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
}
llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16
nvvm.cp.async.shared.global %arg0, %arg1, 16
// CHECK: nvvm.cp.async.commit.group
nvvm.cp.async.commit.group
// CHECK: nvvm.cp.async.wait.group 0
nvvm.cp.async.wait.group 0
llvm.return
}
// -----

View File

@ -162,6 +162,20 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
llvm.return
}
llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
nvvm.cp.async.shared.global %arg0, %arg1, 4
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
nvvm.cp.async.shared.global %arg0, %arg1, 8
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
nvvm.cp.async.shared.global %arg0, %arg1, 16
// CHECK: call void @llvm.nvvm.cp.async.commit.group()
nvvm.cp.async.commit.group
// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
nvvm.cp.async.wait.group 0
llvm.return
}
// This function has the "kernel" attribute attached and should appear in the
// NVVM annotations after conversion.
llvm.func @kernel_func() attributes {nvvm.kernel} {