[CodeExtractor] Enable partial aggregate arguments

Summary:
Enable CodeExtractor to construct output functions that partially
aggregate inputs/outputs in their argument list. A use case is the
OMPIRBuilder to create outlined functions for parallel regions that
aggregate in a struct the payload variables for the region while passing
as scalars thread and bound identifiers.

Differential Revision: https://reviews.llvm.org/D96854
This commit is contained in:
Giorgis Georgakoudis 2022-01-25 20:08:19 -05:00 committed by Joseph Huber
parent 510710d037
commit 95b981ca2a
3 changed files with 169 additions and 72 deletions

View File

@ -168,7 +168,7 @@ public:
///
/// Based on the blocks used when constructing the code extractor,
/// determine whether it is eligible for extraction.
///
///
/// Checks that varargs handling (with vastart and vaend) is only done in
/// the outlined blocks.
bool isEligible() const;
@ -214,6 +214,10 @@ public:
/// original block will be added to the outline region.
BasicBlock *findOrCreateBlockForHoisting(BasicBlock *CommonExitBlock);
/// Exclude a value from aggregate argument passing when extracting a code
/// region, passing it instead as a scalar.
void excludeArgFromAggregate(Value *Arg);
private:
struct LifetimeMarkerInfo {
bool SinkLifeStart = false;
@ -222,6 +226,8 @@ public:
Instruction *LifeEnd = nullptr;
};
ValueSet ExcludeArgsFromAggregate;
LifetimeMarkerInfo
getLifetimeMarkers(const CodeExtractorAnalysisCache &CEAC,
Instruction *Addr, BasicBlock *ExitBlock) const;

View File

@ -829,39 +829,54 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
default: RetTy = Type::getInt16Ty(header->getContext()); break;
}
std::vector<Type *> paramTy;
std::vector<Type *> ParamTy;
std::vector<Type *> AggParamTy;
ValueSet StructValues;
// Add the types of the input values to the function's argument list
for (Value *value : inputs) {
LLVM_DEBUG(dbgs() << "value used in func: " << *value << "\n");
paramTy.push_back(value->getType());
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(value)) {
AggParamTy.push_back(value->getType());
StructValues.insert(value);
} else
ParamTy.push_back(value->getType());
}
// Add the types of the output values to the function's argument list.
for (Value *output : outputs) {
LLVM_DEBUG(dbgs() << "instr used in func: " << *output << "\n");
if (AggregateArgs)
paramTy.push_back(output->getType());
else
paramTy.push_back(PointerType::getUnqual(output->getType()));
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
AggParamTy.push_back(output->getType());
StructValues.insert(output);
} else
ParamTy.push_back(PointerType::getUnqual(output->getType()));
}
assert(
(ParamTy.size() + AggParamTy.size()) ==
(inputs.size() + outputs.size()) &&
"Number of scalar and aggregate params does not match inputs, outputs");
assert(StructValues.empty() ||
AggregateArgs && "Expeced StructValues only with AggregateArgs set");
// Concatenate scalar and aggregate params in ParamTy.
size_t NumScalarParams = ParamTy.size();
StructType *StructTy = nullptr;
if (AggregateArgs && !AggParamTy.empty()) {
StructTy = StructType::get(M->getContext(), AggParamTy);
ParamTy.push_back(PointerType::getUnqual(StructTy));
}
LLVM_DEBUG({
dbgs() << "Function type: " << *RetTy << " f(";
for (Type *i : paramTy)
for (Type *i : ParamTy)
dbgs() << *i << ", ";
dbgs() << ")\n";
});
StructType *StructTy = nullptr;
if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
StructTy = StructType::get(M->getContext(), paramTy);
paramTy.clear();
paramTy.push_back(PointerType::getUnqual(StructTy));
}
FunctionType *funcType =
FunctionType::get(RetTy, paramTy,
AllowVarArgs && oldFunction->isVarArg());
FunctionType *funcType = FunctionType::get(
RetTy, ParamTy, AllowVarArgs && oldFunction->isVarArg());
std::string SuffixToUse =
Suffix.empty()
@ -981,24 +996,27 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
}
newFunction->getBasicBlockList().push_back(newRootNode);
// Create an iterator to name all of the arguments we inserted.
Function::arg_iterator AI = newFunction->arg_begin();
// Create scalar and aggregate iterators to name all of the arguments we
// inserted.
Function::arg_iterator ScalarAI = newFunction->arg_begin();
Function::arg_iterator AggAI = std::next(ScalarAI, NumScalarParams);
// Rewrite all users of the inputs in the extracted region to use the
// arguments (or appropriate addressing into struct) instead.
for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
for (unsigned i = 0, e = inputs.size(), aggIdx = 0; i != e; ++i) {
Value *RewriteVal;
if (AggregateArgs) {
if (AggregateArgs && StructValues.contains(inputs[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(header->getContext()));
Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), i);
Idx[1] = ConstantInt::get(Type::getInt32Ty(header->getContext()), aggIdx);
Instruction *TI = newFunction->begin()->getTerminator();
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructTy, &*AI, Idx, "gep_" + inputs[i]->getName(), TI);
RewriteVal = new LoadInst(StructTy->getElementType(i), GEP,
StructTy, &*AggAI, Idx, "gep_" + inputs[i]->getName(), TI);
RewriteVal = new LoadInst(StructTy->getElementType(aggIdx), GEP,
"loadgep_" + inputs[i]->getName(), TI);
++aggIdx;
} else
RewriteVal = &*AI++;
RewriteVal = &*ScalarAI++;
std::vector<User *> Users(inputs[i]->user_begin(), inputs[i]->user_end());
for (User *use : Users)
@ -1008,12 +1026,14 @@ Function *CodeExtractor::constructFunction(const ValueSet &inputs,
}
// Set names for input and output arguments.
if (!AggregateArgs) {
AI = newFunction->arg_begin();
for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++AI)
AI->setName(inputs[i]->getName());
for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++AI)
AI->setName(outputs[i]->getName()+".out");
if (NumScalarParams) {
ScalarAI = newFunction->arg_begin();
for (unsigned i = 0, e = inputs.size(); i != e; ++i, ++ScalarAI)
if (!StructValues.contains(inputs[i]))
ScalarAI->setName(inputs[i]->getName());
for (unsigned i = 0, e = outputs.size(); i != e; ++i, ++ScalarAI)
if (!StructValues.contains(outputs[i]))
ScalarAI->setName(outputs[i]->getName() + ".out");
}
// Rewrite branches to basic blocks outside of the loop to new dummy blocks
@ -1121,7 +1141,8 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
ValueSet &outputs) {
// Emit a call to the new function, passing in: *pointer to struct (if
// aggregating parameters), or plan inputs and allocated memory for outputs
std::vector<Value *> params, StructValues, ReloadOutputs, Reloads;
std::vector<Value *> params, ReloadOutputs, Reloads;
ValueSet StructValues;
Module *M = newFunction->getParent();
LLVMContext &Context = M->getContext();
@ -1129,23 +1150,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
CallInst *call = nullptr;
// Add inputs as params, or to be filled into the struct
unsigned ArgNo = 0;
unsigned ScalarInputArgNo = 0;
SmallVector<unsigned, 1> SwiftErrorArgs;
for (Value *input : inputs) {
if (AggregateArgs)
StructValues.push_back(input);
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(input))
StructValues.insert(input);
else {
params.push_back(input);
if (input->isSwiftError())
SwiftErrorArgs.push_back(ArgNo);
SwiftErrorArgs.push_back(ScalarInputArgNo);
}
++ArgNo;
++ScalarInputArgNo;
}
// Create allocas for the outputs
unsigned ScalarOutputArgNo = 0;
for (Value *output : outputs) {
if (AggregateArgs) {
StructValues.push_back(output);
if (AggregateArgs && !ExcludeArgsFromAggregate.contains(output)) {
StructValues.insert(output);
} else {
AllocaInst *alloca =
new AllocaInst(output->getType(), DL.getAllocaAddrSpace(),
@ -1153,12 +1175,14 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
&codeReplacer->getParent()->front().front());
ReloadOutputs.push_back(alloca);
params.push_back(alloca);
++ScalarOutputArgNo;
}
}
StructType *StructArgTy = nullptr;
AllocaInst *Struct = nullptr;
if (AggregateArgs && (inputs.size() + outputs.size() > 0)) {
unsigned NumAggregatedInputs = 0;
if (AggregateArgs && !StructValues.empty()) {
std::vector<Type *> ArgTypes;
for (Value *V : StructValues)
ArgTypes.push_back(V->getType());
@ -1170,14 +1194,18 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
&codeReplacer->getParent()->front().front());
params.push_back(Struct);
for (unsigned i = 0, e = inputs.size(); i != e; ++i) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
codeReplacer->getInstList().push_back(GEP);
new StoreInst(StructValues[i], GEP, codeReplacer);
// Store aggregated inputs in the struct.
for (unsigned i = 0, e = StructValues.size(); i != e; ++i) {
if (inputs.contains(StructValues[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), i);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, Struct, Idx, "gep_" + StructValues[i]->getName());
codeReplacer->getInstList().push_back(GEP);
new StoreInst(StructValues[i], GEP, codeReplacer);
NumAggregatedInputs++;
}
}
}
@ -1200,24 +1228,24 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
newFunction->addParamAttr(SwiftErrArgNo, Attribute::SwiftError);
}
Function::arg_iterator OutputArgBegin = newFunction->arg_begin();
unsigned FirstOut = inputs.size();
if (!AggregateArgs)
std::advance(OutputArgBegin, inputs.size());
// Reload the outputs passed in by reference.
for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
// Reload the outputs passed in by reference, use the struct if output is in
// the aggregate or reload from the scalar argument.
for (unsigned i = 0, e = outputs.size(), scalarIdx = 0,
aggIdx = NumAggregatedInputs;
i != e; ++i) {
Value *Output = nullptr;
if (AggregateArgs) {
if (AggregateArgs && StructValues.contains(outputs[i])) {
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, Struct, Idx, "gep_reload_" + outputs[i]->getName());
codeReplacer->getInstList().push_back(GEP);
Output = GEP;
++aggIdx;
} else {
Output = ReloadOutputs[i];
Output = ReloadOutputs[scalarIdx];
++scalarIdx;
}
LoadInst *load = new LoadInst(outputs[i]->getType(), Output,
outputs[i]->getName() + ".reload",
@ -1299,8 +1327,13 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
// Store the arguments right after the definition of output value.
// This should be proceeded after creating exit stubs to be ensure that invoke
// result restore will be placed in the outlined function.
Function::arg_iterator OAI = OutputArgBegin;
for (unsigned i = 0, e = outputs.size(); i != e; ++i) {
Function::arg_iterator ScalarOutputArgBegin = newFunction->arg_begin();
std::advance(ScalarOutputArgBegin, ScalarInputArgNo);
Function::arg_iterator AggOutputArgBegin = newFunction->arg_begin();
std::advance(AggOutputArgBegin, ScalarInputArgNo + ScalarOutputArgNo);
for (unsigned i = 0, e = outputs.size(), aggIdx = NumAggregatedInputs; i != e;
++i) {
auto *OutI = dyn_cast<Instruction>(outputs[i]);
if (!OutI)
continue;
@ -1320,23 +1353,27 @@ CallInst *CodeExtractor::emitCallAndSwitchStatement(Function *newFunction,
assert((InsertBefore->getFunction() == newFunction ||
Blocks.count(InsertBefore->getParent())) &&
"InsertPt should be in new function");
assert(OAI != newFunction->arg_end() &&
"Number of output arguments should match "
"the amount of defined values");
if (AggregateArgs) {
if (AggregateArgs && StructValues.contains(outputs[i])) {
assert(AggOutputArgBegin != newFunction->arg_end() &&
"Number of aggregate output arguments should match "
"the number of defined values");
Value *Idx[2];
Idx[0] = Constant::getNullValue(Type::getInt32Ty(Context));
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), FirstOut + i);
Idx[1] = ConstantInt::get(Type::getInt32Ty(Context), aggIdx);
GetElementPtrInst *GEP = GetElementPtrInst::Create(
StructArgTy, &*OAI, Idx, "gep_" + outputs[i]->getName(),
StructArgTy, &*AggOutputArgBegin, Idx, "gep_" + outputs[i]->getName(),
InsertBefore);
new StoreInst(outputs[i], GEP, InsertBefore);
++aggIdx;
// Since there should be only one struct argument aggregating
// all the output values, we shouldn't increment OAI, which always
// points to the struct argument, in this case.
// all the output values, we shouldn't increment AggOutputArgBegin, which
// always points to the struct argument, in this case.
} else {
new StoreInst(outputs[i], &*OAI, InsertBefore);
++OAI;
assert(ScalarOutputArgBegin != newFunction->arg_end() &&
"Number of scalar output arguments should match "
"the number of defined values");
new StoreInst(outputs[i], &*ScalarOutputArgBegin, InsertBefore);
++ScalarOutputArgBegin;
}
}
@ -1835,3 +1872,7 @@ bool CodeExtractor::verifyAssumptionCache(const Function &OldFunc,
}
return false;
}
void CodeExtractor::excludeArgFromAggregate(Value *Arg) {
ExcludeArgsFromAggregate.insert(Arg);
}

View File

@ -188,7 +188,7 @@ TEST(CodeExtractor, ExitBlockOrderingPhis) {
EXPECT_TRUE(NextReturn);
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
EXPECT_TRUE(CINext->getLimitedValue() == 0u);
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
@ -245,7 +245,7 @@ TEST(CodeExtractor, ExitBlockOrdering) {
EXPECT_TRUE(NextReturn);
ConstantInt *CINext = dyn_cast<ConstantInt>(NextReturn->getReturnValue());
EXPECT_TRUE(CINext->getLimitedValue() == 0u);
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
@ -504,4 +504,54 @@ TEST(CodeExtractor, RemoveBitcastUsesFromOuterLifetimeMarkers) {
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
TEST(CodeExtractor, PartialAggregateArgs) {
LLVMContext Ctx;
SMDiagnostic Err;
std::unique_ptr<Module> M(parseAssemblyString(R"ir(
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"
declare void @use(i32)
define void @foo(i32 %a, i32 %b, i32 %c) {
entry:
br label %extract
extract:
call void @use(i32 %a)
call void @use(i32 %b)
call void @use(i32 %c)
br label %exit
exit:
ret void
}
)ir",
Err, Ctx));
Function *Func = M->getFunction("foo");
SmallVector<BasicBlock *, 1> Blocks{getBlockByName(Func, "extract")};
// Create the CodeExtractor with arguments aggregation enabled.
CodeExtractor CE(Blocks, /* DominatorTree */ nullptr,
/* AggregateArgs */ true);
EXPECT_TRUE(CE.isEligible());
CodeExtractorAnalysisCache CEAC(*Func);
SetVector<Value *> Inputs, Outputs, SinkingCands, HoistingCands;
BasicBlock *CommonExit = nullptr;
CE.findAllocas(CEAC, SinkingCands, HoistingCands, CommonExit);
CE.findInputsOutputs(Inputs, Outputs, SinkingCands);
// Exclude the first input from the argument aggregate.
CE.excludeArgFromAggregate(Inputs[0]);
Function *Outlined = CE.extractCodeRegion(CEAC, Inputs, Outputs);
EXPECT_TRUE(Outlined);
// Expect 2 arguments in the outlined function: the excluded input and the
// struct aggregate for the remaining inputs.
EXPECT_EQ(Outlined->arg_size(), 2U);
EXPECT_FALSE(verifyFunction(*Outlined));
EXPECT_FALSE(verifyFunction(*Func));
}
} // end anonymous namespace