forked from OSchip/llvm-project
[mlir][nvvm] Add async copy ops to nvvm dialect
Differential Revision: https://reviews.llvm.org/D115314
This commit is contained in:
parent
824ddeb994
commit
579c1ff67d
|
@ -94,12 +94,12 @@ class LLVM_PointerTo<Type pointee> : Type<
|
|||
"LLVM pointer to " # pointee.summary>;
|
||||
|
||||
// Type constraints accepting LLVM pointer type to integer of a specific width.
|
||||
class LLVM_IntPtrBase<int width> : Type<
|
||||
class LLVM_IntPtrBase<int width, int addressSpace = 0> : Type<
|
||||
LLVM_PointerTo<I<width>>.predicate,
|
||||
"LLVM pointer to " # I<width>.summary>,
|
||||
BuildableType<"::mlir::LLVM::LLVMPointerType::get("
|
||||
"::mlir::IntegerType::get($_builder.getContext(), "
|
||||
# width #"))">;
|
||||
# width #"), "# addressSpace #")">;
|
||||
|
||||
def LLVM_i8Ptr : LLVM_IntPtrBase<8>;
|
||||
|
||||
|
|
|
@ -16,6 +16,9 @@
|
|||
include "mlir/Dialect/LLVMIR/LLVMOpBase.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def LLVM_i8Ptr_global : LLVM_IntPtrBase<8, 1>;
|
||||
def LLVM_i8Ptr_shared : LLVM_IntPtrBase<8, 3>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NVVM dialect definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -157,6 +160,56 @@ def NVVM_VoteBallotOp :
|
|||
let printer = [{ printNVVMIntrinsicOp(p, this->getOperation()); }];
|
||||
}
|
||||
|
||||
|
||||
def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
|
||||
Arguments<(ins LLVM_i8Ptr_shared:$dst,
|
||||
LLVM_i8Ptr_global:$src,
|
||||
I32Attr:$size)> {
|
||||
string llvmBuilder = [{
|
||||
llvm::Intrinsic::ID id;
|
||||
switch ($size) {
|
||||
case 4:
|
||||
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_4;
|
||||
break;
|
||||
case 8:
|
||||
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_8;
|
||||
break;
|
||||
case 16:
|
||||
id = llvm::Intrinsic::nvvm_cp_async_ca_shared_global_16;
|
||||
break;
|
||||
default:
|
||||
llvm_unreachable("unsupported async copy size");
|
||||
}
|
||||
createIntrinsicCall(builder, id, {$dst, $src});
|
||||
}];
|
||||
let verifier = [{
|
||||
if (size() != 4 && size() != 8 && size() != 16)
|
||||
return emitError("expected byte size to be either 4, 8 or 16.");
|
||||
return success();
|
||||
}];
|
||||
let assemblyFormat = "$dst `,` $src `,` $size attr-dict";
|
||||
}
|
||||
|
||||
def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
|
||||
string llvmBuilder = [{
|
||||
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_cp_async_commit_group);
|
||||
}];
|
||||
let assemblyFormat = "attr-dict";
|
||||
}
|
||||
|
||||
def NVVM_CpAsyncWaitGroupOp : NVVM_Op<"cp.async.wait.group">,
|
||||
Arguments<(ins I32Attr:$n)> {
|
||||
string llvmBuilder = [{
|
||||
createIntrinsicCall(
|
||||
builder,
|
||||
llvm::Intrinsic::nvvm_cp_async_wait_group,
|
||||
llvm::ConstantInt::get(
|
||||
llvm::Type::getInt32Ty(moduleTranslation.getLLVMContext()),
|
||||
$n));
|
||||
}];
|
||||
let assemblyFormat = "$n attr-dict";
|
||||
}
|
||||
|
||||
def NVVM_MmaOp :
|
||||
NVVM_Op<"mma.sync">,
|
||||
Results<(outs LLVM_Type:$res)>,
|
||||
|
|
|
@ -1226,3 +1226,11 @@ func @bitcast(%arg0: vector<2x3xf32>) {
|
|||
llvm.bitcast %arg0 : vector<2x3xf32> to vector<2x3xi32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
|
||||
// expected-error @below {{expected byte size to be either 4, 8 or 16.}}
|
||||
nvvm.cp.async.shared.global %arg0, %arg1, 32
|
||||
return
|
||||
}
|
||||
|
|
|
@ -95,6 +95,15 @@ func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 : i32,
|
|||
llvm.return %r : !llvm.struct<(f32, f32, f32, f32, f32, f32, f32, f32)>
|
||||
}
|
||||
|
||||
llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
|
||||
// CHECK: nvvm.cp.async.shared.global %{{.*}}, %{{.*}}, 16
|
||||
nvvm.cp.async.shared.global %arg0, %arg1, 16
|
||||
// CHECK: nvvm.cp.async.commit.group
|
||||
nvvm.cp.async.commit.group
|
||||
// CHECK: nvvm.cp.async.wait.group 0
|
||||
nvvm.cp.async.wait.group 0
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
|
|
|
@ -162,6 +162,20 @@ llvm.func @nvvm_wmma_mma(%0 : i32, %1 : i32, %2 : i32, %3 : i32, %4 : i32, %5 :
|
|||
llvm.return
|
||||
}
|
||||
|
||||
llvm.func @cp_async(%arg0: !llvm.ptr<i8, 3>, %arg1: !llvm.ptr<i8, 1>) {
|
||||
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.4(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
|
||||
nvvm.cp.async.shared.global %arg0, %arg1, 4
|
||||
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.8(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
|
||||
nvvm.cp.async.shared.global %arg0, %arg1, 8
|
||||
// CHECK: call void @llvm.nvvm.cp.async.ca.shared.global.16(i8 addrspace(3)* %{{.*}}, i8 addrspace(1)* %{{.*}})
|
||||
nvvm.cp.async.shared.global %arg0, %arg1, 16
|
||||
// CHECK: call void @llvm.nvvm.cp.async.commit.group()
|
||||
nvvm.cp.async.commit.group
|
||||
// CHECK: call void @llvm.nvvm.cp.async.wait.group(i32 0)
|
||||
nvvm.cp.async.wait.group 0
|
||||
llvm.return
|
||||
}
|
||||
|
||||
// This function has the "kernel" attribute attached and should appear in the
|
||||
// NVVM annotations after conversion.
|
||||
llvm.func @kernel_func() attributes {nvvm.kernel} {
|
||||
|
|
Loading…
Reference in New Issue