Clean up the op builder APIs, and simplify the implementation of ops by making

OperationState contain a context and have the generic builder mechanics handle
the job of initializing the OperationState and setting the op name.  NFC.

PiperOrigin-RevId: 209869948
This commit is contained in:
Chris Lattner 2018-08-22 19:25:49 -07:00 committed by jpienaar
parent 84259c7def
commit d42ecea381
5 changed files with 79 additions and 76 deletions

View File

@ -175,7 +175,9 @@ public:
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Args... args) {
auto *inst = createOperation(OpTy::build(this, args...));
OperationState state(getContext(), OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *inst = createOperation(state);
auto result = inst->template getAs<OpTy>();
assert(result && "Builder didn't return the right type");
return result;
@ -279,7 +281,9 @@ public:
/// Create operation of specific op type at the current insertion point.
template <typename OpTy, typename... Args>
OpPointer<OpTy> create(Args... args) {
auto stmt = createOperation(OpTy::build(this, args...));
OperationState state(getContext(), OpTy::getOperationName());
OpTy::build(this, &state, args...);
auto *stmt = createOperation(state);
auto result = stmt->template getAs<OpTy>();
assert(result && "Builder didn't return the right type");
return result;

View File

@ -45,6 +45,7 @@ typedef std::pair<Identifier, Attribute*> NamedAttribute;
/// be used as a temporary object on the stack. It is generally unwise to put
/// this in a collection.
struct OperationState {
MLIRContext *const context;
Identifier name;
SmallVector<SSAValue *, 4> operands;
/// Types of the results of this operation.
@ -52,14 +53,31 @@ struct OperationState {
SmallVector<NamedAttribute, 4> attributes;
public:
OperationState(Identifier name) : name(name) {}
OperationState(MLIRContext *context, StringRef name)
: context(context), name(Identifier::get(name, context)) {}
OperationState(Identifier name, ArrayRef<SSAValue *> operands,
ArrayRef<Type *> types,
OperationState(MLIRContext *context, Identifier name)
: context(context), name(name) {}
OperationState(MLIRContext *context, StringRef name,
ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
ArrayRef<NamedAttribute> attributes = {})
: name(name), operands(operands.begin(), operands.end()),
: context(context), name(Identifier::get(name, context)),
operands(operands.begin(), operands.end()),
types(types.begin(), types.end()),
attributes(attributes.begin(), attributes.end()) {}
void addOperands(ArrayRef<SSAValue *> newOperands) {
operands.append(newOperands.begin(), newOperands.end());
}
void addTypes(ArrayRef<Type *> newTypes) {
types.append(newTypes.begin(), newTypes.end());
}
void addAttribute(StringRef name, Attribute *attr) {
attributes.push_back({Identifier::get(name, context), attr});
}
};
/// Operations represent all of the arithmetic and other basic computation in

View File

@ -43,13 +43,8 @@ class AddFOp
public:
static StringRef getOperationName() { return "addf"; }
template <class Builder, class Value>
static OpPointer<AddFOp> build(Builder *builder, Value *lhs, Value *rhs) {
// The resultant type of a addf is the same as both the lhs and rhs.
return OpPointer<AddFOp>(AddFOp(builder->createOperation(
builder->getIdentifier("addf"), {lhs, rhs}, {lhs->getType()}, {})));
}
static void build(Builder *builder, OperationState *result, SSAValue *lhs,
SSAValue *rhs);
const char *verify() const;
static bool parse(OpAsmParser *parser, OperationState *result);
void print(OpAsmPrinter *p) const;
@ -77,8 +72,8 @@ class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
OpTrait::VariadicResults> {
public:
/// Builds an affine apply op with the specified map and operands.
static OperationState build(Builder *builder, AffineMap *map,
ArrayRef<SSAValue *> operands);
static void build(Builder *builder, OperationState *result, AffineMap *map,
ArrayRef<SSAValue *> operands);
// Returns the affine map to be applied by this operation.
AffineMap *getAffineMap() const {
@ -145,8 +140,8 @@ class CallOp : public OpBase<CallOp, OpTrait::VariadicOperands,
public:
static StringRef getOperationName() { return "call"; }
static OperationState build(Builder *builder, Function *callee,
ArrayRef<SSAValue *> operands);
static void build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands);
Function *getCallee() const {
return getAttrOfType<FunctionAttr>("callee")->getValue();
@ -175,8 +170,8 @@ class CallIndirectOp : public OpBase<CallIndirectOp, OpTrait::VariadicOperands,
public:
static StringRef getOperationName() { return "call_indirect"; }
static OperationState build(Builder *builder, SSAValue *callee,
ArrayRef<SSAValue *> operands);
static void build(Builder *builder, OperationState *result, SSAValue *callee,
ArrayRef<SSAValue *> operands);
const SSAValue *getCallee() const { return getOperand(0); }
SSAValue *getCallee() { return getOperand(0); }
@ -222,7 +217,8 @@ protected:
class ConstantFloatOp : public ConstantOp {
public:
/// Builds a constant float op producing a float of the specified type.
static OperationState build(Builder *builder, double value, FloatType *type);
static void build(Builder *builder, OperationState *result, double value,
FloatType *type);
double getValue() const {
return getAttrOfType<FloatAttr>("value")->getValue();
@ -243,7 +239,8 @@ private:
class ConstantIntOp : public ConstantOp {
public:
/// Build a constant int op producing an integer of the specified width.
static OperationState build(Builder *builder, int64_t value, unsigned width);
static void build(Builder *builder, OperationState *result, int64_t value,
unsigned width);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();
@ -264,7 +261,7 @@ private:
class ConstantAffineIntOp : public ConstantOp {
public:
/// Build a constant int op producing an affineint.
static OperationState build(Builder *builder, int64_t value);
static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value")->getValue();

View File

@ -66,6 +66,13 @@ parseDimAndSymbolList(OpAsmParser *parser,
// AddFOp
//===----------------------------------------------------------------------===//
void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
SSAValue *rhs) {
assert(lhs->getType() == rhs->getType());
result->addOperands({lhs, rhs});
result->types.push_back(lhs->getType());
}
bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
SmallVector<OpAsmParser::OperandType, 2> ops;
Type *type;
@ -201,15 +208,11 @@ const char *AllocOp::verify() const {
// CallOp
//===----------------------------------------------------------------------===//
OperationState CallOp::build(Builder *builder, Function *callee,
ArrayRef<SSAValue *> operands) {
OperationState result(builder->getIdentifier("call"));
result.operands.append(operands.begin(), operands.end());
result.attributes.push_back(
{builder->getIdentifier("callee"), builder->getFunctionAttr(callee)});
result.types.append(callee->getType()->getResults().begin(),
callee->getType()->getResults().end());
return result;
void CallOp::build(Builder *builder, OperationState *result, Function *callee,
ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->addAttribute("callee", builder->getFunctionAttr(callee));
result->addTypes(callee->getType()->getResults());
}
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
@ -229,10 +232,7 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
result->operands))
return true;
auto &builder = parser->getBuilder();
result->attributes.push_back(
{builder.getIdentifier("callee"), builder.getFunctionAttr(callee)});
result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
return false;
}
@ -277,15 +277,12 @@ const char *CallOp::verify() const {
// CallIndirectOp
//===----------------------------------------------------------------------===//
OperationState CallIndirectOp::build(Builder *builder, SSAValue *callee,
ArrayRef<SSAValue *> operands) {
void CallIndirectOp::build(Builder *builder, OperationState *result,
SSAValue *callee, ArrayRef<SSAValue *> operands) {
auto *fnType = cast<FunctionType>(callee->getType());
OperationState result(builder->getIdentifier("call_indirect"));
result.operands.push_back(callee);
result.operands.append(operands.begin(), operands.end());
result.types.append(fnType->getResults().begin(), fnType->getResults().end());
return result;
result->operands.push_back(callee);
result->addOperands(operands);
result->addTypes(fnType->getResults());
}
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
@ -406,13 +403,10 @@ const char *ConstantOp::verify() const {
return "requires a result type that aligns with the 'value' attribute";
}
OperationState ConstantFloatOp::build(Builder *builder, double value,
FloatType *type) {
OperationState result(builder->getIdentifier("constant"));
result.attributes.push_back(
{builder->getIdentifier("value"), builder->getFloatAttr(value)});
result.types.push_back(type);
return result;
void ConstantFloatOp::build(Builder *builder, OperationState *result,
double value, FloatType *type) {
result->addAttribute("value", builder->getFloatAttr(value));
result->types.push_back(type);
}
bool ConstantFloatOp::isClassFor(const Operation *op) {
@ -426,13 +420,10 @@ bool ConstantIntOp::isClassFor(const Operation *op) {
isa<IntegerType>(op->getResult(0)->getType());
}
OperationState ConstantIntOp::build(Builder *builder, int64_t value,
unsigned width) {
OperationState result(builder->getIdentifier("constant"));
result.attributes.push_back(
{builder->getIdentifier("value"), builder->getIntegerAttr(value)});
result.types.push_back(builder->getIntegerType(width));
return result;
void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, unsigned width) {
result->addAttribute("value", builder->getIntegerAttr(value));
result->types.push_back(builder->getIntegerType(width));
}
/// ConstantAffineIntOp only matches values whose result type is AffineInt.
@ -441,28 +432,21 @@ bool ConstantAffineIntOp::isClassFor(const Operation *op) {
op->getResult(0)->getType()->isAffineInt();
}
OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
OperationState result(builder->getIdentifier("constant"));
result.attributes.push_back(
{builder->getIdentifier("value"), builder->getIntegerAttr(value)});
result.types.push_back(builder->getAffineIntType());
return result;
void ConstantAffineIntOp::build(Builder *builder, OperationState *result,
int64_t value) {
result->addAttribute("value", builder->getIntegerAttr(value));
result->types.push_back(builder->getAffineIntType());
}
//===----------------------------------------------------------------------===//
// AffineApplyOp
//===----------------------------------------------------------------------===//
OperationState AffineApplyOp::build(Builder *builder, AffineMap *map,
ArrayRef<SSAValue *> operands) {
SmallVector<Type *, 4> resultTypes(map->getNumResults(),
builder->getAffineIntType());
OperationState result(
builder->getIdentifier("affine_apply"), operands, resultTypes,
{{builder->getIdentifier("map"), builder->getAffineMapAttr(map)}});
return result;
void AffineApplyOp::build(Builder *builder, OperationState *result,
AffineMap *map, ArrayRef<SSAValue *> operands) {
result->addOperands(operands);
result->types.append(map->getNumResults(), builder->getAffineIntType());
result->addAttribute("map", builder->getAffineMapAttr(map));
}
//===----------------------------------------------------------------------===//

View File

@ -1648,7 +1648,7 @@ Operation *FunctionParser::parseVerboseOperation(
consumeToken(Token::string);
OperationState result(builder.getIdentifier(name));
OperationState result(builder.getContext(), name);
// Parse the operand list.
SmallVector<SSAUseInfo, 8> operandInfos;
@ -1675,7 +1675,7 @@ Operation *FunctionParser::parseVerboseOperation(
if (!fnType)
return (emitError(typeLoc, "expected function type"), nullptr);
result.types.append(fnType->getResults().begin(), fnType->getResults().end());
result.addTypes(fnType->getResults());
// Check that we have the right number of types for the operands.
auto operandTypes = fnType->getInputs();
@ -1916,7 +1916,7 @@ Operation *FunctionParser::parseCustomOperation(
opNameStr.c_str());
// Have the op implementation take a crack and parsing this.
OperationState opState(builder.getIdentifier(opName));
OperationState opState(builder.getContext(), opName);
if (opDefinition->parseAssembly(&opAsmParser, &opState))
return nullptr;