forked from OSchip/llvm-project
137 lines
4.7 KiB
C++
137 lines
4.7 KiB
C++
//===-- NVPTXLowerStructArgs.cpp - Copy struct args to local memory =====--===//
|
|
//
|
|
// The LLVM Compiler Infrastructure
|
|
//
|
|
// This file is distributed under the University of Illinois Open Source
|
|
// License. See LICENSE.TXT for details.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// Copy struct args to local memory. This is needed for kernel functions only.
|
|
// This is a preparation for handling cases like
|
|
//
|
|
// kernel void foo(struct A arg, ...)
|
|
// {
|
|
// struct A *p = &arg;
|
|
// ...
|
|
// ... = p->filed1 ... (this is no generic address for .param)
|
|
// p->filed2 = ... (this is no write access to .param)
|
|
// }
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "NVPTX.h"
|
|
#include "NVPTXUtilities.h"
|
|
#include "llvm/IR/Function.h"
|
|
#include "llvm/IR/Instructions.h"
|
|
#include "llvm/IR/IntrinsicInst.h"
|
|
#include "llvm/IR/Module.h"
|
|
#include "llvm/IR/Type.h"
|
|
#include "llvm/Pass.h"
|
|
|
|
using namespace llvm;
|
|
|
|
namespace llvm {
|
|
void initializeNVPTXLowerStructArgsPass(PassRegistry &);
|
|
}
|
|
|
|
namespace {
|
|
class NVPTXLowerStructArgs : public FunctionPass {
|
|
bool runOnFunction(Function &F) override;
|
|
|
|
void handleStructPtrArgs(Function &);
|
|
void handleParam(Argument *);
|
|
|
|
public:
|
|
static char ID; // Pass identification, replacement for typeid
|
|
NVPTXLowerStructArgs() : FunctionPass(ID) {}
|
|
const char *getPassName() const override {
|
|
return "Copy structure (byval *) arguments to stack";
|
|
}
|
|
};
|
|
} // namespace
|
|
|
|
char NVPTXLowerStructArgs::ID = 1;
|
|
|
|
INITIALIZE_PASS(NVPTXLowerStructArgs, "nvptx-lower-struct-args",
|
|
"Lower structure arguments (NVPTX)", false, false)
|
|
|
|
void NVPTXLowerStructArgs::handleParam(Argument *Arg) {
|
|
Function *Func = Arg->getParent();
|
|
Instruction *FirstInst = &(Func->getEntryBlock().front());
|
|
PointerType *PType = dyn_cast<PointerType>(Arg->getType());
|
|
|
|
assert(PType && "Expecting pointer type in handleParam");
|
|
|
|
Type *StructType = PType->getElementType();
|
|
AllocaInst *AllocA = new AllocaInst(StructType, Arg->getName(), FirstInst);
|
|
|
|
/* Set the alignment to alignment of the byval parameter. This is because,
|
|
* later load/stores assume that alignment, and we are going to replace
|
|
* the use of the byval parameter with this alloca instruction.
|
|
*/
|
|
AllocA->setAlignment(Func->getParamAlignment(Arg->getArgNo() + 1));
|
|
|
|
Arg->replaceAllUsesWith(AllocA);
|
|
|
|
// Get the cvt.gen.to.param intrinsic
|
|
Type *CvtTypes[] = {
|
|
Type::getInt8PtrTy(Func->getParent()->getContext(), ADDRESS_SPACE_PARAM),
|
|
Type::getInt8PtrTy(Func->getParent()->getContext(),
|
|
ADDRESS_SPACE_GENERIC)};
|
|
Function *CvtFunc = Intrinsic::getDeclaration(
|
|
Func->getParent(), Intrinsic::nvvm_ptr_gen_to_param, CvtTypes);
|
|
|
|
Value *BitcastArgs[] = {
|
|
new BitCastInst(Arg, Type::getInt8PtrTy(Func->getParent()->getContext(),
|
|
ADDRESS_SPACE_GENERIC),
|
|
Arg->getName(), FirstInst)};
|
|
CallInst *CallCVT =
|
|
CallInst::Create(CvtFunc, BitcastArgs, "cvt_to_param", FirstInst);
|
|
|
|
BitCastInst *BitCast = new BitCastInst(
|
|
CallCVT, PointerType::get(StructType, ADDRESS_SPACE_PARAM),
|
|
Arg->getName(), FirstInst);
|
|
LoadInst *LI = new LoadInst(BitCast, Arg->getName(), FirstInst);
|
|
new StoreInst(LI, AllocA, FirstInst);
|
|
}
|
|
|
|
// =============================================================================
|
|
// If the function had a struct ptr arg, say foo(%struct.x *byval %d), then
|
|
// add the following instructions to the first basic block :
|
|
//
|
|
// %temp = alloca %struct.x, align 8
|
|
// %tt1 = bitcast %struct.x * %d to i8 *
|
|
// %tt2 = llvm.nvvm.cvt.gen.to.param %tt2
|
|
// %tempd = bitcast i8 addrspace(101) * to %struct.x addrspace(101) *
|
|
// %tv = load %struct.x addrspace(101) * %tempd
|
|
// store %struct.x %tv, %struct.x * %temp, align 8
|
|
//
|
|
// The above code allocates some space in the stack and copies the incoming
|
|
// struct from param space to local space.
|
|
// Then replace all occurences of %d by %temp.
|
|
// =============================================================================
|
|
void NVPTXLowerStructArgs::handleStructPtrArgs(Function &F) {
|
|
for (Argument &Arg : F.args()) {
|
|
if (Arg.getType()->isPointerTy() && Arg.hasByValAttr()) {
|
|
handleParam(&Arg);
|
|
}
|
|
}
|
|
}
|
|
|
|
// =============================================================================
|
|
// Main function for this pass.
|
|
// =============================================================================
|
|
bool NVPTXLowerStructArgs::runOnFunction(Function &F) {
|
|
// Skip non-kernels. See the comments at the top of this file.
|
|
if (!isKernelFunction(F))
|
|
return false;
|
|
|
|
handleStructPtrArgs(F);
|
|
return true;
|
|
}
|
|
|
|
FunctionPass *llvm::createNVPTXLowerStructArgsPass() {
|
|
return new NVPTXLowerStructArgs();
|
|
}
|