[mlir] Add bar.warp.sync to NVVM

It adds the missing `bar.warp.sync` to the nvvm dialect. It is a barrier to synchronize for threads in a warp.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D135253
This commit is contained in:
Guray Ozen 2022-10-05 12:23:32 +02:00
parent 1b9a6e58a8
commit 040805dc47
2 changed files with 16 additions and 0 deletions

View File

@ -168,6 +168,15 @@ def NVVM_VoteBallotOp :
let hasCustomAssemblyFormat = 1;
}
def NVVM_SyncWarpOp :
NVVM_Op<"bar.warp.sync">,
Arguments<(ins LLVM_Type:$mask)> {
string llvmBuilder = [{
createIntrinsicCall(builder, llvm::Intrinsic::nvvm_bar_warp_sync, {$mask});
}];
let assemblyFormat = "$mask attr-dict `:` type($mask)";
}
def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
Arguments<(ins LLVM_i8Ptr_shared:$dst,

View File

@ -78,6 +78,13 @@ func.func @nvvm_vote(%arg0 : i32, %arg1 : i1) -> i32 {
llvm.return %0 : i32
}
// CHECK-LABEL: @llvm_nvvm_bar_warp_sync
func.func @llvm_nvvm_bar_warp_sync(%mask : i32) {
// CHECK: nvvm.bar.warp.sync %{{.*}}
nvvm.bar.warp.sync %mask : i32
llvm.return
}
// CHECK-LABEL: @nvvm_mma_m8n8k4_row_col_f32_f32
func.func @nvvm_mma_m8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
%b0 : vector<2xf16>, %b1 : vector<2xf16>,