[NFC][ARM][ParallelDSP] Refactor narrow sequence

Most of the code used for finding a 'narrow' sequence is not used,
so I've removed it and simplified the calls from the smlad matcher.

llvm-svn: 362104
This commit is contained in:
Sam Parker 2019-05-30 15:26:37 +00:00
parent 202c3ffcbf
commit 913604a637
1 changed files with 19 additions and 48 deletions

View File

@ -248,45 +248,6 @@ namespace {
};
}
// MaxBitwidth: the maximum supported bitwidth of the elements in the DSP
// instructions, which is set to 16. So here we should collect all i8 and i16
// narrow operations.
// TODO: we currently only collect i16, and will support i8 later, so that's
// why we check that types are equal to MaxBitWidth, and not <= MaxBitWidth.
template<unsigned MaxBitWidth>
static bool IsNarrowSequence(Value *V, ValueList &VL) {
ConstantInt *CInt;
if (match(V, m_ConstantInt(CInt))) {
// TODO: if a constant is used, it needs to fit within the bit width.
return false;
}
auto *I = dyn_cast<Instruction>(V);
if (!I)
return false;
Value *Val, *LHS, *RHS;
if (match(V, m_Trunc(m_Value(Val)))) {
if (cast<TruncInst>(I)->getDestTy()->getIntegerBitWidth() == MaxBitWidth)
return IsNarrowSequence<MaxBitWidth>(Val, VL);
} else if (match(V, m_Add(m_Value(LHS), m_Value(RHS)))) {
// TODO: we need to implement sadd16/sadd8 for this, which enables to
// also do the rewrite for smlad8.ll, but it is unsupported for now.
return false;
} else if (match(V, m_ZExtOrSExt(m_Value(Val)))) {
if (cast<CastInst>(I)->getSrcTy()->getIntegerBitWidth() != MaxBitWidth)
return false;
if (match(Val, m_Load(m_Value()))) {
VL.push_back(Val);
VL.push_back(I);
return true;
}
}
return false;
}
template<typename MemInst>
static bool AreSequentialAccesses(MemInst *MemOp0, MemInst *MemOp1,
const DataLayout &DL, ScalarEvolution &SE) {
@ -507,6 +468,18 @@ bool ARMParallelDSP::InsertParallelMACs(Reduction &Reduction) {
return false;
}
template<typename InstType, unsigned BitWidth>
bool IsExtendingLoad(Value *V) {
auto *I = dyn_cast<InstType>(V);
if (!I)
return false;
if (I->getSrcTy()->getIntegerBitWidth() != BitWidth)
return false;
return isa<LoadInst>(I->getOperand(0));
}
static void MatchParallelMACSequences(Reduction &R,
OpChainList &Candidates) {
Instruction *Acc = R.AccIntAdd;
@ -526,15 +499,13 @@ static void MatchParallelMACSequences(Reduction &R,
return true;
break;
case Instruction::Mul: {
Value *MulOp0 = I->getOperand(0);
Value *MulOp1 = I->getOperand(1);
if (isa<SExtInst>(MulOp0) && isa<SExtInst>(MulOp1)) {
ValueList LHS;
ValueList RHS;
if (IsNarrowSequence<16>(MulOp0, LHS) &&
IsNarrowSequence<16>(MulOp1, RHS)) {
Candidates.push_back(make_unique<BinOpChain>(I, LHS, RHS));
}
Value *Op0 = I->getOperand(0);
Value *Op1 = I->getOperand(1);
if (IsExtendingLoad<SExtInst, 16>(Op0) &&
IsExtendingLoad<SExtInst, 16>(Op1)) {
ValueList LHS = { cast<SExtInst>(Op0)->getOperand(0), Op0 };
ValueList RHS = { cast<SExtInst>(Op1)->getOperand(0), Op1 };
Candidates.push_back(make_unique<BinOpChain>(I, LHS, RHS));
}
return false;
}