forked from OSchip/llvm-project
[WebAssembly] Custom optimization for truncate
When possible, optimize TRUNCATE to generate Wasm SIMD narrow instructions (i16x8.narrow_i32x4_u, i8x16.narrow_i16x8_u), rather than generate lots of extract_lane and replace_lane. Closes #50350.
This commit is contained in:
parent
0b9b1c8c49
commit
2a4a229d6d
|
@ -31,6 +31,7 @@ HANDLE_NODETYPE(SWIZZLE)
|
|||
HANDLE_NODETYPE(VEC_SHL)
|
||||
HANDLE_NODETYPE(VEC_SHR_S)
|
||||
HANDLE_NODETYPE(VEC_SHR_U)
|
||||
HANDLE_NODETYPE(NARROW_U)
|
||||
HANDLE_NODETYPE(EXTEND_LOW_S)
|
||||
HANDLE_NODETYPE(EXTEND_LOW_U)
|
||||
HANDLE_NODETYPE(EXTEND_HIGH_S)
|
||||
|
|
|
@ -176,6 +176,8 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
|
|||
setTargetDAGCombine(ISD::FP_ROUND);
|
||||
setTargetDAGCombine(ISD::CONCAT_VECTORS);
|
||||
|
||||
setTargetDAGCombine(ISD::TRUNCATE);
|
||||
|
||||
// Support saturating add for i8x16 and i16x8
|
||||
for (auto Op : {ISD::SADDSAT, ISD::UADDSAT})
|
||||
for (auto T : {MVT::v16i8, MVT::v8i16})
|
||||
|
@ -2609,6 +2611,114 @@ performVectorTruncZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
|
|||
return DAG.getNode(Op, SDLoc(N), ResVT, Source);
|
||||
}
|
||||
|
||||
// Helper to extract VectorWidth bits from Vec, starting from IdxVal.
|
||||
static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
|
||||
const SDLoc &DL, unsigned VectorWidth) {
|
||||
EVT VT = Vec.getValueType();
|
||||
EVT ElVT = VT.getVectorElementType();
|
||||
unsigned Factor = VT.getSizeInBits() / VectorWidth;
|
||||
EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT,
|
||||
VT.getVectorNumElements() / Factor);
|
||||
|
||||
// Extract the relevant VectorWidth bits. Generate an EXTRACT_SUBVECTOR
|
||||
unsigned ElemsPerChunk = VectorWidth / ElVT.getSizeInBits();
|
||||
assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2");
|
||||
|
||||
// This is the index of the first element of the VectorWidth-bit chunk
|
||||
// we want. Since ElemsPerChunk is a power of 2 just need to clear bits.
|
||||
IdxVal &= ~(ElemsPerChunk - 1);
|
||||
|
||||
// If the input is a buildvector just emit a smaller one.
|
||||
if (Vec.getOpcode() == ISD::BUILD_VECTOR)
|
||||
return DAG.getBuildVector(ResultVT, DL,
|
||||
Vec->ops().slice(IdxVal, ElemsPerChunk));
|
||||
|
||||
SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, DL);
|
||||
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResultVT, Vec, VecIdx);
|
||||
}
|
||||
|
||||
// Helper to recursively truncate vector elements in half with NARROW_U. DstVT
|
||||
// is the expected destination value type after recursion. In is the initial
|
||||
// input. Note that the input should have enough leading zero bits to prevent
|
||||
// NARROW_U from saturating results.
|
||||
static SDValue truncateVectorWithNARROW(EVT DstVT, SDValue In, const SDLoc &DL,
|
||||
SelectionDAG &DAG) {
|
||||
EVT SrcVT = In.getValueType();
|
||||
|
||||
// No truncation required, we might get here due to recursive calls.
|
||||
if (SrcVT == DstVT)
|
||||
return In;
|
||||
|
||||
unsigned SrcSizeInBits = SrcVT.getSizeInBits();
|
||||
unsigned NumElems = SrcVT.getVectorNumElements();
|
||||
if (!isPowerOf2_32(NumElems))
|
||||
return SDValue();
|
||||
assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
|
||||
assert(SrcSizeInBits > DstVT.getSizeInBits() && "Illegal truncation");
|
||||
|
||||
LLVMContext &Ctx = *DAG.getContext();
|
||||
EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2);
|
||||
|
||||
// Narrow to the largest type possible:
|
||||
// vXi64/vXi32 -> i16x8.narrow_i32x4_u and vXi16 -> i8x16.narrow_i16x8_u.
|
||||
EVT InVT = MVT::i16, OutVT = MVT::i8;
|
||||
if (SrcVT.getScalarSizeInBits() > 16) {
|
||||
InVT = MVT::i32;
|
||||
OutVT = MVT::i16;
|
||||
}
|
||||
unsigned SubSizeInBits = SrcSizeInBits / 2;
|
||||
InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits());
|
||||
OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits());
|
||||
|
||||
// Split lower/upper subvectors.
|
||||
SDValue Lo = extractSubVector(In, 0, DAG, DL, SubSizeInBits);
|
||||
SDValue Hi = extractSubVector(In, NumElems / 2, DAG, DL, SubSizeInBits);
|
||||
|
||||
// 256bit -> 128bit truncate - Narrow lower/upper 128-bit subvectors.
|
||||
if (SrcVT.is256BitVector() && DstVT.is128BitVector()) {
|
||||
Lo = DAG.getBitcast(InVT, Lo);
|
||||
Hi = DAG.getBitcast(InVT, Hi);
|
||||
SDValue Res = DAG.getNode(WebAssemblyISD::NARROW_U, DL, OutVT, Lo, Hi);
|
||||
return DAG.getBitcast(DstVT, Res);
|
||||
}
|
||||
|
||||
// Recursively narrow lower/upper subvectors, concat result and narrow again.
|
||||
EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems / 2);
|
||||
Lo = truncateVectorWithNARROW(PackedVT, Lo, DL, DAG);
|
||||
Hi = truncateVectorWithNARROW(PackedVT, Hi, DL, DAG);
|
||||
|
||||
PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
|
||||
SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi);
|
||||
return truncateVectorWithNARROW(DstVT, Res, DL, DAG);
|
||||
}
|
||||
|
||||
static SDValue performTruncateCombine(SDNode *N,
|
||||
TargetLowering::DAGCombinerInfo &DCI) {
|
||||
auto &DAG = DCI.DAG;
|
||||
|
||||
SDValue In = N->getOperand(0);
|
||||
EVT InVT = In.getValueType();
|
||||
if (!InVT.isSimple())
|
||||
return SDValue();
|
||||
|
||||
EVT OutVT = N->getValueType(0);
|
||||
if (!OutVT.isVector())
|
||||
return SDValue();
|
||||
|
||||
EVT OutSVT = OutVT.getVectorElementType();
|
||||
EVT InSVT = InVT.getVectorElementType();
|
||||
// Currently only cover truncate to v16i8 or v8i16.
|
||||
if (!((InSVT == MVT::i16 || InSVT == MVT::i32 || InSVT == MVT::i64) &&
|
||||
(OutSVT == MVT::i8 || OutSVT == MVT::i16) && OutVT.is128BitVector()))
|
||||
return SDValue();
|
||||
|
||||
SDLoc DL(N);
|
||||
APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(),
|
||||
OutVT.getScalarSizeInBits());
|
||||
In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT));
|
||||
return truncateVectorWithNARROW(OutVT, In, DL, DAG);
|
||||
}
|
||||
|
||||
SDValue
|
||||
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
|
||||
DAGCombinerInfo &DCI) const {
|
||||
|
@ -2625,5 +2735,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
|
|||
case ISD::FP_ROUND:
|
||||
case ISD::CONCAT_VECTORS:
|
||||
return performVectorTruncZeroCombine(N, DCI);
|
||||
case ISD::TRUNCATE:
|
||||
return performTruncateCombine(N, DCI);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1278,6 +1278,14 @@ multiclass SIMDNarrow<Vec vec, bits<32> baseInst> {
|
|||
defm "" : SIMDNarrow<I16x8, 101>;
|
||||
defm "" : SIMDNarrow<I32x4, 133>;
|
||||
|
||||
// WebAssemblyISD::NARROW_U
|
||||
def wasm_narrow_t : SDTypeProfile<1, 2, []>;
|
||||
def wasm_narrow_u : SDNode<"WebAssemblyISD::NARROW_U", wasm_narrow_t>;
|
||||
def : Pat<(v16i8 (wasm_narrow_u (v8i16 V128:$left), (v8i16 V128:$right))),
|
||||
(NARROW_U_I8x16 $left, $right)>;
|
||||
def : Pat<(v8i16 (wasm_narrow_u (v4i32 V128:$left), (v4i32 V128:$right))),
|
||||
(NARROW_U_I16x8 $left, $right)>;
|
||||
|
||||
// Bitcasts are nops
|
||||
// Matching bitcast t1 to t1 causes strange errors, so avoid repeating types
|
||||
foreach t1 = AllVecs in
|
||||
|
|
|
@ -532,7 +532,7 @@ entry:
|
|||
define <8 x i16> @stest_f16i16(<8 x half> %x) {
|
||||
; CHECK-LABEL: stest_f16i16:
|
||||
; CHECK: .functype stest_f16i16 (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128)
|
||||
; CHECK-NEXT: .local v128, v128
|
||||
; CHECK-NEXT: .local v128, v128, v128
|
||||
; CHECK-NEXT: # %bb.0: # %entry
|
||||
; CHECK-NEXT: local.get 5
|
||||
; CHECK-NEXT: call __truncsfhf2
|
||||
|
@ -578,6 +578,9 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) {
|
|||
; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768
|
||||
; CHECK-NEXT: local.tee 9
|
||||
; CHECK-NEXT: i32x4.max_s
|
||||
; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
|
||||
; CHECK-NEXT: local.tee 10
|
||||
; CHECK-NEXT: v128.and
|
||||
; CHECK-NEXT: local.get 4
|
||||
; CHECK-NEXT: i32.trunc_sat_f32_s
|
||||
; CHECK-NEXT: i32x4.splat
|
||||
|
@ -594,7 +597,9 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) {
|
|||
; CHECK-NEXT: i32x4.min_s
|
||||
; CHECK-NEXT: local.get 9
|
||||
; CHECK-NEXT: i32x4.max_s
|
||||
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
|
||||
; CHECK-NEXT: local.get 10
|
||||
; CHECK-NEXT: v128.and
|
||||
; CHECK-NEXT: i16x8.narrow_i32x4_u
|
||||
; CHECK-NEXT: # fallthrough-return
|
||||
entry:
|
||||
%conv = fptosi <8 x half> %x to <8 x i32>
|
||||
|
@ -666,7 +671,7 @@ define <8 x i16> @utesth_f16i16(<8 x half> %x) {
|
|||
; CHECK-NEXT: i32x4.replace_lane 3
|
||||
; CHECK-NEXT: local.get 8
|
||||
; CHECK-NEXT: i32x4.min_u
|
||||
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
|
||||
; CHECK-NEXT: i16x8.narrow_i32x4_u
|
||||
; CHECK-NEXT: # fallthrough-return
|
||||
entry:
|
||||
%conv = fptoui <8 x half> %x to <8 x i32>
|
||||
|
@ -741,7 +746,7 @@ define <8 x i16> @ustest_f16i16(<8 x half> %x) {
|
|||
; CHECK-NEXT: i32x4.min_s
|
||||
; CHECK-NEXT: local.get 9
|
||||
; CHECK-NEXT: i32x4.max_s
|
||||
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
|
||||
; CHECK-NEXT: i16x8.narrow_i32x4_u
|
||||
; CHECK-NEXT: # fallthrough-return
|
||||
entry:
|
||||
%conv = fptosi <8 x half> %x to <8 x i32>
|
||||
|
@ -2106,7 +2111,7 @@ entry:
|
|||
define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
|
||||
; CHECK-LABEL: stest_f16i16_mm:
|
||||
; CHECK: .functype stest_f16i16_mm (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128)
|
||||
; CHECK-NEXT: .local v128, v128
|
||||
; CHECK-NEXT: .local v128, v128, v128
|
||||
; CHECK-NEXT: # %bb.0: # %entry
|
||||
; CHECK-NEXT: local.get 5
|
||||
; CHECK-NEXT: call __truncsfhf2
|
||||
|
@ -2152,6 +2157,9 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
|
|||
; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768
|
||||
; CHECK-NEXT: local.tee 9
|
||||
; CHECK-NEXT: i32x4.max_s
|
||||
; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
|
||||
; CHECK-NEXT: local.tee 10
|
||||
; CHECK-NEXT: v128.and
|
||||
; CHECK-NEXT: local.get 4
|
||||
; CHECK-NEXT: i32.trunc_sat_f32_s
|
||||
; CHECK-NEXT: i32x4.splat
|
||||
|
@ -2168,7 +2176,9 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
|
|||
; CHECK-NEXT: i32x4.min_s
|
||||
; CHECK-NEXT: local.get 9
|
||||
; CHECK-NEXT: i32x4.max_s
|
||||
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
|
||||
; CHECK-NEXT: local.get 10
|
||||
; CHECK-NEXT: v128.and
|
||||
; CHECK-NEXT: i16x8.narrow_i32x4_u
|
||||
; CHECK-NEXT: # fallthrough-return
|
||||
entry:
|
||||
%conv = fptosi <8 x half> %x to <8 x i32>
|
||||
|
@ -2238,7 +2248,7 @@ define <8 x i16> @utesth_f16i16_mm(<8 x half> %x) {
|
|||
; CHECK-NEXT: i32x4.replace_lane 3
|
||||
; CHECK-NEXT: local.get 8
|
||||
; CHECK-NEXT: i32x4.min_u
|
||||
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
|
||||
; CHECK-NEXT: i16x8.narrow_i32x4_u
|
||||
; CHECK-NEXT: # fallthrough-return
|
||||
entry:
|
||||
%conv = fptoui <8 x half> %x to <8 x i32>
|
||||
|
@ -2312,7 +2322,7 @@ define <8 x i16> @ustest_f16i16_mm(<8 x half> %x) {
|
|||
; CHECK-NEXT: i32x4.min_s
|
||||
; CHECK-NEXT: local.get 9
|
||||
; CHECK-NEXT: i32x4.max_s
|
||||
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
|
||||
; CHECK-NEXT: i16x8.narrow_i32x4_u
|
||||
; CHECK-NEXT: # fallthrough-return
|
||||
entry:
|
||||
%conv = fptosi <8 x half> %x to <8 x i32>
|
||||
|
|
Loading…
Reference in New Issue