forked from OSchip/llvm-project
[CodeExtractor] Enable partial aggregate arguments
Summary: Enable CodeExtractor to construct output functions that partially aggregate inputs/outputs in their argument list. A use case is the OMPIRBuilder to create outlined functions for parallel regions that aggregate in a struct the payload variables for the region while passing as scalars thread and bound identifiers. Differential Revision: https://reviews.llvm.org/D96854
This commit is contained in:
parent
510710d037
commit
95b981ca2a
|
@ -168,7 +168,7 @@ public:
|
|||
///
|
||||
/// Based on the blocks used when constructing the code extractor,
|
||||
/// determine whether it is eligible for extraction.
|
||||
///
|
||||
///
|
||||
/// Checks that varargs handling (with vastart and vaend) is only done in
|
||||
/// the outlined blocks.
|
||||
bool isEligible() const;
|
||||
|
@ -214,6 +214,10 @@ public:
|
|||
/// original block will be added to the outline region.
|
||||
BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock);
|
||||
|
||||
/// Exclude a value from aggregate argument passing when extracting a code
|
||||
/// region, passing it instead as a scalar.
|
||||
void excludeArgFromAggregate(Value *Arg);
|
||||
|
||||
private:
|
||||
struct LifetimeMarkerInfo {
|
||||
bool SinkLifeStart = false;
|
||||
|
@ -222,6 +226,8 @@ public:
|
|||
Instruction *LifeEnd = nullptr;
|
||||
};
|
||||
|
||||
ValueSet ExcludeArgsFromAggregate;
|
||||
|
||||
LifetimeMarkerInfo
|
||||
getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
|
||||
Instruction *Addr, BasicBlock *ExitBlock) const;
|
||||
|
|
|
@ -829,39 +829,54 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
|
|||
default: RetTy = Type::getInt16Ty(header->getContext()); break;
|
||||
}
|
||||
|
||||
std::vector<Type *> paramTy;
|
||||
std::vector<Type *> ParamTy;
|
||||
std::vector<Type *> AggParamTy;
|
||||
ValueSet StructValues;
|
||||
|
||||
// Add the types of the input values to the function's argument list
|
||||
for (Value *value : inputs) {
|
||||
LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
|
||||
paramTy.push_back(value->getType());
|
||||
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
|
||||
AggParamTy.push_back(value->getType());
|
||||
StructValues.insert(value);
|
||||
} else
|
||||
ParamTy.push_back(value->getType());
|
||||
}
|
||||
|
||||
// Add the types of the output values to the function's argument list.
|
||||
for (Value *output : outputs) {
|
||||
LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
|
||||
if (AggregateArgs)
|
||||
paramTy.push_back(output->getType());
|
||||
else
|
||||
paramTy.push_back(PointerType::getUnqual(output->getType()));
|
||||
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
|
||||
AggParamTy.push_back(output->getType());
|
||||
StructValues.insert(output);
|
||||
} else
|
||||
ParamTy.push_back(PointerType::getUnqual(output->getType()));
|
||||
}
|
||||
|
||||
assert(
|
||||
(ParamTy.size() + AggParamTy.size()) ==
|
||||
(inputs.size() + outputs.size()) &&
|
||||
"Number of scalar and aggregate params does not match inputs, outputs");
|
||||
assert(StructValues.empty() ||
|
||||
AggregateArgs && "Expeced StructValues only with AggregateArgs set");
|
||||
|
||||
// Concatenate scalar and aggregate params in ParamTy.
|
||||
size_t NumScalarParams = ParamTy.size();
|
||||
StructType *StructTy = nullptr;
|
||||
if (AggregateArgs && !AggParamTy.empty()) {
|
||||
StructTy = StructType::get(M->getContext(), AggParamTy);
|
||||
ParamTy.push_back(PointerType::getUnqual(StructTy));
|
||||
}
|
||||
|
||||
LLVM_DEBUG({
|
||||
dbgs() << "Function type: " << *RetTy << " f(";
|
||||
for (Type *i : paramTy)
|
||||
for (Type *i : ParamTy)
|
||||
dbgs() << *i << ", ";
|
||||
dbgs() << ")\n";
|
||||
});
|
||||
|
||||
StructType *StructTy = nullptr;
|
||||
if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
|
||||
StructTy = StructType::get(M->getContext(), paramTy);
|
||||
paramTy.clear();
|
||||
paramTy.push_back(PointerType::getUnqual(StructTy));
|
||||
}
|
||||
FunctionType *funcType =
|
||||
FunctionType::get(RetTy, paramTy,
|
||||
AllowVarArgs && oldFunction->isVarArg());
|
||||
FunctionType *funcType = FunctionType::get(
|
||||
RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());
|
||||
|
||||
std::string SuffixToUse =
|
||||
Suffix.empty()
|
||||
|
@ -981,24 +996,27 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
|
|||
}
|
||||
newFunction->getBasicBlockList().push_back(newRootNode);
|
||||
|
||||
// Create an iterator to name all of the arguments we inserted.
|
||||
Function::arg_iterator AI = newFunction->arg_begin();
|
||||
// Create scalar and aggregate iterators to name all of the arguments we
|
||||
// inserted.
|
||||
Function::arg_iterator ScalarAI = newFunction->arg_begin();
|
||||
Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams);
|
||||
|
||||
// Rewrite all users of the inputs in the extracted region to use the
|
||||
// arguments (or appropriate addressing into struct) instead.
|
||||
for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
|
||||
for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
|
||||
Value *RewriteVal;
|
||||
if (AggregateArgs) {
|
||||
if (AggregateArgs && StructValues.contains(inputs[i])) {
|
||||
Value *Idx[2];
|
||||
Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
|
||||
Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
|
||||
Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
|
||||
Instruction *TI = newFunction->begin()->getTerminator();
|
||||
GetElementPtrInst *GEP = GetElementPtrInst::Create(
|
||||
StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
|
||||
RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
|
||||
StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
|
||||
RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
|
||||
"loadgep_" + inputs[i]->getName(), TI);
|
||||
++aggIdx;
|
||||
} else
|
||||
RewriteVal = &*AI++;
|
||||
RewriteVal = &*ScalarAI++;
|
||||
|
||||
std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
|
||||
for (User *use : Users)
|
||||
|
@ -1008,12 +1026,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
|
|||
}
|
||||
|
||||
// Set names for input and output arguments.
|
||||
if (!AggregateArgs) {
|
||||
AI = newFunction->arg_begin();
|
||||
for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
|
||||
AI->setName(inputs[i]->getName());
|
||||
for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
|
||||
AI->setName(outputs[i]->getName()+".out");
|
||||
if (NumScalarParams) {
|
||||
ScalarAI = newFunction->arg_begin();
|
||||
for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
|
||||
if (!StructValues.contains(inputs[i]))
|
||||
ScalarAI->setName(inputs[i]->getName());
|
||||
for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
|
||||
if (!StructValues.contains(outputs[i]))
|
||||
ScalarAI->setName(outputs[i]->getName() + ".out");
|
||||
}
|
||||
|
||||
// Rewrite branches to basic blocks outside of the loop to new dummy blocks
|
||||
|
@ -1121,7 +1141,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
|
|||
ValueSet &outputs) {
|
||||
// Emit a call to the new function, passing in: *pointer to struct (if
|
||||
// aggregating parameters), or plan inputs and allocated memory for outputs
|
||||
std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
|
||||
std::vector<Value *> params, ReloadOutputs, Reloads;
|
||||
ValueSet StructValues;
|
||||
|
||||
Module *M = newFunction->getParent();
|
||||
LLVMContext &Context = M->getContext();
|
||||
|
@ -1129,23 +1150,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
|
|||
CallInst *call = nullptr;
|
||||
|
||||
// Add inputs as params, or to be filled into the struct
|
||||
unsigned ArgNo = 0;
|
||||
unsigned ScalarInputArgNo = 0;
|
||||
SmallVector<unsigned, 1> SwiftErrorArgs;
|
||||
for (Value *input : inputs) {
|
||||
if (AggregateArgs)
|
||||
StructValues.push_back(input);
|
||||
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
|
||||
StructValues.insert(input);
|
||||
else {
|
||||
params.push_back(input);
|
||||
if (input->isSwiftError())
|
||||
SwiftErrorArgs.push_back(ArgNo);
|
||||
SwiftErrorArgs.push_back(ScalarInputArgNo);
|
||||
}
|
||||
++ArgNo;
|
||||
++ScalarInputArgNo;
|
||||
}
|
||||
|
||||
// Create allocas for the outputs
|
||||
unsigned ScalarOutputArgNo = 0;
|
||||
for (Value *output : outputs) {
|
||||
if (AggregateArgs) {
|
||||
StructValues.push_back(output);
|
||||
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
|
||||
StructValues.insert(output);
|
||||
} else {
|
||||
AllocaInst *alloca =
|
||||
new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
|
||||
|
@ -1153,12 +1175,14 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
|
|||
&codeReplacer->getParent()->front().front());
|
||||
ReloadOutputs.push_back(alloca);
|
||||
params.push_back(alloca);
|
||||
++ScalarOutputArgNo;
|
||||
}
|
||||
}
|
||||
|
||||
StructType *StructArgTy = nullptr;
|
||||
AllocaInst *Struct = nullptr;
|
||||
if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
|
||||
unsigned NumAggregatedInputs = 0;
|
||||
if (AggregateArgs && !StructValues.empty()) {
|
||||
std::vector<Type *> ArgTypes;
|
||||
for (Value *V : StructValues)
|
||||
ArgTypes.push_back(V->getType());
|
||||
|
@ -1170,14 +1194,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
|
|||
&codeReplacer->getParent()->front().front());
|
||||
params.push_back(Struct);
|
||||
|
||||
for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
|
||||
Value *Idx[2];
|
||||
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
|
||||
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
|
||||
GetElementPtrInst *GEP = GetElementPtrInst::Create(
|
||||
StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
|
||||
codeReplacer->getInstList().push_back(GEP);
|
||||
new StoreInst(StructValues[i], GEP, codeReplacer);
|
||||
// Store aggregated inputs in the struct.
|
||||
for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
|
||||
if (inputs.contains(StructValues[i])) {
|
||||
Value *Idx[2];
|
||||
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
|
||||
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
|
||||
GetElementPtrInst *GEP = GetElementPtrInst::Create(
|
||||
StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
|
||||
codeReplacer->getInstList().push_back(GEP);
|
||||
new StoreInst(StructValues[i], GEP, codeReplacer);
|
||||
NumAggregatedInputs++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1200,24 +1228,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
|
|||
newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
|
||||
}
|
||||
|
||||
Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
|
||||
unsigned FirstOut = inputs.size();
|
||||
if (!AggregateArgs)
|
||||
std::advance(OutputArgBegin, inputs.size());
|
||||
|
||||
// Reload the outputs passed in by reference.
|
||||
for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
|
||||
// Reload the outputs passed in by reference, use the struct if output is in
|
||||
// the aggregate or reload from the scalar argument.
|
||||
for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
|
||||
aggIdx = NumAggregatedInputs;
|
||||
i != e; ++i) {
|
||||
Value *Output = nullptr;
|
||||
if (AggregateArgs) {
|
||||
if (AggregateArgs && StructValues.contains(outputs[i])) {
|
||||
Value *Idx[2];
|
||||
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
|
||||
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
|
||||
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
|
||||
GetElementPtrInst *GEP = GetElementPtrInst::Create(
|
||||
StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
|
||||
codeReplacer->getInstList().push_back(GEP);
|
||||
Output = GEP;
|
||||
++aggIdx;
|
||||
} else {
|
||||
Output = ReloadOutputs[i];
|
||||
Output = ReloadOutputs[scalarIdx];
|
||||
++scalarIdx;
|
||||
}
|
||||
LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
|
||||
outputs[i]->getName() + ".reload",
|
||||
|
@ -1299,8 +1327,13 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
|
|||
// Store the arguments right after the definition of output value.
|
||||
// This should be proceeded after creating exit stubs to be ensure that invoke
|
||||
// result restore will be placed in the outlined function.
|
||||
Function::arg_iterator OAI = OutputArgBegin;
|
||||
for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
|
||||
Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
|
||||
std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
|
||||
Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
|
||||
std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);
|
||||
|
||||
for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
|
||||
++i) {
|
||||
auto *OutI = dyn_cast<Instruction>(outputs[i]);
|
||||
if (!OutI)
|
||||
continue;
|
||||
|
@ -1320,23 +1353,27 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
|
|||
assert((InsertBefore->getFunction() == newFunction ||
|
||||
Blocks.count(InsertBefore->getParent())) &&
|
||||
"InsertPt should be in new function");
|
||||
assert(OAI != newFunction->arg_end() &&
|
||||
"Number of output arguments should match "
|
||||
"the amount of defined values");
|
||||
if (AggregateArgs) {
|
||||
if (AggregateArgs && StructValues.contains(outputs[i])) {
|
||||
assert(AggOutputArgBegin != newFunction->arg_end() &&
|
||||
"Number of aggregate output arguments should match "
|
||||
"the number of defined values");
|
||||
Value *Idx[2];
|
||||
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
|
||||
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
|
||||
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
|
||||
GetElementPtrInst *GEP = GetElementPtrInst::Create(
|
||||
StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
|
||||
StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
|
||||
InsertBefore);
|
||||
new StoreInst(outputs[i], GEP, InsertBefore);
|
||||
++aggIdx;
|
||||
// Since there should be only one struct argument aggregating
|
||||
// all the output values, we shouldn't increment OAI, which always
|
||||
// points to the struct argument, in this case.
|
||||
// all the output values, we shouldn't increment AggOutputArgBegin, which
|
||||
// always points to the struct argument, in this case.
|
||||
} else {
|
||||
new StoreInst(outputs[i], &*OAI, InsertBefore);
|
||||
++OAI;
|
||||
assert(ScalarOutputArgBegin != newFunction->arg_end() &&
|
||||
"Number of scalar output arguments should match "
|
||||
"the number of defined values");
|
||||
new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore);
|
||||
++ScalarOutputArgBegin;
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1835,3 +1872,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
|
|||
}
|
||||
return false;
|
||||
}
|
||||
|
||||
void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
|
||||
ExcludeArgsFromAggregate.insert(Arg);
|
||||
}
|
||||
|
|
|
@ -188,7 +188,7 @@ TEST(CodeExtractor, ExitBlockOrderingPhis) {
|
|||
EXPECT_TRUE(NextReturn);
|
||||
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
|
||||
EXPECT_TRUE(CINext->getLimitedValue() == 0u);
|
||||
|
||||
|
||||
EXPECT_FALSE(verifyFunction(*Outlined));
|
||||
EXPECT_FALSE(verifyFunction(*Func));
|
||||
}
|
||||
|
@ -245,7 +245,7 @@ TEST(CodeExtractor, ExitBlockOrdering) {
|
|||
EXPECT_TRUE(NextReturn);
|
||||
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
|
||||
EXPECT_TRUE(CINext->getLimitedValue() == 0u);
|
||||
|
||||
|
||||
EXPECT_FALSE(verifyFunction(*Outlined));
|
||||
EXPECT_FALSE(verifyFunction(*Func));
|
||||
}
|
||||
|
@ -504,4 +504,54 @@ TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
|
|||
EXPECT_FALSE(verifyFunction(*Outlined));
|
||||
EXPECT_FALSE(verifyFunction(*Func));
|
||||
}
|
||||
|
||||
TEST(CodeExtractor, PartialAggregateArgs) {
|
||||
LLVMContext Ctx;
|
||||
SMDiagnostic Err;
|
||||
std::unique_ptr<Module> M(parseAssemblyString(R"ir(
|
||||
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
|
||||
target triple = "x86_64-unknown-linux-gnu"
|
||||
|
||||
declare void @use(i32)
|
||||
|
||||
define void @foo(i32 %a, i32 %b, i32 %c) {
|
||||
entry:
|
||||
br label %extract
|
||||
|
||||
extract:
|
||||
call void @use(i32 %a)
|
||||
call void @use(i32 %b)
|
||||
call void @use(i32 %c)
|
||||
br label %exit
|
||||
|
||||
exit:
|
||||
ret void
|
||||
}
|
||||
)ir",
|
||||
Err, Ctx));
|
||||
|
||||
Function *Func = M->getFunction("foo");
|
||||
SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
|
||||
|
||||
// Create the CodeExtractor with arguments aggregation enabled.
|
||||
CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
|
||||
/* AggregateArgs */ true);
|
||||
EXPECT_TRUE(CE.isEligible());
|
||||
|
||||
CodeExtractorAnalysisCache CEAC(*Func);
|
||||
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
|
||||
BasicBlock *CommonExit = nullptr;
|
||||
CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
|
||||
CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
|
||||
// Exclude the first input from the argument aggregate.
|
||||
CE.excludeArgFromAggregate(Inputs[0]);
|
||||
|
||||
Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
|
||||
EXPECT_TRUE(Outlined);
|
||||
// Expect 2 arguments in the outlined function: the excluded input and the
|
||||
// struct aggregate for the remaining inputs.
|
||||
EXPECT_EQ(Outlined->arg_size(), 2U);
|
||||
EXPECT_FALSE(verifyFunction(*Outlined));
|
||||
EXPECT_FALSE(verifyFunction(*Func));
|
||||
}
|
||||
} // end anonymous namespace
|
||||
|
|
Loading…
Reference in New Issue