Eliminate the MLFuncArgument class representing arguments to MLFunctions: use the

BlockArgument arguments of the entry block instead.  This makes MLFunctions and
CFGFunctions work more similarly.

This is step 7/n towards merging instructions and statements, NFC.

PiperOrigin-RevId: 226966975
This commit is contained in:
Chris Lattner 2018-12-26 16:51:31 -08:00 committed by jpienaar
parent 5ff0001dc7
commit 3bd8ff6699
10 changed files with 39 additions and 104 deletions

View File

@ -50,7 +50,6 @@ public:
case SSAValueKind::InstResult:
return true;
case SSAValueKind::MLFuncArgument:
case SSAValueKind::BlockArgument:
case SSAValueKind::StmtResult:
case SSAValueKind::ForStmt:

View File

@ -34,24 +34,24 @@ template <typename ObjectType, typename ElementType> class ArgumentIterator;
// MLFunction is defined as a sequence of statements that may
// include nested affine for loops, conditionals and operations.
class MLFunction final
: public Function,
private llvm::TrailingObjects<MLFunction, MLFuncArgument> {
class MLFunction final : public Function {
public:
/// Creates a new MLFunction with the specific type.
MLFunction(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
// TODO(clattner): drop this, it is redundant.
static MLFunction *create(Location location, StringRef name,
FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
ArrayRef<NamedAttribute> attrs = {}) {
return new MLFunction(location, name, type, attrs);
}
StmtBlockList &getStatementList() { return body; }
const StmtBlockList &getStatementList() const { return body; }
StmtBlockList &getBlockList() { return body; }
const StmtBlockList &getBlockList() const { return body; }
StmtBlock *getBody() { return &body.front(); }
const StmtBlock *getBody() const { return &body.front(); }
/// Destroys this statement and its subclass data.
void destroy();
//===--------------------------------------------------------------------===//
// Arguments
//===--------------------------------------------------------------------===//
@ -60,23 +60,23 @@ public:
unsigned getNumArguments() const { return getType().getInputs().size(); }
/// Gets argument.
MLFuncArgument *getArgument(unsigned idx) {
return &getArgumentsInternal()[idx];
BlockArgument *getArgument(unsigned idx) {
return getBlockList().front().getArgument(idx);
}
const MLFuncArgument *getArgument(unsigned idx) const {
return &getArgumentsInternal()[idx];
const BlockArgument *getArgument(unsigned idx) const {
return getBlockList().front().getArgument(idx);
}
// Supports non-const operand iteration.
using args_iterator = ArgumentIterator<MLFunction, MLFuncArgument>;
using args_iterator = ArgumentIterator<MLFunction, BlockArgument>;
args_iterator args_begin();
args_iterator args_end();
llvm::iterator_range<args_iterator> getArguments();
// Supports const operand iteration.
using const_args_iterator =
ArgumentIterator<const MLFunction, const MLFuncArgument>;
ArgumentIterator<const MLFunction, const BlockArgument>;
const_args_iterator args_begin() const;
const_args_iterator args_end() const;
llvm::iterator_range<const_args_iterator> getArguments() const;
@ -105,23 +105,6 @@ public:
}
private:
MLFunction(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs = {});
// This stuff is used by the TrailingObjects template.
friend llvm::TrailingObjects<MLFunction, MLFuncArgument>;
size_t numTrailingObjects(OverloadToken<MLFuncArgument>) const {
return getType().getInputs().size();
}
// Internal functions to get argument list used by getArgument() methods.
ArrayRef<MLFuncArgument> getArgumentsInternal() const {
return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
}
MutableArrayRef<MLFuncArgument> getArgumentsInternal() {
return {getTrailingObjects<MLFuncArgument>(), getNumArguments()};
}
StmtBlockList body;
};

View File

@ -35,7 +35,6 @@ class StmtBlock;
/// function. This should be kept as a proper subtype of SSAValueKind,
/// including having all of the values of the enumerators align.
enum class MLValueKind {
MLFuncArgument = (int)SSAValueKind::MLFuncArgument,
BlockArgument = (int)SSAValueKind::BlockArgument,
StmtResult = (int)SSAValueKind::StmtResult,
ForStmt = (int)SSAValueKind::ForStmt,
@ -55,7 +54,6 @@ public:
static bool classof(const SSAValue *value) {
switch (value->getKind()) {
case SSAValueKind::MLFuncArgument:
case SSAValueKind::BlockArgument:
case SSAValueKind::StmtResult:
case SSAValueKind::ForStmt:
@ -79,32 +77,6 @@ protected:
MLValue(MLValueKind kind, Type type) : SSAValueImpl(kind, type) {}
};
/// This is the value defined by an argument of an ML function.
class MLFuncArgument : public MLValue {
public:
static bool classof(const SSAValue *value) {
return value->getKind() == SSAValueKind::MLFuncArgument;
}
MLFunction *getOwner() { return owner; }
const MLFunction *getOwner() const { return owner; }
/// Return the function that this MLFuncArgument is defined in.
const MLFunction *getFunction() const { return getOwner(); }
MLFunction *getFunction() { return getOwner(); }
private:
friend class MLFunction; // For access to private constructor.
MLFuncArgument(Type type, MLFunction *owner)
: MLValue(MLValueKind::MLFuncArgument, type), owner(owner) {}
/// The owner of this operand.
/// TODO: can encode this more efficiently to avoid the space hit of this
/// through bitpacking shenanigans.
MLFunction *const owner;
};
/// Block arguments are ML Values.
class BlockArgument : public MLValue {
public:

View File

@ -36,7 +36,6 @@ class Operation;
enum class SSAValueKind {
BBArgument, // basic block argument
InstResult, // instruction result
MLFuncArgument, // ML function argument
BlockArgument, // Block argument
StmtResult, // statement result
ForStmt, // for statement induction variable

View File

@ -75,6 +75,8 @@ public:
// This is the list of arguments to the block.
using BlockArgListType = ArrayRef<BlockArgument *>;
// FIXME: Not const correct.
BlockArgListType getArguments() const { return arguments; }
using args_iterator = BlockArgListType::iterator;

View File

@ -1010,17 +1010,25 @@ protected:
break;
}
// Otherwise number it normally.
LLVM_FALLTHROUGH;
valueIDs[value] = nextValueID++;
return;
case SSAValueKind::BlockArgument:
// If this is an argument to the function, give it an 'arg' name.
if (auto *block = cast<BlockArgument>(value)->getOwner())
if (auto *fn = block->findFunction())
if (&fn->getBlockList().front() == block) {
specialName << "arg" << nextArgumentID++;
break;
}
// Otherwise number it normally.
valueIDs[value] = nextValueID++;
return;
case SSAValueKind::InstResult:
case SSAValueKind::StmtResult:
// This is an uninteresting result, give it a boring number and be
// done with it.
valueIDs[value] = nextValueID++;
return;
case SSAValueKind::MLFuncArgument:
specialName << "arg" << nextArgumentID++;
break;
case SSAValueKind::ForStmt:
specialName << 'i' << nextLoopID++;
break;
@ -1583,10 +1591,6 @@ void SSAValue::print(raw_ostream &os) const {
return;
case SSAValueKind::InstResult:
return getDefiningInst()->print(os);
case SSAValueKind::MLFuncArgument:
// TODO: Improve this.
os << "<function argument>\n";
return;
case SSAValueKind::StmtResult:
return getDefiningStmt()->print(os);
case SSAValueKind::ForStmt:

View File

@ -55,7 +55,7 @@ void Function::destroy() {
delete cast<ExtFunction>(this);
break;
case Kind::MLFunc:
cast<MLFunction>(this)->destroy();
delete cast<MLFunction>(this);
break;
case Kind::CFGFunc:
delete cast<CFGFunction>(this);
@ -182,29 +182,16 @@ CFGFunction::~CFGFunction() {
// MLFunction implementation.
//===----------------------------------------------------------------------===//
/// Create a new MLFunction with the specific fields.
MLFunction *MLFunction::create(Location location, StringRef name,
FunctionType type,
ArrayRef<NamedAttribute> attrs) {
const auto &argTypes = type.getInputs();
auto byteSize = totalSizeToAlloc<MLFuncArgument>(argTypes.size());
void *rawMem = malloc(byteSize);
// Initialize the MLFunction part of the function object.
auto function = ::new (rawMem) MLFunction(location, name, type, attrs);
// Initialize the arguments.
auto arguments = function->getArgumentsInternal();
for (unsigned i = 0, e = argTypes.size(); i != e; ++i)
new (&arguments[i]) MLFuncArgument(argTypes[i], function);
return function;
}
MLFunction::MLFunction(Location location, StringRef name, FunctionType type,
ArrayRef<NamedAttribute> attrs)
: Function(Kind::MLFunc, location, name, type, attrs), body(this) {
// The body of an MLFunction always has one block.
body.push_back(new StmtBlock());
auto *entry = new StmtBlock();
body.push_back(entry);
// Initialize the arguments.
entry->addArguments(type.getInputs());
}
MLFunction::~MLFunction() {
@ -212,15 +199,6 @@ MLFunction::~MLFunction() {
// since child statements need to be destroyed before function arguments
// are destroyed.
getBody()->clear();
// Explicitly run the destructors for the function arguments.
for (auto &arg : getArgumentsInternal())
arg.~MLFuncArgument();
}
void MLFunction::destroy() {
this->~MLFunction();
free(this);
}
const OperationStmt *MLFunction::getReturnStmt() const {

View File

@ -54,8 +54,6 @@ Function *SSAValue::getFunction() {
return cast<BBArgument>(this)->getFunction();
case SSAValueKind::InstResult:
return getDefiningInst()->getFunction();
case SSAValueKind::MLFuncArgument:
return cast<MLFuncArgument>(this)->getFunction();
case SSAValueKind::BlockArgument:
return cast<BlockArgument>(this)->getFunction();
case SSAValueKind::StmtResult:

View File

@ -126,7 +126,7 @@ bool MLValue::isValidSymbol() const {
}
// This value is either a function argument or an induction variable.
// Function argument is ok, induction variable is not.
return isa<MLFuncArgument>(this);
return isa<BlockArgument>(this);
}
void Statement::setOperand(unsigned idx, MLValue *value) {

View File

@ -3452,7 +3452,7 @@ ParseResult ModuleParser::parseMLFunc() {
// Okay, the ML function signature was parsed correctly, create the
// function.
auto *function =
MLFunction::create(getEncodedSourceLocation(loc), name, type, attrs);
new MLFunction(getEncodedSourceLocation(loc), name, type, attrs);
getModule()->getFunctions().push_back(function);
// Verify no name collision / redefinition.