[NVPTX] Clean up argument lowering code and properly handle alignment for structs and vectors

llvm-svn: 211938
This commit is contained in:
Justin Holewinski 2014-06-27 18:35:44 +00:00
parent d7d8fe0e9c
commit 6e40f63e41
2 changed files with 89 additions and 90 deletions

View File

@ -67,6 +67,17 @@ static bool IsPTXVectorType(MVT VT) {
} }
} }
static uint64_t GCD( int a, int b)
{
if (a < b) std::swap(a,b);
while (b > 0) {
uint64_t c = b;
b = a % b;
a = c;
}
return a;
}
/// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive /// ComputePTXValueVTs - For the given Type \p Ty, returns the set of primitive
/// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors /// EVTs that compose it. Unlike ComputeValueVTs, this will break apart vectors
/// into their primitive components. /// into their primitive components.
@ -518,26 +529,12 @@ NVPTXTargetLowering::getPrototype(Type *retTy, const ArgListTy &Args,
} else if (isa<PointerType>(retTy)) { } else if (isa<PointerType>(retTy)) {
O << ".param .b" << getPointerTy().getSizeInBits() << " _"; O << ".param .b" << getPointerTy().getSizeInBits() << " _";
} else { } else {
if ((retTy->getTypeID() == Type::StructTyID) || isa<VectorType>(retTy)) { if((retTy->getTypeID() == Type::StructTyID) ||
SmallVector<EVT, 16> vtparts; isa<VectorType>(retTy)) {
ComputeValueVTs(*this, retTy, vtparts); O << ".param .align "
unsigned totalsz = 0; << retAlignment
for (unsigned i = 0, e = vtparts.size(); i != e; ++i) { << " .b8 _["
unsigned elems = 1; << getDataLayout()->getTypeAllocSize(retTy) << "]";
EVT elemtype = vtparts[i];
if (vtparts[i].isVector()) {
elems = vtparts[i].getVectorNumElements();
elemtype = vtparts[i].getVectorElementType();
}
// TODO: no need to loop
for (unsigned j = 0, je = elems; j != je; ++j) {
unsigned sz = elemtype.getSizeInBits();
if (elemtype.isInteger() && (sz < 8))
sz = 8;
totalsz += sz / 8;
}
}
O << ".param .align " << retAlignment << " .b8 _[" << totalsz << "]";
} else { } else {
assert(false && "Unknown return type"); assert(false && "Unknown return type");
} }
@ -706,7 +703,8 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
if (Ty->isAggregateType()) { if (Ty->isAggregateType()) {
// aggregate // aggregate
SmallVector<EVT, 16> vtparts; SmallVector<EVT, 16> vtparts;
ComputeValueVTs(*this, Ty, vtparts); SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, Ty, vtparts, &Offsets, 0);
unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1); unsigned align = getArgumentAlignment(Callee, CS, Ty, paramCount + 1);
// declare .param .align <align> .b8 .param<n>[<size>]; // declare .param .align <align> .b8 .param<n>[<size>];
@ -718,16 +716,9 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
DeclareParamOps); DeclareParamOps);
InFlag = Chain.getValue(1); InFlag = Chain.getValue(1);
unsigned curOffset = 0;
for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
unsigned elems = 1;
EVT elemtype = vtparts[j]; EVT elemtype = vtparts[j];
if (vtparts[j].isVector()) { unsigned ArgAlign = GCD(align, Offsets[j]);
elems = vtparts[j].getVectorNumElements();
elemtype = vtparts[j].getVectorElementType();
}
for (unsigned k = 0, ke = elems; k != ke; ++k) {
unsigned sz = elemtype.getSizeInBits();
if (elemtype.isInteger() && (sz < 8)) if (elemtype.isInteger() && (sz < 8))
sz = 8; sz = 8;
SDValue StVal = OutVals[OIdx]; SDValue StVal = OutVals[OIdx];
@ -737,16 +728,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDVTList CopyParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
SDValue CopyParamOps[] = { Chain, SDValue CopyParamOps[] = { Chain,
DAG.getConstant(paramCount, MVT::i32), DAG.getConstant(paramCount, MVT::i32),
DAG.getConstant(curOffset, MVT::i32), DAG.getConstant(Offsets[j], MVT::i32),
StVal, InFlag }; StVal, InFlag };
Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl, Chain = DAG.getMemIntrinsicNode(NVPTXISD::StoreParam, dl,
CopyParamVTs, CopyParamOps, CopyParamVTs, CopyParamOps,
elemtype, MachinePointerInfo()); elemtype, MachinePointerInfo(),
ArgAlign);
InFlag = Chain.getValue(1); InFlag = Chain.getValue(1);
curOffset += sz / 8;
++OIdx; ++OIdx;
} }
}
if (vtparts.size() > 0) if (vtparts.size() > 0)
--OIdx; --OIdx;
++paramCount; ++paramCount;
@ -930,13 +920,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
} }
// struct or vector // struct or vector
SmallVector<EVT, 16> vtparts; SmallVector<EVT, 16> vtparts;
SmallVector<uint64_t, 16> Offsets;
const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty); const PointerType *PTy = dyn_cast<PointerType>(Args[i].Ty);
assert(PTy && "Type of a byval parameter should be pointer"); assert(PTy && "Type of a byval parameter should be pointer");
ComputeValueVTs(*this, PTy->getElementType(), vtparts); ComputePTXValueVTs(*this, PTy->getElementType(), vtparts, &Offsets, 0);
// declare .param .align <align> .b8 .param<n>[<size>]; // declare .param .align <align> .b8 .param<n>[<size>];
unsigned sz = Outs[OIdx].Flags.getByValSize(); unsigned sz = Outs[OIdx].Flags.getByValSize();
SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue); SDVTList DeclareParamVTs = DAG.getVTList(MVT::Other, MVT::Glue);
unsigned ArgAlign = Outs[OIdx].Flags.getByValAlign();
// The ByValAlign in the Outs[OIdx].Flags is alway set at this point, // The ByValAlign in the Outs[OIdx].Flags is alway set at this point,
// so we don't need to worry about natural alignment or not. // so we don't need to worry about natural alignment or not.
// See TargetLowering::LowerCallTo(). // See TargetLowering::LowerCallTo().
@ -948,24 +940,16 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs, Chain = DAG.getNode(NVPTXISD::DeclareParam, dl, DeclareParamVTs,
DeclareParamOps); DeclareParamOps);
InFlag = Chain.getValue(1); InFlag = Chain.getValue(1);
unsigned curOffset = 0;
for (unsigned j = 0, je = vtparts.size(); j != je; ++j) { for (unsigned j = 0, je = vtparts.size(); j != je; ++j) {
unsigned elems = 1;
EVT elemtype = vtparts[j]; EVT elemtype = vtparts[j];
if (vtparts[j].isVector()) { int curOffset = Offsets[j];
elems = vtparts[j].getVectorNumElements(); unsigned PartAlign = GCD(ArgAlign, curOffset);
elemtype = vtparts[j].getVectorElementType();
}
for (unsigned k = 0, ke = elems; k != ke; ++k) {
unsigned sz = elemtype.getSizeInBits();
if (elemtype.isInteger() && (sz < 8))
sz = 8;
SDValue srcAddr = SDValue srcAddr =
DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx], DAG.getNode(ISD::ADD, dl, getPointerTy(), OutVals[OIdx],
DAG.getConstant(curOffset, getPointerTy())); DAG.getConstant(curOffset, getPointerTy()));
SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr, SDValue theVal = DAG.getLoad(elemtype, dl, tempChain, srcAddr,
MachinePointerInfo(), false, false, false, MachinePointerInfo(), false, false, false,
0); PartAlign);
if (elemtype.getSizeInBits() < 16) { if (elemtype.getSizeInBits() < 16) {
theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal); theVal = DAG.getNode(ISD::ANY_EXTEND, dl, MVT::i16, theVal);
} }
@ -978,8 +962,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
MachinePointerInfo()); MachinePointerInfo());
InFlag = Chain.getValue(1); InFlag = Chain.getValue(1);
curOffset += sz / 8;
}
} }
++paramCount; ++paramCount;
} }
@ -1088,7 +1070,6 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
// Generate loads from param memory/moves from registers for result // Generate loads from param memory/moves from registers for result
if (Ins.size() > 0) { if (Ins.size() > 0) {
unsigned resoffset = 0;
if (retTy && retTy->isVectorTy()) { if (retTy && retTy->isVectorTy()) {
EVT ObjectVT = getValueType(retTy); EVT ObjectVT = getValueType(retTy);
unsigned NumElts = ObjectVT.getVectorNumElements(); unsigned NumElts = ObjectVT.getVectorNumElements();
@ -1097,14 +1078,15 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
ObjectVT) == NumElts && ObjectVT) == NumElts &&
"Vector was not scalarized"); "Vector was not scalarized");
unsigned sz = EltVT.getSizeInBits(); unsigned sz = EltVT.getSizeInBits();
bool needTruncate = sz < 16 ? true : false; bool needTruncate = sz < 8 ? true : false;
if (NumElts == 1) { if (NumElts == 1) {
// Just a simple load // Just a simple load
SmallVector<EVT, 4> LoadRetVTs; SmallVector<EVT, 4> LoadRetVTs;
if (needTruncate) { if (EltVT == MVT::i1 || EltVT == MVT::i8) {
// If loading i1 result, generate // If loading i1/i8 result, generate
// load i16 // load.b8 i16
// if i1
// trunc i16 to i1 // trunc i16 to i1
LoadRetVTs.push_back(MVT::i16); LoadRetVTs.push_back(MVT::i16);
} else } else
@ -1128,9 +1110,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
} else if (NumElts == 2) { } else if (NumElts == 2) {
// LoadV2 // LoadV2
SmallVector<EVT, 4> LoadRetVTs; SmallVector<EVT, 4> LoadRetVTs;
if (needTruncate) { if (EltVT == MVT::i1 || EltVT == MVT::i8) {
// If loading i1 result, generate // If loading i1/i8 result, generate
// load i16 // load.b8 i16
// if i1
// trunc i16 to i1 // trunc i16 to i1
LoadRetVTs.push_back(MVT::i16); LoadRetVTs.push_back(MVT::i16);
LoadRetVTs.push_back(MVT::i16); LoadRetVTs.push_back(MVT::i16);
@ -1173,9 +1156,10 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize); EVT VecVT = EVT::getVectorVT(F->getContext(), EltVT, VecSize);
for (unsigned i = 0; i < NumElts; i += VecSize) { for (unsigned i = 0; i < NumElts; i += VecSize) {
SmallVector<EVT, 8> LoadRetVTs; SmallVector<EVT, 8> LoadRetVTs;
if (needTruncate) { if (EltVT == MVT::i1 || EltVT == MVT::i8) {
// If loading i1 result, generate // If loading i1/i8 result, generate
// load i16 // load.b8 i16
// if i1
// trunc i16 to i1 // trunc i16 to i1
for (unsigned j = 0; j < VecSize; ++j) for (unsigned j = 0; j < VecSize; ++j)
LoadRetVTs.push_back(MVT::i16); LoadRetVTs.push_back(MVT::i16);
@ -1214,10 +1198,13 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
} }
} else { } else {
SmallVector<EVT, 16> VTs; SmallVector<EVT, 16> VTs;
ComputePTXValueVTs(*this, retTy, VTs); SmallVector<uint64_t, 16> Offsets;
ComputePTXValueVTs(*this, retTy, VTs, &Offsets, 0);
assert(VTs.size() == Ins.size() && "Bad value decomposition"); assert(VTs.size() == Ins.size() && "Bad value decomposition");
unsigned RetAlign = getArgumentAlignment(Callee, CS, retTy, 0);
for (unsigned i = 0, e = Ins.size(); i != e; ++i) { for (unsigned i = 0, e = Ins.size(); i != e; ++i) {
unsigned sz = VTs[i].getSizeInBits(); unsigned sz = VTs[i].getSizeInBits();
unsigned AlignI = GCD(RetAlign, Offsets[i]);
bool needTruncate = sz < 8 ? true : false; bool needTruncate = sz < 8 ? true : false;
if (VTs[i].isInteger() && (sz < 8)) if (VTs[i].isInteger() && (sz < 8))
sz = 8; sz = 8;
@ -1243,19 +1230,18 @@ SDValue NVPTXTargetLowering::LowerCall(TargetLowering::CallLoweringInfo &CLI,
SmallVector<SDValue, 4> LoadRetOps; SmallVector<SDValue, 4> LoadRetOps;
LoadRetOps.push_back(Chain); LoadRetOps.push_back(Chain);
LoadRetOps.push_back(DAG.getConstant(1, MVT::i32)); LoadRetOps.push_back(DAG.getConstant(1, MVT::i32));
LoadRetOps.push_back(DAG.getConstant(resoffset, MVT::i32)); LoadRetOps.push_back(DAG.getConstant(Offsets[i], MVT::i32));
LoadRetOps.push_back(InFlag); LoadRetOps.push_back(InFlag);
SDValue retval = DAG.getMemIntrinsicNode( SDValue retval = DAG.getMemIntrinsicNode(
NVPTXISD::LoadParam, dl, NVPTXISD::LoadParam, dl,
DAG.getVTList(LoadRetVTs), LoadRetOps, DAG.getVTList(LoadRetVTs), LoadRetOps,
TheLoadType, MachinePointerInfo()); TheLoadType, MachinePointerInfo(), AlignI);
Chain = retval.getValue(1); Chain = retval.getValue(1);
InFlag = retval.getValue(2); InFlag = retval.getValue(2);
SDValue Ret0 = retval.getValue(0); SDValue Ret0 = retval.getValue(0);
if (needTruncate) if (needTruncate)
Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0); Ret0 = DAG.getNode(ISD::TRUNCATE, dl, Ins[i].VT, Ret0);
InVals.push_back(Ret0); InVals.push_back(Ret0);
resoffset += sz / 8;
} }
} }
} }

View File

@ -0,0 +1,13 @@
; RUN: llc < %s -march=nvptx -mcpu=sm_20 | FileCheck %s
; CHECK: .visible .func (.param .align 16 .b8 func_retval0[16]) foo0(
; CHECK: .param .align 4 .b8 foo0_param_0[8]
define <4 x float> @foo0({float, float} %arg0) {
ret <4 x float> <float 1.0, float 1.0, float 1.0, float 1.0>
}
; CHECK: .visible .func (.param .align 8 .b8 func_retval0[8]) foo1(
; CHECK: .param .align 8 .b8 foo1_param_0[16]
define <2 x float> @foo1({float, float, i64} %arg0) {
ret <2 x float> <float 1.0, float 1.0>
}