GlobalISel: Handle more cases in getGCDType

Try harder to find a canonical unmerge type when trying to cover the
desired target type. Handle finding a compatible unmerge type for two
vectors with different element types. This will return the largest
multiple of the source vector element that will evenly divide the
target vector type.

Also make the handling mixing scalars and vectors, and prefer the
source element type as the unmerge target type.
This commit is contained in:
Matt Arsenault 2020-06-06 21:24:02 -04:00
parent 107c954c13
commit 12d5bec8c7
4 changed files with 134 additions and 19 deletions

View File

@ -194,12 +194,23 @@ Align inferAlignFromPtrInfo(MachineFunction &MF, const MachinePointerInfo &MPO);
/// the number of vector elements or scalar bitwidth. The intent is a
/// G_MERGE_VALUES can be constructed from \p Ty0 elements, and unmerged into
/// \p Ty1.
LLVM_READNONE
LLT getLCMType(LLT Ty0, LLT Ty1);
/// Return a type that is greatest common divisor of \p OrigTy and \p
/// TargetTy. This will either change the number of vector elements, or
/// bitwidth of scalars. The intent is the result type can be used as the
/// result of a G_UNMERGE_VALUES from \p OrigTy.
/// Return a type where the total size is the greatest common divisor of \p
/// OrigTy and \p TargetTy. This will try to either change the number of vector
/// elements, or bitwidth of scalars. The intent is the result type can be used
/// as the result of a G_UNMERGE_VALUES from \p OrigTy, and then some
/// combination of G_MERGE_VALUES, G_BUILD_VECTOR and G_CONCAT_VECTORS (possibly
/// with intermediate casts) can re-form \p TargetTy.
///
/// If these are vectors with different element types, this will try to produce
/// a vector with a compatible total size, but the element type of \p OrigTy. If
/// this can't be satisfied, this will produce a scalar smaller than the
/// original vector elements.
///
/// In the worst case, this returns LLT::scalar(1)
LLVM_READNONE
LLT getGCDType(LLT OrigTy, LLT TargetTy);
} // End namespace llvm.

View File

@ -252,7 +252,7 @@ LLT LegalizerHelper::extractGCDType(SmallVectorImpl<Register> &Parts, LLT DstTy,
LLT NarrowTy, Register SrcReg) {
LLT SrcTy = MRI.getType(SrcReg);
LLT GCDTy = getGCDType(DstTy, getGCDType(SrcTy, NarrowTy));
LLT GCDTy = getGCDType(getGCDType(SrcTy, NarrowTy), DstTy);
if (SrcTy == GCDTy) {
// If the source already evenly divides the result type, we don't need to do
// anything.

View File

@ -542,22 +542,45 @@ LLT llvm::getLCMType(LLT Ty0, LLT Ty1) {
}
LLT llvm::getGCDType(LLT OrigTy, LLT TargetTy) {
if (OrigTy.isVector() && TargetTy.isVector()) {
assert(OrigTy.getElementType() == TargetTy.getElementType());
int GCD = greatestCommonDivisor(OrigTy.getNumElements(),
TargetTy.getNumElements());
return LLT::scalarOrVector(GCD, OrigTy.getElementType());
const unsigned OrigSize = OrigTy.getSizeInBits();
const unsigned TargetSize = TargetTy.getSizeInBits();
if (OrigSize == TargetSize)
return OrigTy;
if (OrigTy.isVector()) {
LLT OrigElt = OrigTy.getElementType();
if (TargetTy.isVector()) {
LLT TargetElt = TargetTy.getElementType();
if (OrigElt.getSizeInBits() == TargetElt.getSizeInBits()) {
int GCD = greatestCommonDivisor(OrigTy.getNumElements(),
TargetTy.getNumElements());
return LLT::scalarOrVector(GCD, OrigElt);
}
} else {
// If the source is a vector of pointers, return a pointer element.
if (OrigElt.getSizeInBits() == TargetSize)
return OrigElt;
}
unsigned GCD = greatestCommonDivisor(OrigSize, TargetSize);
if (GCD == OrigElt.getSizeInBits())
return OrigElt;
// If we can't produce the original element type, we have to use a smaller
// scalar.
if (GCD < OrigElt.getSizeInBits())
return LLT::scalar(GCD);
return LLT::vector(GCD / OrigElt.getSizeInBits(), OrigElt);
}
if (OrigTy.isVector() && !TargetTy.isVector()) {
assert(OrigTy.getElementType() == TargetTy);
return TargetTy;
if (TargetTy.isVector()) {
// Try to preserve the original element type.
LLT TargetElt = TargetTy.getElementType();
if (TargetElt.getSizeInBits() == OrigSize)
return OrigTy;
}
assert(!OrigTy.isVector() && !TargetTy.isVector() &&
"GCD type of vector and scalar not implemented");
int GCD = greatestCommonDivisor(OrigTy.getSizeInBits(),
TargetTy.getSizeInBits());
unsigned GCD = greatestCommonDivisor(OrigSize, TargetSize);
return LLT::scalar(GCD);
}

View File

@ -13,13 +13,18 @@ using namespace llvm;
namespace {
static const LLT S1 = LLT::scalar(1);
static const LLT S8 = LLT::scalar(8);
static const LLT S16 = LLT::scalar(16);
static const LLT S32 = LLT::scalar(32);
static const LLT S64 = LLT::scalar(64);
static const LLT P0 = LLT::pointer(0, 64);
static const LLT P1 = LLT::pointer(1, 32);
static const LLT V2S8 = LLT::vector(2, 8);
static const LLT V4S8 = LLT::vector(4, 8);
static const LLT V2S16 = LLT::vector(2, 16);
static const LLT V3S16 = LLT::vector(3, 16);
static const LLT V4S16 = LLT::vector(4, 16);
static const LLT V2S32 = LLT::vector(2, 32);
@ -27,11 +32,17 @@ static const LLT V3S32 = LLT::vector(3, 32);
static const LLT V4S32 = LLT::vector(4, 32);
static const LLT V6S32 = LLT::vector(6, 32);
static const LLT V2S64 = LLT::vector(2, 64);
static const LLT V4S64 = LLT::vector(4, 64);
static const LLT V2P0 = LLT::vector(2, P0);
static const LLT V3P0 = LLT::vector(3, P0);
static const LLT V4P0 = LLT::vector(4, P0);
static const LLT V6P0 = LLT::vector(6, P0);
static const LLT V2P1 = LLT::vector(2, P1);
static const LLT V4P1 = LLT::vector(4, P1);
TEST(GISelUtilsTest, getGCDType) {
EXPECT_EQ(S1, getGCDType(S1, S1));
EXPECT_EQ(S32, getGCDType(S32, S32));
@ -56,7 +67,7 @@ TEST(GISelUtilsTest, getGCDType) {
EXPECT_EQ(S32, getGCDType(P0, S32));
EXPECT_EQ(S32, getGCDType(S32, P0));
EXPECT_EQ(S64, getGCDType(P0, S64));
EXPECT_EQ(P0, getGCDType(P0, S64));
EXPECT_EQ(S64, getGCDType(S64, P0));
EXPECT_EQ(S32, getGCDType(P0, P1));
@ -64,6 +75,76 @@ TEST(GISelUtilsTest, getGCDType) {
EXPECT_EQ(P0, getGCDType(V3P0, V2P0));
EXPECT_EQ(P0, getGCDType(V2P0, V3P0));
EXPECT_EQ(P0, getGCDType(P0, V2P0));
EXPECT_EQ(P0, getGCDType(V2P0, P0));
EXPECT_EQ(V2P0, getGCDType(V2P0, V2P0));
EXPECT_EQ(P0, getGCDType(V3P0, V2P0));
EXPECT_EQ(P0, getGCDType(V2P0, V3P0));
EXPECT_EQ(V2P0, getGCDType(V4P0, V2P0));
EXPECT_EQ(V2P0, getGCDType(V2P0, V4P1));
EXPECT_EQ(V4P1, getGCDType(V4P1, V2P0));
EXPECT_EQ(V2P0, getGCDType(V4P0, V4P1));
EXPECT_EQ(V4P1, getGCDType(V4P1, V4P0));
// Elements have same size, but have different pointeriness, so prefer the
// original element type.
EXPECT_EQ(V2P0, getGCDType(V2P0, V4S64));
EXPECT_EQ(V2S64, getGCDType(V4S64, V2P0));
EXPECT_EQ(V2S16, getGCDType(V2S16, V4P1));
EXPECT_EQ(P1, getGCDType(V4P1, V2S16));
EXPECT_EQ(V2P1, getGCDType(V4P1, V4S16));
EXPECT_EQ(V4S16, getGCDType(V4S16, V2P1));
EXPECT_EQ(P0, getGCDType(P0, V2S64));
EXPECT_EQ(S64, getGCDType(V2S64, P0));
EXPECT_EQ(S16, getGCDType(V2S16, V3S16));
EXPECT_EQ(S16, getGCDType(V3S16, V2S16));
EXPECT_EQ(S16, getGCDType(V3S16, S16));
EXPECT_EQ(S16, getGCDType(S16, V3S16));
EXPECT_EQ(V2S16, getGCDType(V2S16, V2S32));
EXPECT_EQ(S32, getGCDType(V2S32, V2S16));
EXPECT_EQ(V4S8, getGCDType(V4S8, V2S32));
EXPECT_EQ(S32, getGCDType(V2S32, V4S8));
// Test cases where neither element type nicely divides.
EXPECT_EQ(LLT::scalar(3), getGCDType(LLT::vector(3, 5), LLT::vector(2, 6)));
EXPECT_EQ(LLT::scalar(3), getGCDType(LLT::vector(2, 6), LLT::vector(3, 5)));
// Have to go smaller than a pointer element.
EXPECT_EQ(LLT::scalar(3), getGCDType(LLT::vector(2, LLT::pointer(3, 6)),
LLT::vector(3, 5)));
EXPECT_EQ(LLT::scalar(3), getGCDType(LLT::vector(3, 5),
LLT::vector(2, LLT::pointer(3, 6))));
EXPECT_EQ(V4S8, getGCDType(V4S8, S32));
EXPECT_EQ(S32, getGCDType(S32, V4S8));
EXPECT_EQ(V4S8, getGCDType(V4S8, P1));
EXPECT_EQ(P1, getGCDType(P1, V4S8));
EXPECT_EQ(V2S8, getGCDType(V2S8, V4S16));
EXPECT_EQ(S16, getGCDType(V4S16, V2S8));
EXPECT_EQ(S8, getGCDType(V2S8, LLT::vector(4, 2)));
EXPECT_EQ(LLT::vector(4, 2), getGCDType(LLT::vector(4, 2), S8));
EXPECT_EQ(LLT::pointer(4, 8), getGCDType(LLT::vector(2, LLT::pointer(4, 8)),
LLT::vector(4, 2)));
EXPECT_EQ(LLT::vector(4, 2), getGCDType(LLT::vector(4, 2),
LLT::vector(2, LLT::pointer(4, 8))));
EXPECT_EQ(LLT::scalar(4), getGCDType(LLT::vector(3, 4), S8));
EXPECT_EQ(LLT::scalar(4), getGCDType(S8, LLT::vector(3, 4)));
}
TEST(GISelUtilsTest, getLCMType) {