forked from OSchip/llvm-project
[AMDGPU] Match signed dot4/8 pattern.
Summary: This patch matches signed dot4 and dot8 pattern. Author: FarhanaAleen Reviewed By: msearles Differential Revision: https://reviews.llvm.org/D52520 llvm-svn: 343798
This commit is contained in:
parent
8920428376
commit
4bc597bff5
|
@ -165,34 +165,40 @@ def V_FMA_MIXHI_F16 : VOP3_VOP3PInst<"v_fma_mixhi_f16", VOP3_Profile<VOP_F16_F16
|
|||
defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
|
||||
}
|
||||
|
||||
class Srl<int N> : PatFrag<(ops node:$src),
|
||||
(srl node:$src, (i32 N))>;
|
||||
// Defines patterns that extract signed 4bit from each Idx[0].
|
||||
foreach Idx = [[0,28],[4,24],[8,20],[12,16],[16,12],[20,8],[24,4]] in
|
||||
def ExtractSigned4bit_#Idx[0] : PatFrag<(ops node:$src),
|
||||
(sra (shl node:$src, (i32 Idx[1])), (i32 28))>;
|
||||
|
||||
foreach Bits = 1-7 in
|
||||
def srl#!shl(Bits, 2) : Srl<!shl(Bits, 2)>;
|
||||
|
||||
class Extract_U<int FromBitIndex, int BitMask> : PatFrag<
|
||||
// Defines code pattern that extracts U(unsigned/signed) 4/8bit from FromBitIndex.
|
||||
class Extract<int FromBitIndex, int BitMask, bit U>: PatFrag<
|
||||
(ops node:$src),
|
||||
!if (!or (!and (!eq (BitMask, 255), !eq (FromBitIndex, 24)),
|
||||
!and (!eq (BitMask, 15), !eq (FromBitIndex, 28))), // last element
|
||||
(!cast<Srl>("srl"#FromBitIndex) node:$src),
|
||||
!if (!or (!and (!eq (BitMask, 255), !eq (FromBitIndex, 24)), !eq (FromBitIndex, 28)), // last element
|
||||
!if (U, (srl node:$src, (i32 FromBitIndex)), (sra node:$src, (i32 FromBitIndex))),
|
||||
!if (!eq (FromBitIndex, 0), // first element
|
||||
(and node:$src, (i32 BitMask)),
|
||||
(and (!cast<Srl>("srl"#FromBitIndex) node:$src), (i32 BitMask))))>;
|
||||
!if (U, (and node:$src, (i32 BitMask)),
|
||||
!if (!eq (BitMask, 15), (!cast<PatFrag>("ExtractSigned4bit_"#FromBitIndex) node:$src),
|
||||
(sext_inreg node:$src, i8))),
|
||||
!if (U, (and (srl node:$src, (i32 FromBitIndex)), (i32 BitMask)),
|
||||
!if (!eq (BitMask, 15), (!cast<PatFrag>("ExtractSigned4bit_"#FromBitIndex) node:$src),
|
||||
(sext_inreg (srl node:$src, (i32 FromBitIndex)), i8)))))>;
|
||||
|
||||
foreach Index = 0-3 in {
|
||||
|
||||
foreach Type = ["I", "U"] in
|
||||
foreach Index = 0-3 in {
|
||||
// Defines patterns that extract each Index'ed 8bit from an unsigned
|
||||
// 32bit scalar value;
|
||||
def U#Index#"_8bit" : Extract_U<!shl(Index, 3),
|
||||
255>;
|
||||
def #Type#Index#"_8bit" : Extract<!shl(Index, 3), 255, !if (!eq (Type, "U"), 1, 0)>;
|
||||
|
||||
// Defines multiplication patterns where the multiplication is happening on each
|
||||
// Index'ed 8bit of a 32bit scalar value.
|
||||
def MulU_Elt#Index : PatFrag<
|
||||
|
||||
def Mul#Type#_Elt#Index : PatFrag<
|
||||
(ops node:$src0, node:$src1),
|
||||
(AMDGPUmul_u24_oneuse (!cast<Extract_U>("U"#Index#"_8bit") node:$src0),
|
||||
(!cast<Extract_U>("U"#Index#"_8bit") node:$src1))>;
|
||||
}
|
||||
(!cast<HasOneUseBinOp>(!if (!eq (Type, "I"), AMDGPUmul_i24_oneuse, AMDGPUmul_u24_oneuse))
|
||||
(!cast<Extract>(#Type#Index#"_8bit") node:$src0),
|
||||
(!cast<Extract>(#Type#Index#"_8bit") node:$src1))>;
|
||||
}
|
||||
|
||||
// Different variants of dot8 patterns cause a huge increase in the compile time.
|
||||
// Define non-associative/commutative add/mul to prevent permutation in the dot8
|
||||
|
@ -203,19 +209,23 @@ def NonACAdd_oneuse : HasOneUseBinOp<NonACAdd>;
|
|||
def NonACAMDGPUmul_u24 : SDNode<"AMDGPUISD::MUL_U24" , SDTIntBinOp>;
|
||||
def NonACAMDGPUmul_u24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_u24>;
|
||||
|
||||
foreach Index = 0-7 in {
|
||||
def NonACAMDGPUmul_i24 : SDNode<"AMDGPUISD::MUL_I24" , SDTIntBinOp>;
|
||||
def NonACAMDGPUmul_i24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_i24>;
|
||||
|
||||
foreach Type = ["I", "U"] in
|
||||
foreach Index = 0-7 in {
|
||||
// Defines patterns that extract each Index'ed 4bit from an unsigned
|
||||
// 32bit scalar value;
|
||||
def U#Index#"_4bit" : Extract_U<!shl(Index, 2),
|
||||
15>;
|
||||
def #Type#Index#"_4bit" : Extract<!shl(Index, 2), 15, !if (!eq (Type, "U"), 1, 0)>;
|
||||
|
||||
// Defines multiplication patterns where the multiplication is happening on each
|
||||
// Index'ed 8bit of a 32bit scalar value.
|
||||
def MulU#Index#"_4bit" : PatFrag<
|
||||
def Mul#Type#Index#"_4bit" : PatFrag<
|
||||
(ops node:$src0, node:$src1),
|
||||
(NonACAMDGPUmul_u24_oneuse (!cast<Extract_U>("U"#Index#"_4bit") node:$src0),
|
||||
(!cast<Extract_U>("U"#Index#"_4bit") node:$src1))>;
|
||||
}
|
||||
(!cast<HasOneUseBinOp>(!if (!eq (Type, "I"), NonACAMDGPUmul_i24_oneuse, NonACAMDGPUmul_u24_oneuse))
|
||||
(!cast<Extract>(#Type#Index#"_4bit") node:$src0),
|
||||
(!cast<Extract>(#Type#Index#"_4bit") node:$src1))>;
|
||||
}
|
||||
|
||||
class UDot2Pat<Instruction Inst> : GCNPat <
|
||||
(add (add_oneuse (AMDGPUmul_u24_oneuse (srl i32:$src0, (i32 16)),
|
||||
|
@ -264,17 +274,18 @@ defm : DotPats<int_amdgcn_udot8, V_DOT8_U32_U4>;
|
|||
def : UDot2Pat<V_DOT2_U32_U16>;
|
||||
def : SDot2Pat<V_DOT2_I32_I16>;
|
||||
|
||||
def : GCNPat <
|
||||
foreach Type = ["U", "I"] in
|
||||
def : GCNPat <
|
||||
!cast<dag>(!foldl((i32 i32:$src2), [0, 1, 2, 3], lhs, y,
|
||||
(add_oneuse lhs, (!cast<PatFrag>("MulU_Elt"#y) i32:$src0, i32:$src1)))),
|
||||
(V_DOT4_U32_U8 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
|
||||
>;
|
||||
(add_oneuse lhs, (!cast<PatFrag>("Mul"#Type#"_Elt"#y) i32:$src0, i32:$src1)))),
|
||||
(!cast<VOP3PInst>("V_DOT4_"#Type#"32_"#Type#8) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>;
|
||||
|
||||
def : GCNPat <
|
||||
!cast<dag>(!foldl((add_oneuse i32:$src2, (MulU0_4bit i32:$src0, i32:$src1)), [1, 2, 3, 4, 5, 6, 7], lhs, y,
|
||||
(NonACAdd_oneuse lhs, (!cast<PatFrag>("MulU"#y#"_4bit") i32:$src0, i32:$src1)))),
|
||||
(V_DOT8_U32_U4 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
|
||||
>;
|
||||
foreach Type = ["U", "I"] in
|
||||
def : GCNPat <
|
||||
!cast<dag>(!foldl((add_oneuse i32:$src2, (!cast<PatFrag>("Mul"#Type#"0_4bit") i32:$src0, i32:$src1)),
|
||||
[1, 2, 3, 4, 5, 6, 7], lhs, y,
|
||||
(NonACAdd_oneuse lhs, (!cast<PatFrag>("Mul"#Type#y#"_4bit") i32:$src0, i32:$src1)))),
|
||||
(!cast<VOP3PInst>("V_DOT8_"#Type#"32_"#Type#4) (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))>;
|
||||
|
||||
} // End SubtargetPredicate = HasDLInsts
|
||||
|
||||
|
|
File diff suppressed because it is too large
Load Diff
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue