[WebAssembly] Custom optimization for truncate

When possible, optimize TRUNCATE to generate Wasm SIMD narrow
instructions (i16x8.narrow_i32x4_u, i8x16.narrow_i16x8_u), rather than generate
lots of extract_lane and replace_lane.

Closes #50350.
This commit is contained in:
Jing Bao 2021-12-14 08:42:38 -08:00 committed by Thomas Lively
parent 0b9b1c8c49
commit 2a4a229d6d
4 changed files with 139 additions and 8 deletions

View File

@ -31,6 +31,7 @@ HANDLE_NODETYPE(SWIZZLE)
HANDLE_NODETYPE(VEC_SHL)
HANDLE_NODETYPE(VEC_SHR_S)
HANDLE_NODETYPE(VEC_SHR_U)
HANDLE_NODETYPE(NARROW_U)
HANDLE_NODETYPE(EXTEND_LOW_S)
HANDLE_NODETYPE(EXTEND_LOW_U)
HANDLE_NODETYPE(EXTEND_HIGH_S)

View File

@ -176,6 +176,8 @@ WebAssemblyTargetLowering::WebAssemblyTargetLowering(
setTargetDAGCombine(ISD::FP_ROUND);
setTargetDAGCombine(ISD::CONCAT_VECTORS);
setTargetDAGCombine(ISD::TRUNCATE);
// Support saturating add for i8x16 and i16x8
for (auto Op : {ISD::SADDSAT, ISD::UADDSAT})
for (auto T : {MVT::v16i8, MVT::v8i16})
@ -2609,6 +2611,114 @@ performVectorTruncZeroCombine(SDNode *N, TargetLowering::DAGCombinerInfo &DCI) {
return DAG.getNode(Op, SDLoc(N), ResVT, Source);
}
// Helper to extract VectorWidth bits from Vec, starting from IdxVal.
static SDValue extractSubVector(SDValue Vec, unsigned IdxVal, SelectionDAG &DAG,
const SDLoc &DL, unsigned VectorWidth) {
EVT VT = Vec.getValueType();
EVT ElVT = VT.getVectorElementType();
unsigned Factor = VT.getSizeInBits() / VectorWidth;
EVT ResultVT = EVT::getVectorVT(*DAG.getContext(), ElVT,
VT.getVectorNumElements() / Factor);
// Extract the relevant VectorWidth bits. Generate an EXTRACT_SUBVECTOR
unsigned ElemsPerChunk = VectorWidth / ElVT.getSizeInBits();
assert(isPowerOf2_32(ElemsPerChunk) && "Elements per chunk not power of 2");
// This is the index of the first element of the VectorWidth-bit chunk
// we want. Since ElemsPerChunk is a power of 2 just need to clear bits.
IdxVal &= ~(ElemsPerChunk - 1);
// If the input is a buildvector just emit a smaller one.
if (Vec.getOpcode() == ISD::BUILD_VECTOR)
return DAG.getBuildVector(ResultVT, DL,
Vec->ops().slice(IdxVal, ElemsPerChunk));
SDValue VecIdx = DAG.getIntPtrConstant(IdxVal, DL);
return DAG.getNode(ISD::EXTRACT_SUBVECTOR, DL, ResultVT, Vec, VecIdx);
}
// Helper to recursively truncate vector elements in half with NARROW_U. DstVT
// is the expected destination value type after recursion. In is the initial
// input. Note that the input should have enough leading zero bits to prevent
// NARROW_U from saturating results.
static SDValue truncateVectorWithNARROW(EVT DstVT, SDValue In, const SDLoc &DL,
SelectionDAG &DAG) {
EVT SrcVT = In.getValueType();
// No truncation required, we might get here due to recursive calls.
if (SrcVT == DstVT)
return In;
unsigned SrcSizeInBits = SrcVT.getSizeInBits();
unsigned NumElems = SrcVT.getVectorNumElements();
if (!isPowerOf2_32(NumElems))
return SDValue();
assert(DstVT.getVectorNumElements() == NumElems && "Illegal truncation");
assert(SrcSizeInBits > DstVT.getSizeInBits() && "Illegal truncation");
LLVMContext &Ctx = *DAG.getContext();
EVT PackedSVT = EVT::getIntegerVT(Ctx, SrcVT.getScalarSizeInBits() / 2);
// Narrow to the largest type possible:
// vXi64/vXi32 -> i16x8.narrow_i32x4_u and vXi16 -> i8x16.narrow_i16x8_u.
EVT InVT = MVT::i16, OutVT = MVT::i8;
if (SrcVT.getScalarSizeInBits() > 16) {
InVT = MVT::i32;
OutVT = MVT::i16;
}
unsigned SubSizeInBits = SrcSizeInBits / 2;
InVT = EVT::getVectorVT(Ctx, InVT, SubSizeInBits / InVT.getSizeInBits());
OutVT = EVT::getVectorVT(Ctx, OutVT, SubSizeInBits / OutVT.getSizeInBits());
// Split lower/upper subvectors.
SDValue Lo = extractSubVector(In, 0, DAG, DL, SubSizeInBits);
SDValue Hi = extractSubVector(In, NumElems / 2, DAG, DL, SubSizeInBits);
// 256bit -> 128bit truncate - Narrow lower/upper 128-bit subvectors.
if (SrcVT.is256BitVector() && DstVT.is128BitVector()) {
Lo = DAG.getBitcast(InVT, Lo);
Hi = DAG.getBitcast(InVT, Hi);
SDValue Res = DAG.getNode(WebAssemblyISD::NARROW_U, DL, OutVT, Lo, Hi);
return DAG.getBitcast(DstVT, Res);
}
// Recursively narrow lower/upper subvectors, concat result and narrow again.
EVT PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems / 2);
Lo = truncateVectorWithNARROW(PackedVT, Lo, DL, DAG);
Hi = truncateVectorWithNARROW(PackedVT, Hi, DL, DAG);
PackedVT = EVT::getVectorVT(Ctx, PackedSVT, NumElems);
SDValue Res = DAG.getNode(ISD::CONCAT_VECTORS, DL, PackedVT, Lo, Hi);
return truncateVectorWithNARROW(DstVT, Res, DL, DAG);
}
static SDValue performTruncateCombine(SDNode *N,
TargetLowering::DAGCombinerInfo &DCI) {
auto &DAG = DCI.DAG;
SDValue In = N->getOperand(0);
EVT InVT = In.getValueType();
if (!InVT.isSimple())
return SDValue();
EVT OutVT = N->getValueType(0);
if (!OutVT.isVector())
return SDValue();
EVT OutSVT = OutVT.getVectorElementType();
EVT InSVT = InVT.getVectorElementType();
// Currently only cover truncate to v16i8 or v8i16.
if (!((InSVT == MVT::i16 || InSVT == MVT::i32 || InSVT == MVT::i64) &&
(OutSVT == MVT::i8 || OutSVT == MVT::i16) && OutVT.is128BitVector()))
return SDValue();
SDLoc DL(N);
APInt Mask = APInt::getLowBitsSet(InVT.getScalarSizeInBits(),
OutVT.getScalarSizeInBits());
In = DAG.getNode(ISD::AND, DL, InVT, In, DAG.getConstant(Mask, DL, InVT));
return truncateVectorWithNARROW(OutVT, In, DL, DAG);
}
SDValue
WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
DAGCombinerInfo &DCI) const {
@ -2625,5 +2735,7 @@ WebAssemblyTargetLowering::PerformDAGCombine(SDNode *N,
case ISD::FP_ROUND:
case ISD::CONCAT_VECTORS:
return performVectorTruncZeroCombine(N, DCI);
case ISD::TRUNCATE:
return performTruncateCombine(N, DCI);
}
}

View File

@ -1278,6 +1278,14 @@ multiclass SIMDNarrow<Vec vec, bits<32> baseInst> {
defm "" : SIMDNarrow<I16x8, 101>;
defm "" : SIMDNarrow<I32x4, 133>;
// WebAssemblyISD::NARROW_U
def wasm_narrow_t : SDTypeProfile<1, 2, []>;
def wasm_narrow_u : SDNode<"WebAssemblyISD::NARROW_U", wasm_narrow_t>;
def : Pat<(v16i8 (wasm_narrow_u (v8i16 V128:$left), (v8i16 V128:$right))),
(NARROW_U_I8x16 $left, $right)>;
def : Pat<(v8i16 (wasm_narrow_u (v4i32 V128:$left), (v4i32 V128:$right))),
(NARROW_U_I16x8 $left, $right)>;
// Bitcasts are nops
// Matching bitcast t1 to t1 causes strange errors, so avoid repeating types
foreach t1 = AllVecs in

View File

@ -532,7 +532,7 @@ entry:
define <8 x i16> @stest_f16i16(<8 x half> %x) {
; CHECK-LABEL: stest_f16i16:
; CHECK: .functype stest_f16i16 (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128)
; CHECK-NEXT: .local v128, v128
; CHECK-NEXT: .local v128, v128, v128
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 5
; CHECK-NEXT: call __truncsfhf2
@ -578,6 +578,9 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) {
; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768
; CHECK-NEXT: local.tee 9
; CHECK-NEXT: i32x4.max_s
; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
; CHECK-NEXT: local.tee 10
; CHECK-NEXT: v128.and
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i32.trunc_sat_f32_s
; CHECK-NEXT: i32x4.splat
@ -594,7 +597,9 @@ define <8 x i16> @stest_f16i16(<8 x half> %x) {
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: local.get 9
; CHECK-NEXT: i32x4.max_s
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
; CHECK-NEXT: local.get 10
; CHECK-NEXT: v128.and
; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptosi <8 x half> %x to <8 x i32>
@ -666,7 +671,7 @@ define <8 x i16> @utesth_f16i16(<8 x half> %x) {
; CHECK-NEXT: i32x4.replace_lane 3
; CHECK-NEXT: local.get 8
; CHECK-NEXT: i32x4.min_u
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptoui <8 x half> %x to <8 x i32>
@ -741,7 +746,7 @@ define <8 x i16> @ustest_f16i16(<8 x half> %x) {
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: local.get 9
; CHECK-NEXT: i32x4.max_s
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptosi <8 x half> %x to <8 x i32>
@ -2106,7 +2111,7 @@ entry:
define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
; CHECK-LABEL: stest_f16i16_mm:
; CHECK: .functype stest_f16i16_mm (f32, f32, f32, f32, f32, f32, f32, f32) -> (v128)
; CHECK-NEXT: .local v128, v128
; CHECK-NEXT: .local v128, v128, v128
; CHECK-NEXT: # %bb.0: # %entry
; CHECK-NEXT: local.get 5
; CHECK-NEXT: call __truncsfhf2
@ -2152,6 +2157,9 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
; CHECK-NEXT: v128.const -32768, -32768, -32768, -32768
; CHECK-NEXT: local.tee 9
; CHECK-NEXT: i32x4.max_s
; CHECK-NEXT: v128.const 65535, 65535, 65535, 65535
; CHECK-NEXT: local.tee 10
; CHECK-NEXT: v128.and
; CHECK-NEXT: local.get 4
; CHECK-NEXT: i32.trunc_sat_f32_s
; CHECK-NEXT: i32x4.splat
@ -2168,7 +2176,9 @@ define <8 x i16> @stest_f16i16_mm(<8 x half> %x) {
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: local.get 9
; CHECK-NEXT: i32x4.max_s
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
; CHECK-NEXT: local.get 10
; CHECK-NEXT: v128.and
; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptosi <8 x half> %x to <8 x i32>
@ -2238,7 +2248,7 @@ define <8 x i16> @utesth_f16i16_mm(<8 x half> %x) {
; CHECK-NEXT: i32x4.replace_lane 3
; CHECK-NEXT: local.get 8
; CHECK-NEXT: i32x4.min_u
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptoui <8 x half> %x to <8 x i32>
@ -2312,7 +2322,7 @@ define <8 x i16> @ustest_f16i16_mm(<8 x half> %x) {
; CHECK-NEXT: i32x4.min_s
; CHECK-NEXT: local.get 9
; CHECK-NEXT: i32x4.max_s
; CHECK-NEXT: i8x16.shuffle 0, 1, 4, 5, 8, 9, 12, 13, 16, 17, 20, 21, 24, 25, 28, 29
; CHECK-NEXT: i16x8.narrow_i32x4_u
; CHECK-NEXT: # fallthrough-return
entry:
%conv = fptosi <8 x half> %x to <8 x i32>