forked from OSchip/llvm-project
[AMDGPU] Match udot8 pattern
Summary: D.u32 = S0.u4[0] * S1.u4[0] + S0.u4[1] * S1.u4[1] + S0.u4[2] * S1.u4[2] + S0.u4[3] * S1.u4[3] + S0.u4[4] * S1.u4[4] + S0.u4[5] * S1.u4[5] + S0.u4[6] * S1.u4[6] + S0.u4[7] * S1.u4[7] + S2.u32 Author: FarhanaAleen Reviewed By: arsenm, nhaehnle Differential Revision: https://reviews.llvm.org/D51947 llvm-svn: 342497
This commit is contained in:
parent
b7471814cf
commit
f5a2848376
|
@ -168,34 +168,53 @@ defm : MadFmaMixPats<fma, V_FMA_MIX_F32, V_FMA_MIXLO_F16, V_FMA_MIXHI_F16>;
|
|||
class Srl<int N> : PatFrag<(ops node:$src),
|
||||
(srl node:$src, (i32 N))>;
|
||||
|
||||
foreach Bits = [8, 16, 24] in {
|
||||
def srl#Bits : Srl<Bits>;
|
||||
}
|
||||
foreach Bits = 1-7 in
|
||||
def srl#!shl(Bits, 2) : Srl<!shl(Bits, 2)>;
|
||||
|
||||
def and_255 : PatFrag<
|
||||
(ops node:$src0), (and node:$src0, (i32 255))
|
||||
>;
|
||||
|
||||
class Extract_U8<int FromBitIndex> : PatFrag<(
|
||||
ops node:$src),
|
||||
!if (!eq (FromBitIndex, 24), // last element
|
||||
class Extract_U<int FromBitIndex, int BitMask> : PatFrag<
|
||||
(ops node:$src),
|
||||
!if (!or (!and (!eq (BitMask, 255), !eq (FromBitIndex, 24)),
|
||||
!and (!eq (BitMask, 15), !eq (FromBitIndex, 28))), // last element
|
||||
(!cast<Srl>("srl"#FromBitIndex) node:$src),
|
||||
!if (!eq (FromBitIndex, 0), // first element
|
||||
(and_255 node:$src),
|
||||
(and_255 (!cast<Srl>("srl"#FromBitIndex) node:$src))))>;
|
||||
(and node:$src, (i32 BitMask)),
|
||||
(and (!cast<Srl>("srl"#FromBitIndex) node:$src), (i32 BitMask))))>;
|
||||
|
||||
// Defines patterns that extract each Index'ed 8bit from a 32bit scalar value;
|
||||
foreach Index = [1, 2, 3, 4] in {
|
||||
def UElt#Index : Extract_U8<!shl(!add(Index, -1), 3)>;
|
||||
}
|
||||
foreach Index = 0-3 in {
|
||||
// Defines patterns that extract each Index'ed 8bit from an unsigned
|
||||
// 32bit scalar value;
|
||||
def U#Index#"_8bit" : Extract_U<!shl(Index, 3),
|
||||
255>;
|
||||
|
||||
// Defines multiplication patterns where the multiplication is happening on each
|
||||
// Index'ed 8bit of a 32bit scalar value.
|
||||
foreach Index = [1, 2, 3, 4] in {
|
||||
def MulU_Elt#Index : PatFrag<
|
||||
(ops node:$src0, node:$src1),
|
||||
(AMDGPUmul_u24_oneuse (!cast<Extract_U8>("UElt"#Index) node:$src0),
|
||||
(!cast<Extract_U8>("UElt"#Index) node:$src1))>;
|
||||
(AMDGPUmul_u24_oneuse (!cast<Extract_U>("U"#Index#"_8bit") node:$src0),
|
||||
(!cast<Extract_U>("U"#Index#"_8bit") node:$src1))>;
|
||||
}
|
||||
|
||||
// Different variants of dot8 patterns cause a huge increase in the compile time.
|
||||
// Define non-associative/commutative add/mul to prevent permutation in the dot8
|
||||
// pattern.
|
||||
def NonACAdd : SDNode<"ISD::ADD" , SDTIntBinOp>;
|
||||
def NonACAdd_oneuse : HasOneUseBinOp<NonACAdd>;
|
||||
|
||||
def NonACAMDGPUmul_u24 : SDNode<"AMDGPUISD::MUL_U24" , SDTIntBinOp>;
|
||||
def NonACAMDGPUmul_u24_oneuse : HasOneUseBinOp<NonACAMDGPUmul_u24>;
|
||||
|
||||
foreach Index = 0-7 in {
|
||||
// Defines patterns that extract each Index'ed 4bit from an unsigned
|
||||
// 32bit scalar value;
|
||||
def U#Index#"_4bit" : Extract_U<!shl(Index, 2),
|
||||
15>;
|
||||
|
||||
// Defines multiplication patterns where the multiplication is happening on each
|
||||
// Index'ed 8bit of a 32bit scalar value.
|
||||
def MulU#Index#"_4bit" : PatFrag<
|
||||
(ops node:$src0, node:$src1),
|
||||
(NonACAMDGPUmul_u24_oneuse (!cast<Extract_U>("U"#Index#"_4bit") node:$src0),
|
||||
(!cast<Extract_U>("U"#Index#"_4bit") node:$src1))>;
|
||||
}
|
||||
|
||||
class UDot2Pat<Instruction Inst> : GCNPat <
|
||||
|
@ -246,11 +265,17 @@ def : UDot2Pat<V_DOT2_U32_U16>;
|
|||
def : SDot2Pat<V_DOT2_I32_I16>;
|
||||
|
||||
def : GCNPat <
|
||||
!cast<dag>(!foldl((i32 i32:$src2), [1, 2, 3, 4], lhs, y,
|
||||
!cast<dag>(!foldl((i32 i32:$src2), [0, 1, 2, 3], lhs, y,
|
||||
(add_oneuse lhs, (!cast<PatFrag>("MulU_Elt"#y) i32:$src0, i32:$src1)))),
|
||||
(V_DOT4_U32_U8 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
|
||||
>;
|
||||
|
||||
def : GCNPat <
|
||||
!cast<dag>(!foldl((add_oneuse i32:$src2, (MulU0_4bit i32:$src0, i32:$src1)), [1, 2, 3, 4, 5, 6, 7], lhs, y,
|
||||
(NonACAdd_oneuse lhs, (!cast<PatFrag>("MulU"#y#"_4bit") i32:$src0, i32:$src1)))),
|
||||
(V_DOT8_U32_U4 (i32 8), $src0, (i32 8), $src1, (i32 8), $src2, (i1 0))
|
||||
>;
|
||||
|
||||
} // End SubtargetPredicate = HasDLInsts
|
||||
|
||||
multiclass VOP3P_Real_vi<bits<10> op> {
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue