[mlir][LLVMIR] Add vector predication binary intrinsic ops.

Differential Revision: https://reviews.llvm.org/D122971
This commit is contained in:
jacquesguan 2022-04-02 17:35:51 +08:00
parent b389354b28
commit 2420d42925
3 changed files with 120 additions and 0 deletions

View File

@ -365,4 +365,14 @@ class LLVM_VectorReductionAcc<string mnem>
}];
}
// LLVM vector predication intrinsics.
class LLVM_VPBinaryBase<string mnem, Type element>
: LLVM_OneResultIntrOp<"vp." # mnem, [0], [], [NoSideEffect]>,
Arguments<(ins LLVM_VectorOf<element>:$lhs, LLVM_VectorOf<element>:$rhs,
LLVM_VectorOf<I1>:$mask, I32:$evl)>;
class LLVM_VPBinaryI<string mnem> : LLVM_VPBinaryBase<mnem, AnyInteger>;
class LLVM_VPBinaryF<string mnem> : LLVM_VPBinaryBase<mnem, AnyFloat>;
#endif // LLVMIR_OP_BASE

View File

@ -1921,4 +1921,34 @@ def LLVM_InlineAsmOp : LLVM_Op<"inline_asm", []> {
}
}];
}
//
// LLVM Vector Predication operations.
//
// Integer Binary
def LLVM_VPAddOp : LLVM_VPBinaryI<"add">;
def LLVM_VPSubOp : LLVM_VPBinaryI<"sub">;
def LLVM_VPMulOp : LLVM_VPBinaryI<"mul">;
def LLVM_VPSDivOp : LLVM_VPBinaryI<"sdiv">;
def LLVM_VPUDivOp : LLVM_VPBinaryI<"udiv">;
def LLVM_VPSRemOp : LLVM_VPBinaryI<"srem">;
def LLVM_VPURemOp : LLVM_VPBinaryI<"urem">;
def LLVM_VPAShrOp : LLVM_VPBinaryI<"ashr">;
def LLVM_VPLShrOp : LLVM_VPBinaryI<"lshr">;
def LLVM_VPShlOp : LLVM_VPBinaryI<"shl">;
def LLVM_VPOrOp : LLVM_VPBinaryI<"or">;
def LLVM_VPAndOp : LLVM_VPBinaryI<"and">;
def LLVM_VPXorOp : LLVM_VPBinaryI<"xor">;
// Float Binary
def LLVM_VPFAddOp : LLVM_VPBinaryF<"fadd">;
def LLVM_VPFSubOp : LLVM_VPBinaryF<"fsub">;
def LLVM_VPFMulOp : LLVM_VPBinaryF<"fmul">;
def LLVM_VPFDivOp : LLVM_VPBinaryF<"fdiv">;
def LLVM_VPFRemOp : LLVM_VPBinaryF<"frem">;
#endif // LLVMIR_OPS

View File

@ -515,6 +515,68 @@ llvm.func @stack_restore(%arg0: !llvm.ptr<i8>) {
llvm.return
}
// CHECK-LABEL: @vector_predication_intrinsics
llvm.func @vector_predication_intrinsics(%A: vector<8xi32>, %B: vector<8xi32>,
%C: vector<8xf32>, %D: vector<8xf32>,
%mask: vector<8xi1>, %evl: i32) {
// CHECK: call <8 x i32> @llvm.vp.add.v8i32
"llvm.intr.vp.add" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.sub.v8i32
"llvm.intr.vp.sub" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.mul.v8i32
"llvm.intr.vp.mul" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.sdiv.v8i32
"llvm.intr.vp.sdiv" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.udiv.v8i32
"llvm.intr.vp.udiv" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.srem.v8i32
"llvm.intr.vp.srem" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.urem.v8i32
"llvm.intr.vp.urem" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.ashr.v8i32
"llvm.intr.vp.ashr" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.lshr.v8i32
"llvm.intr.vp.lshr" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.shl.v8i32
"llvm.intr.vp.shl" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.or.v8i32
"llvm.intr.vp.or" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.and.v8i32
"llvm.intr.vp.and" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x i32> @llvm.vp.xor.v8i32
"llvm.intr.vp.xor" (%A, %B, %mask, %evl) :
(vector<8xi32>, vector<8xi32>, vector<8xi1>, i32) -> vector<8xi32>
// CHECK: call <8 x float> @llvm.vp.fadd.v8f32
"llvm.intr.vp.fadd" (%C, %D, %mask, %evl) :
(vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32>
// CHECK: call <8 x float> @llvm.vp.fsub.v8f32
"llvm.intr.vp.fsub" (%C, %D, %mask, %evl) :
(vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32>
// CHECK: call <8 x float> @llvm.vp.fmul.v8f32
"llvm.intr.vp.fmul" (%C, %D, %mask, %evl) :
(vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32>
// CHECK: call <8 x float> @llvm.vp.fdiv.v8f32
"llvm.intr.vp.fdiv" (%C, %D, %mask, %evl) :
(vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32>
// CHECK: call <8 x float> @llvm.vp.frem.v8f32
"llvm.intr.vp.frem" (%C, %D, %mask, %evl) :
(vector<8xf32>, vector<8xf32>, vector<8xi1>, i32) -> vector<8xf32>
llvm.return
}
// Check that intrinsics are declared with appropriate types.
// CHECK-DAG: declare float @llvm.fma.f32(float, float, float)
// CHECK-DAG: declare <8 x float> @llvm.fma.v8f32(<8 x float>, <8 x float>, <8 x float>) #0
@ -570,3 +632,21 @@ llvm.func @stack_restore(%arg0: !llvm.ptr<i8>) {
// CHECK-DAG: declare i1 @llvm.coro.end(i8*, i1)
// CHECK-DAG: declare i8* @llvm.coro.free(token, i8* nocapture readonly)
// CHECK-DAG: declare void @llvm.coro.resume(i8*)
// CHECK-DAG: declare <8 x i32> @llvm.vp.add.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x i32> @llvm.vp.sub.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x i32> @llvm.vp.mul.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x i32> @llvm.vp.sdiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #2
// CHECK-DAG: declare <8 x i32> @llvm.vp.udiv.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #2
// CHECK-DAG: declare <8 x i32> @llvm.vp.srem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #2
// CHECK-DAG: declare <8 x i32> @llvm.vp.urem.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #2
// CHECK-DAG: declare <8 x i32> @llvm.vp.ashr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x i32> @llvm.vp.lshr.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x i32> @llvm.vp.shl.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x i32> @llvm.vp.or.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x i32> @llvm.vp.and.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x i32> @llvm.vp.xor.v8i32(<8 x i32>, <8 x i32>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x float> @llvm.vp.fadd.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x float> @llvm.vp.fsub.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x float> @llvm.vp.fmul.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x float> @llvm.vp.fdiv.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0
// CHECK-DAG: declare <8 x float> @llvm.vp.frem.v8f32(<8 x float>, <8 x float>, <8 x i1>, i32) #0