forked from OSchip/llvm-project
Eliminate the MLFuncArgument class representing arguments to MLFunctions: use the
BlockArgument arguments of the entry block instead. This makes MLFunctions and CFGFunctions work more similarly. This is step 7/n towards merging instructions and statements, NFC. PiperOrigin-RevId: 226966975
This commit is contained in:
parent
5ff0001dc7
commit
3bd8ff6699
|
@ -50,7 +50,6 @@ public:
|
|||
case SSAValueKind::InstResult:
|
||||
return true;
|
||||
|
||||
case SSAValueKind::MLFuncArgument:
|
||||
case SSAValueKind::BlockArgument:
|
||||
case SSAValueKind::StmtResult:
|
||||
case SSAValueKind::ForStmt:
|
||||
|
|
|
@ -34,24 +34,24 @@ template <typename ObjectType, typename ElementType> class ArgumentIterator;
|
|||
|
||||
// MLFunction is defined as a sequence of statements that may
|
||||
// include nested affine for loops, conditionals and operations.
|
||||
class MLFunction final
|
||||
: public Function,
|
||||
private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
|
||||
class MLFunction final : public Function {
|
||||
public:
|
||||
/// Creates a new MLFunction with the specific type.
|
||||
MLFunction(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
|
||||
// TODO(clattner): drop this, it is redundant.
|
||||
static MLFunction *create(Location location, StringRef name,
|
||||
FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
ArrayRef<NamedAttribute> attrs = {}) {
|
||||
return new MLFunction(location, name, type, attrs);
|
||||
}
|
||||
|
||||
StmtBlockList &getStatementList() { return body; }
|
||||
const StmtBlockList &getStatementList() const { return body; }
|
||||
StmtBlockList &getBlockList() { return body; }
|
||||
const StmtBlockList &getBlockList() const { return body; }
|
||||
|
||||
StmtBlock *getBody() { return &body.front(); }
|
||||
const StmtBlock *getBody() const { return &body.front(); }
|
||||
|
||||
/// Destroys this statement and its subclass data.
|
||||
void destroy();
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Arguments
|
||||
//===--------------------------------------------------------------------===//
|
||||
|
@ -60,23 +60,23 @@ public:
|
|||
unsigned getNumArguments() const { return getType().getInputs().size(); }
|
||||
|
||||
/// Gets argument.
|
||||
MLFuncArgument *getArgument(unsigned idx) {
|
||||
return &getArgumentsInternal()[idx];
|
||||
BlockArgument *getArgument(unsigned idx) {
|
||||
return getBlockList().front().getArgument(idx);
|
||||
}
|
||||
|
||||
const MLFuncArgument *getArgument(unsigned idx) const {
|
||||
return &getArgumentsInternal()[idx];
|
||||
const BlockArgument *getArgument(unsigned idx) const {
|
||||
return getBlockList().front().getArgument(idx);
|
||||
}
|
||||
|
||||
// Supports non-const operand iteration.
|
||||
using args_iterator = ArgumentIterator<MLFunction, MLFuncArgument>;
|
||||
using args_iterator = ArgumentIterator<MLFunction, BlockArgument>;
|
||||
args_iterator args_begin();
|
||||
args_iterator args_end();
|
||||
llvm::iterator_range<args_iterator> getArguments();
|
||||
|
||||
// Supports const operand iteration.
|
||||
using const_args_iterator =
|
||||
ArgumentIterator<const MLFunction, const MLFuncArgument>;
|
||||
ArgumentIterator<const MLFunction, const BlockArgument>;
|
||||
const_args_iterator args_begin() const;
|
||||
const_args_iterator args_end() const;
|
||||
llvm::iterator_range<const_args_iterator> getArguments() const;
|
||||
|
@ -105,23 +105,6 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
MLFunction(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs = {});
|
||||
|
||||
// This stuff is used by the TrailingObjects template.
|
||||
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
|
||||
size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
|
||||
return getType().getInputs().size();
|
||||
}
|
||||
|
||||
// Internal functions to get argument list used by getArgument() methods.
|
||||
ArrayRef<MLFuncArgument> getArgumentsInternal() const {
|
||||
return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
|
||||
}
|
||||
MutableArrayRef<MLFuncArgument> getArgumentsInternal() {
|
||||
return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
|
||||
}
|
||||
|
||||
StmtBlockList body;
|
||||
};
|
||||
|
||||
|
|
|
@ -35,7 +35,6 @@ class StmtBlock;
|
|||
/// function. This should be kept as a proper subtype of SSAValueKind,
|
||||
/// including having all of the values of the enumerators align.
|
||||
enum class MLValueKind {
|
||||
MLFuncArgument = (int)SSAValueKind::MLFuncArgument,
|
||||
BlockArgument = (int)SSAValueKind::BlockArgument,
|
||||
StmtResult = (int)SSAValueKind::StmtResult,
|
||||
ForStmt = (int)SSAValueKind::ForStmt,
|
||||
|
@ -55,7 +54,6 @@ public:
|
|||
|
||||
static bool classof(const SSAValue *value) {
|
||||
switch (value->getKind()) {
|
||||
case SSAValueKind::MLFuncArgument:
|
||||
case SSAValueKind::BlockArgument:
|
||||
case SSAValueKind::StmtResult:
|
||||
case SSAValueKind::ForStmt:
|
||||
|
@ -79,32 +77,6 @@ protected:
|
|||
MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
|
||||
};
|
||||
|
||||
/// This is the value defined by an argument of an ML function.
|
||||
class MLFuncArgument : public MLValue {
|
||||
public:
|
||||
static bool classof(const SSAValue *value) {
|
||||
return value->getKind() == SSAValueKind::MLFuncArgument;
|
||||
}
|
||||
|
||||
MLFunction *getOwner() { return owner; }
|
||||
const MLFunction *getOwner() const { return owner; }
|
||||
|
||||
/// Return the function that this MLFuncArgument is defined in.
|
||||
const MLFunction *getFunction() const { return getOwner(); }
|
||||
|
||||
MLFunction *getFunction() { return getOwner(); }
|
||||
|
||||
private:
|
||||
friend class MLFunction; // For access to private constructor.
|
||||
MLFuncArgument(Type type, MLFunction *owner)
|
||||
: MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
|
||||
|
||||
/// The owner of this operand.
|
||||
/// TODO: can encode this more efficiently to avoid the space hit of this
|
||||
/// through bitpacking shenanigans.
|
||||
MLFunction *const owner;
|
||||
};
|
||||
|
||||
/// Block arguments are ML Values.
|
||||
class BlockArgument : public MLValue {
|
||||
public:
|
||||
|
|
|
@ -36,7 +36,6 @@ class Operation;
|
|||
enum class SSAValueKind {
|
||||
BBArgument, // basic block argument
|
||||
InstResult, // instruction result
|
||||
MLFuncArgument, // ML function argument
|
||||
BlockArgument, // Block argument
|
||||
StmtResult, // statement result
|
||||
ForStmt, // for statement induction variable
|
||||
|
|
|
@ -75,6 +75,8 @@ public:
|
|||
|
||||
// This is the list of arguments to the block.
|
||||
using BlockArgListType = ArrayRef<BlockArgument *>;
|
||||
|
||||
// FIXME: Not const correct.
|
||||
BlockArgListType getArguments() const { return arguments; }
|
||||
|
||||
using args_iterator = BlockArgListType::iterator;
|
||||
|
|
|
@ -1010,17 +1010,25 @@ protected:
|
|||
break;
|
||||
}
|
||||
// Otherwise number it normally.
|
||||
LLVM_FALLTHROUGH;
|
||||
valueIDs[value] = nextValueID++;
|
||||
return;
|
||||
case SSAValueKind::BlockArgument:
|
||||
// If this is an argument to the function, give it an 'arg' name.
|
||||
if (auto *block = cast<BlockArgument>(value)->getOwner())
|
||||
if (auto *fn = block->findFunction())
|
||||
if (&fn->getBlockList().front() == block) {
|
||||
specialName << "arg" << nextArgumentID++;
|
||||
break;
|
||||
}
|
||||
// Otherwise number it normally.
|
||||
valueIDs[value] = nextValueID++;
|
||||
return;
|
||||
case SSAValueKind::InstResult:
|
||||
case SSAValueKind::StmtResult:
|
||||
// This is an uninteresting result, give it a boring number and be
|
||||
// done with it.
|
||||
valueIDs[value] = nextValueID++;
|
||||
return;
|
||||
case SSAValueKind::MLFuncArgument:
|
||||
specialName << "arg" << nextArgumentID++;
|
||||
break;
|
||||
case SSAValueKind::ForStmt:
|
||||
specialName << 'i' << nextLoopID++;
|
||||
break;
|
||||
|
@ -1583,10 +1591,6 @@ void SSAValue::print(raw_ostream &os) const {
|
|||
return;
|
||||
case SSAValueKind::InstResult:
|
||||
return getDefiningInst()->print(os);
|
||||
case SSAValueKind::MLFuncArgument:
|
||||
// TODO: Improve this.
|
||||
os << "<function argument>\n";
|
||||
return;
|
||||
case SSAValueKind::StmtResult:
|
||||
return getDefiningStmt()->print(os);
|
||||
case SSAValueKind::ForStmt:
|
||||
|
|
|
@ -55,7 +55,7 @@ void Function::destroy() {
|
|||
delete cast<ExtFunction>(this);
|
||||
break;
|
||||
case Kind::MLFunc:
|
||||
cast<MLFunction>(this)->destroy();
|
||||
delete cast<MLFunction>(this);
|
||||
break;
|
||||
case Kind::CFGFunc:
|
||||
delete cast<CFGFunction>(this);
|
||||
|
@ -182,29 +182,16 @@ CFGFunction::~CFGFunction() {
|
|||
// MLFunction implementation.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Create a new MLFunction with the specific fields.
|
||||
MLFunction *MLFunction::create(Location location, StringRef name,
|
||||
FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs) {
|
||||
const auto &argTypes = type.getInputs();
|
||||
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
|
||||
void *rawMem = malloc(byteSize);
|
||||
|
||||
// Initialize the MLFunction part of the function object.
|
||||
auto function = ::new (rawMem) MLFunction(location, name, type, attrs);
|
||||
|
||||
// Initialize the arguments.
|
||||
auto arguments = function->getArgumentsInternal();
|
||||
for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
|
||||
new (&arguments[i]) MLFuncArgument(argTypes[i], function);
|
||||
return function;
|
||||
}
|
||||
|
||||
MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
|
||||
ArrayRef<NamedAttribute> attrs)
|
||||
: Function(Kind::MLFunc, location, name, type, attrs), body(this) {
|
||||
|
||||
// The body of an MLFunction always has one block.
|
||||
body.push_back(new StmtBlock());
|
||||
auto *entry = new StmtBlock();
|
||||
body.push_back(entry);
|
||||
|
||||
// Initialize the arguments.
|
||||
entry->addArguments(type.getInputs());
|
||||
}
|
||||
|
||||
MLFunction::~MLFunction() {
|
||||
|
@ -212,15 +199,6 @@ MLFunction::~MLFunction() {
|
|||
// since child statements need to be destroyed before function arguments
|
||||
// are destroyed.
|
||||
getBody()->clear();
|
||||
|
||||
// Explicitly run the destructors for the function arguments.
|
||||
for (auto &arg : getArgumentsInternal())
|
||||
arg.~MLFuncArgument();
|
||||
}
|
||||
|
||||
void MLFunction::destroy() {
|
||||
this->~MLFunction();
|
||||
free(this);
|
||||
}
|
||||
|
||||
const OperationStmt *MLFunction::getReturnStmt() const {
|
||||
|
|
|
@ -54,8 +54,6 @@ Function *SSAValue::getFunction() {
|
|||
return cast<BBArgument>(this)->getFunction();
|
||||
case SSAValueKind::InstResult:
|
||||
return getDefiningInst()->getFunction();
|
||||
case SSAValueKind::MLFuncArgument:
|
||||
return cast<MLFuncArgument>(this)->getFunction();
|
||||
case SSAValueKind::BlockArgument:
|
||||
return cast<BlockArgument>(this)->getFunction();
|
||||
case SSAValueKind::StmtResult:
|
||||
|
|
|
@ -126,7 +126,7 @@ bool MLValue::isValidSymbol() const {
|
|||
}
|
||||
// This value is either a function argument or an induction variable.
|
||||
// Function argument is ok, induction variable is not.
|
||||
return isa<MLFuncArgument>(this);
|
||||
return isa<BlockArgument>(this);
|
||||
}
|
||||
|
||||
void Statement::setOperand(unsigned idx, MLValue *value) {
|
||||
|
|
|
@ -3452,7 +3452,7 @@ ParseResult ModuleParser::parseMLFunc() {
|
|||
// Okay, the ML function signature was parsed correctly, create the
|
||||
// function.
|
||||
auto *function =
|
||||
MLFunction::create(getEncodedSourceLocation(loc), name, type, attrs);
|
||||
new MLFunction(getEncodedSourceLocation(loc), name, type, attrs);
|
||||
getModule()->getFunctions().push_back(function);
|
||||
|
||||
// Verify no name collision / redefinition.
|
||||
|
|
Loading…
Reference in New Issue