[X86] Refactor GetSSETypeAtOffset to fix pr51813

D105263 adds support for _Float16 type. It introduced a bug (pr51813) that generates a <4 x half> type instead the default double when passing blank structure by SSE registers.

Although I doubt it may expose a bug somewhere other than D105263, it's good to avoid return half type when no half type in arguments.

Reviewed By: LuoYuanke

Differential Revision: https://reviews.llvm.org/D109607
This commit is contained in:
Wang, Pengfei 2021-09-17 10:20:09 +08:00
parent aaf00f3f19
commit e9e1d4751b
2 changed files with 95 additions and 81 deletions

View File

@ -3407,52 +3407,18 @@ static bool BitsContainNoUserData(QualType Ty, unsigned StartBit,
return false;
}
/// ContainsFloatAtOffset - Return true if the specified LLVM IR type has a
/// float member at the specified offset. For example, {int,{float}} has a
/// float at offset 4. It is conservatively correct for this routine to return
/// false.
static bool ContainsFloatAtOffset(llvm::Type *IRType, unsigned IROffset,
const llvm::DataLayout &TD) {
// Base case if we find a float.
if (IROffset == 0 && IRType->isFloatTy())
return true;
/// getFPTypeAtOffset - Return a floating point type at the specified offset.
static llvm::Type *getFPTypeAtOffset(llvm::Type *IRType, unsigned IROffset,
const llvm::DataLayout &TD) {
if (IROffset == 0 && IRType->isFloatingPointTy())
return IRType;
// If this is a struct, recurse into the field at the specified offset.
if (llvm::StructType *STy = dyn_cast<llvm::StructType>(IRType)) {
const llvm::StructLayout *SL = TD.getStructLayout(STy);
unsigned Elt = SL->getElementContainingOffset(IROffset);
IROffset -= SL->getElementOffset(Elt);
return ContainsFloatAtOffset(STy->getElementType(Elt), IROffset, TD);
}
// If this is an array, recurse into the field at the specified offset.
if (llvm::ArrayType *ATy = dyn_cast<llvm::ArrayType>(IRType)) {
llvm::Type *EltTy = ATy->getElementType();
unsigned EltSize = TD.getTypeAllocSize(EltTy);
IROffset -= IROffset/EltSize*EltSize;
return ContainsFloatAtOffset(EltTy, IROffset, TD);
}
return false;
}
/// ContainsHalfAtOffset - Return true if the specified LLVM IR type has a
/// half member at the specified offset. For example, {int,{half}} has a
/// half at offset 4. It is conservatively correct for this routine to return
/// false.
/// FIXME: Merge with ContainsFloatAtOffset
static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
const llvm::DataLayout &TD) {
// Base case if we find a float.
if (IROffset == 0 && IRType->isHalfTy())
return true;
// If this is a struct, recurse into the field at the specified offset.
if (llvm::StructType *STy = dyn_cast<llvm::StructType>(IRType)) {
const llvm::StructLayout *SL = TD.getStructLayout(STy);
unsigned Elt = SL->getElementContainingOffset(IROffset);
IROffset -= SL->getElementOffset(Elt);
return ContainsHalfAtOffset(STy->getElementType(Elt), IROffset, TD);
return getFPTypeAtOffset(STy->getElementType(Elt), IROffset, TD);
}
// If this is an array, recurse into the field at the specified offset.
@ -3460,10 +3426,10 @@ static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
llvm::Type *EltTy = ATy->getElementType();
unsigned EltSize = TD.getTypeAllocSize(EltTy);
IROffset -= IROffset / EltSize * EltSize;
return ContainsHalfAtOffset(EltTy, IROffset, TD);
return getFPTypeAtOffset(EltTy, IROffset, TD);
}
return false;
return nullptr;
}
/// GetSSETypeAtOffset - Return a type that will be passed by the backend in the
@ -3471,39 +3437,37 @@ static bool ContainsHalfAtOffset(llvm::Type *IRType, unsigned IROffset,
llvm::Type *X86_64ABIInfo::
GetSSETypeAtOffset(llvm::Type *IRType, unsigned IROffset,
QualType SourceTy, unsigned SourceOffset) const {
// If the high 32 bits are not used, we have three choices. Single half,
// single float or two halfs.
if (BitsContainNoUserData(SourceTy, SourceOffset * 8 + 32,
SourceOffset * 8 + 64, getContext())) {
if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()))
return llvm::Type::getFloatTy(getVMContext());
if (ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()))
return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()),
2);
const llvm::DataLayout &TD = getDataLayout();
llvm::Type *T0 = getFPTypeAtOffset(IRType, IROffset, TD);
if (!T0 || T0->isDoubleTy())
return llvm::Type::getDoubleTy(getVMContext());
return llvm::Type::getHalfTy(getVMContext());
// Get the adjacent FP type.
llvm::Type *T1 =
getFPTypeAtOffset(IRType, IROffset + TD.getTypeAllocSize(T0), TD);
if (T1 == nullptr) {
// Check if IRType is a half + float. float type will be in IROffset+4 due
// to its alignment.
if (T0->isHalfTy())
T1 = getFPTypeAtOffset(IRType, IROffset + 4, TD);
// If we can't get a second FP type, return a simple half or float.
// avx512fp16-abi.c:pr51813_2 shows it works to return float for
// {float, i8} too.
if (T1 == nullptr)
return T0;
}
// We want to pass as <2 x float> if the LLVM IR type contains a float at
// offset+0 and offset+4. Walk the LLVM IR type to find out if this is the
// case.
if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) &&
ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout()))
return llvm::FixedVectorType::get(llvm::Type::getFloatTy(getVMContext()),
2);
if (T0->isFloatTy() && T1->isFloatTy())
return llvm::FixedVectorType::get(T0, 2);
// We want to pass as <4 x half> if the LLVM IR type contains a half at
// offset+0, +2, +4. Walk the LLVM IR type to find out if this is the case.
if (ContainsHalfAtOffset(IRType, IROffset, getDataLayout()) &&
ContainsHalfAtOffset(IRType, IROffset + 2, getDataLayout()) &&
ContainsHalfAtOffset(IRType, IROffset + 4, getDataLayout()))
return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);
if (T0->isHalfTy() && T1->isHalfTy()) {
llvm::Type *T2 = getFPTypeAtOffset(IRType, IROffset + 4, TD);
if (T2 == nullptr)
return llvm::FixedVectorType::get(T0, 2);
return llvm::FixedVectorType::get(T0, 4);
}
// We want to pass as <4 x half> if the LLVM IR type contains a mix of float
// and half.
// FIXME: Do we have a better representation for the mixed type?
if (ContainsFloatAtOffset(IRType, IROffset, getDataLayout()) ||
ContainsFloatAtOffset(IRType, IROffset + 4, getDataLayout()))
if (T0->isHalfTy() || T1->isHalfTy())
return llvm::FixedVectorType::get(llvm::Type::getHalfTy(getVMContext()), 4);
return llvm::Type::getDoubleTy(getVMContext());

View File

@ -1,11 +1,12 @@
// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm -target-feature +avx512fp16 < %s | FileCheck %s --check-prefixes=CHECK
// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm -target-feature +avx512fp16 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-C
// RUN: %clang_cc1 -triple x86_64-linux -emit-llvm -target-feature +avx512fp16 -x c++ -std=c++11 < %s | FileCheck %s --check-prefixes=CHECK,CHECK-CPP
struct half1 {
_Float16 a;
};
struct half1 h1(_Float16 a) {
// CHECK: define{{.*}}half @h1
// CHECK: define{{.*}}half @
struct half1 x;
x.a = a;
return x;
@ -17,7 +18,7 @@ struct half2 {
};
struct half2 h2(_Float16 a, _Float16 b) {
// CHECK: define{{.*}}<2 x half> @h2
// CHECK: define{{.*}}<2 x half> @
struct half2 x;
x.a = a;
x.b = b;
@ -31,7 +32,7 @@ struct half3 {
};
struct half3 h3(_Float16 a, _Float16 b, _Float16 c) {
// CHECK: define{{.*}}<4 x half> @h3
// CHECK: define{{.*}}<4 x half> @
struct half3 x;
x.a = a;
x.b = b;
@ -47,7 +48,7 @@ struct half4 {
};
struct half4 h4(_Float16 a, _Float16 b, _Float16 c, _Float16 d) {
// CHECK: define{{.*}}<4 x half> @h4
// CHECK: define{{.*}}<4 x half> @
struct half4 x;
x.a = a;
x.b = b;
@ -62,7 +63,7 @@ struct floathalf {
};
struct floathalf fh(float a, _Float16 b) {
// CHECK: define{{.*}}<4 x half> @fh
// CHECK: define{{.*}}<4 x half> @
struct floathalf x;
x.a = a;
x.b = b;
@ -76,7 +77,7 @@ struct floathalf2 {
};
struct floathalf2 fh2(float a, _Float16 b, _Float16 c) {
// CHECK: define{{.*}}<4 x half> @fh2
// CHECK: define{{.*}}<4 x half> @
struct floathalf2 x;
x.a = a;
x.b = b;
@ -90,7 +91,7 @@ struct halffloat {
};
struct halffloat hf(_Float16 a, float b) {
// CHECK: define{{.*}}<4 x half> @hf
// CHECK: define{{.*}}<4 x half> @
struct halffloat x;
x.a = a;
x.b = b;
@ -104,7 +105,7 @@ struct half2float {
};
struct half2float h2f(_Float16 a, _Float16 b, float c) {
// CHECK: define{{.*}}<4 x half> @h2f
// CHECK: define{{.*}}<4 x half> @
struct half2float x;
x.a = a;
x.b = b;
@ -120,7 +121,7 @@ struct floathalf3 {
};
struct floathalf3 fh3(float a, _Float16 b, _Float16 c, _Float16 d) {
// CHECK: define{{.*}}{ <4 x half>, half } @fh3
// CHECK: define{{.*}}{ <4 x half>, half } @
struct floathalf3 x;
x.a = a;
x.b = b;
@ -138,7 +139,7 @@ struct half5 {
};
struct half5 h5(_Float16 a, _Float16 b, _Float16 c, _Float16 d, _Float16 e) {
// CHECK: define{{.*}}{ <4 x half>, half } @h5
// CHECK: define{{.*}}{ <4 x half>, half } @
struct half5 x;
x.a = a;
x.b = b;
@ -147,3 +148,52 @@ struct half5 h5(_Float16 a, _Float16 b, _Float16 c, _Float16 d, _Float16 e) {
x.e = e;
return x;
}
struct float2 {
struct {} s;
float a;
float b;
};
float pr51813(struct float2 s) {
// CHECK-C: define{{.*}} @pr51813(<2 x float>
// CHECK-CPP: define{{.*}} @_Z7pr518136float2(double {{.*}}, float
return s.a;
}
struct float3 {
float a;
struct {} s;
float b;
};
float pr51813_2(struct float3 s) {
// CHECK-C: define{{.*}} @pr51813_2(<2 x float>
// CHECK-CPP: define{{.*}} @_Z9pr51813_26float3(double {{.*}}, float
return s.a;
}
struct shalf2 {
struct {} s;
_Float16 a;
_Float16 b;
};
_Float16 sf2(struct shalf2 s) {
// CHECK-C: define{{.*}} @sf2(<2 x half>
// CHECK-CPP: define{{.*}} @_Z3sf26shalf2(double {{.*}}
return s.a;
};
struct halfs2 {
_Float16 a;
struct {} s1;
_Float16 b;
struct {} s2;
};
_Float16 fs2(struct shalf2 s) {
// CHECK-C: define{{.*}} @fs2(<2 x half>
// CHECK-CPP: define{{.*}} @_Z3fs26shalf2(double {{.*}}
return s.a;
};