forked from OSchip/llvm-project
Clean up the op builder APIs, and simplify the implementation of ops by making
OperationState contain a context and have the generic builder mechanics handle the job of initializing the OperationState and setting the op name. NFC. PiperOrigin-RevId: 209869948
This commit is contained in:
parent
84259c7def
commit
d42ecea381
|
@ -175,7 +175,9 @@ public:
|
|||
/// Create operation of specific op type at the current insertion point.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpPointer<OpTy> create(Args... args) {
|
||||
auto *inst = createOperation(OpTy::build(this, args...));
|
||||
OperationState state(getContext(), OpTy::getOperationName());
|
||||
OpTy::build(this, &state, args...);
|
||||
auto *inst = createOperation(state);
|
||||
auto result = inst->template getAs<OpTy>();
|
||||
assert(result && "Builder didn't return the right type");
|
||||
return result;
|
||||
|
@ -279,7 +281,9 @@ public:
|
|||
/// Create operation of specific op type at the current insertion point.
|
||||
template <typename OpTy, typename... Args>
|
||||
OpPointer<OpTy> create(Args... args) {
|
||||
auto stmt = createOperation(OpTy::build(this, args...));
|
||||
OperationState state(getContext(), OpTy::getOperationName());
|
||||
OpTy::build(this, &state, args...);
|
||||
auto *stmt = createOperation(state);
|
||||
auto result = stmt->template getAs<OpTy>();
|
||||
assert(result && "Builder didn't return the right type");
|
||||
return result;
|
||||
|
|
|
@ -45,6 +45,7 @@ typedef std::pair<Identifier, Attribute*> NamedAttribute;
|
|||
/// be used as a temporary object on the stack. It is generally unwise to put
|
||||
/// this in a collection.
|
||||
struct OperationState {
|
||||
MLIRContext *const context;
|
||||
Identifier name;
|
||||
SmallVector<SSAValue *, 4> operands;
|
||||
/// Types of the results of this operation.
|
||||
|
@ -52,14 +53,31 @@ struct OperationState {
|
|||
SmallVector<NamedAttribute, 4> attributes;
|
||||
|
||||
public:
|
||||
OperationState(Identifier name) : name(name) {}
|
||||
OperationState(MLIRContext *context, StringRef name)
|
||||
: context(context), name(Identifier::get(name, context)) {}
|
||||
|
||||
OperationState(Identifier name, ArrayRef<SSAValue *> operands,
|
||||
ArrayRef<Type *> types,
|
||||
OperationState(MLIRContext *context, Identifier name)
|
||||
: context(context), name(name) {}
|
||||
|
||||
OperationState(MLIRContext *context, StringRef name,
|
||||
ArrayRef<SSAValue *> operands, ArrayRef<Type *> types,
|
||||
ArrayRef<NamedAttribute> attributes = {})
|
||||
: name(name), operands(operands.begin(), operands.end()),
|
||||
: context(context), name(Identifier::get(name, context)),
|
||||
operands(operands.begin(), operands.end()),
|
||||
types(types.begin(), types.end()),
|
||||
attributes(attributes.begin(), attributes.end()) {}
|
||||
|
||||
void addOperands(ArrayRef<SSAValue *> newOperands) {
|
||||
operands.append(newOperands.begin(), newOperands.end());
|
||||
}
|
||||
|
||||
void addTypes(ArrayRef<Type *> newTypes) {
|
||||
types.append(newTypes.begin(), newTypes.end());
|
||||
}
|
||||
|
||||
void addAttribute(StringRef name, Attribute *attr) {
|
||||
attributes.push_back({Identifier::get(name, context), attr});
|
||||
}
|
||||
};
|
||||
|
||||
/// Operations represent all of the arithmetic and other basic computation in
|
||||
|
|
|
@ -43,13 +43,8 @@ class AddFOp
|
|||
public:
|
||||
static StringRef getOperationName() { return "addf"; }
|
||||
|
||||
template <class Builder, class Value>
|
||||
static OpPointer<AddFOp> build(Builder *builder, Value *lhs, Value *rhs) {
|
||||
// The resultant type of a addf is the same as both the lhs and rhs.
|
||||
return OpPointer<AddFOp>(AddFOp(builder->createOperation(
|
||||
builder->getIdentifier("addf"), {lhs, rhs}, {lhs->getType()}, {})));
|
||||
}
|
||||
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *lhs,
|
||||
SSAValue *rhs);
|
||||
const char *verify() const;
|
||||
static bool parse(OpAsmParser *parser, OperationState *result);
|
||||
void print(OpAsmPrinter *p) const;
|
||||
|
@ -77,8 +72,8 @@ class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
|
|||
OpTrait::VariadicResults> {
|
||||
public:
|
||||
/// Builds an affine apply op with the specified map and operands.
|
||||
static OperationState build(Builder *builder, AffineMap *map,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
static void build(Builder *builder, OperationState *result, AffineMap *map,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
|
||||
// Returns the affine map to be applied by this operation.
|
||||
AffineMap *getAffineMap() const {
|
||||
|
@ -145,8 +140,8 @@ class CallOp : public OpBase<CallOp, OpTrait::VariadicOperands,
|
|||
public:
|
||||
static StringRef getOperationName() { return "call"; }
|
||||
|
||||
static OperationState build(Builder *builder, Function *callee,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
static void build(Builder *builder, OperationState *result, Function *callee,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
|
||||
Function *getCallee() const {
|
||||
return getAttrOfType<FunctionAttr>("callee")->getValue();
|
||||
|
@ -175,8 +170,8 @@ class CallIndirectOp : public OpBase<CallIndirectOp, OpTrait::VariadicOperands,
|
|||
public:
|
||||
static StringRef getOperationName() { return "call_indirect"; }
|
||||
|
||||
static OperationState build(Builder *builder, SSAValue *callee,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
static void build(Builder *builder, OperationState *result, SSAValue *callee,
|
||||
ArrayRef<SSAValue *> operands);
|
||||
|
||||
const SSAValue *getCallee() const { return getOperand(0); }
|
||||
SSAValue *getCallee() { return getOperand(0); }
|
||||
|
@ -222,7 +217,8 @@ protected:
|
|||
class ConstantFloatOp : public ConstantOp {
|
||||
public:
|
||||
/// Builds a constant float op producing a float of the specified type.
|
||||
static OperationState build(Builder *builder, double value, FloatType *type);
|
||||
static void build(Builder *builder, OperationState *result, double value,
|
||||
FloatType *type);
|
||||
|
||||
double getValue() const {
|
||||
return getAttrOfType<FloatAttr>("value")->getValue();
|
||||
|
@ -243,7 +239,8 @@ private:
|
|||
class ConstantIntOp : public ConstantOp {
|
||||
public:
|
||||
/// Build a constant int op producing an integer of the specified width.
|
||||
static OperationState build(Builder *builder, int64_t value, unsigned width);
|
||||
static void build(Builder *builder, OperationState *result, int64_t value,
|
||||
unsigned width);
|
||||
|
||||
int64_t getValue() const {
|
||||
return getAttrOfType<IntegerAttr>("value")->getValue();
|
||||
|
@ -264,7 +261,7 @@ private:
|
|||
class ConstantAffineIntOp : public ConstantOp {
|
||||
public:
|
||||
/// Build a constant int op producing an affineint.
|
||||
static OperationState build(Builder *builder, int64_t value);
|
||||
static void build(Builder *builder, OperationState *result, int64_t value);
|
||||
|
||||
int64_t getValue() const {
|
||||
return getAttrOfType<IntegerAttr>("value")->getValue();
|
||||
|
|
|
@ -66,6 +66,13 @@ parseDimAndSymbolList(OpAsmParser *parser,
|
|||
// AddFOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
void AddFOp::build(Builder *builder, OperationState *result, SSAValue *lhs,
|
||||
SSAValue *rhs) {
|
||||
assert(lhs->getType() == rhs->getType());
|
||||
result->addOperands({lhs, rhs});
|
||||
result->types.push_back(lhs->getType());
|
||||
}
|
||||
|
||||
bool AddFOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
SmallVector<OpAsmParser::OperandType, 2> ops;
|
||||
Type *type;
|
||||
|
@ -201,15 +208,11 @@ const char *AllocOp::verify() const {
|
|||
// CallOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperationState CallOp::build(Builder *builder, Function *callee,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
OperationState result(builder->getIdentifier("call"));
|
||||
result.operands.append(operands.begin(), operands.end());
|
||||
result.attributes.push_back(
|
||||
{builder->getIdentifier("callee"), builder->getFunctionAttr(callee)});
|
||||
result.types.append(callee->getType()->getResults().begin(),
|
||||
callee->getType()->getResults().end());
|
||||
return result;
|
||||
void CallOp::build(Builder *builder, OperationState *result, Function *callee,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->addAttribute("callee", builder->getFunctionAttr(callee));
|
||||
result->addTypes(callee->getType()->getResults());
|
||||
}
|
||||
|
||||
bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
|
@ -229,10 +232,7 @@ bool CallOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
result->operands))
|
||||
return true;
|
||||
|
||||
auto &builder = parser->getBuilder();
|
||||
result->attributes.push_back(
|
||||
{builder.getIdentifier("callee"), builder.getFunctionAttr(callee)});
|
||||
|
||||
result->addAttribute("callee", parser->getBuilder().getFunctionAttr(callee));
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -277,15 +277,12 @@ const char *CallOp::verify() const {
|
|||
// CallIndirectOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperationState CallIndirectOp::build(Builder *builder, SSAValue *callee,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
void CallIndirectOp::build(Builder *builder, OperationState *result,
|
||||
SSAValue *callee, ArrayRef<SSAValue *> operands) {
|
||||
auto *fnType = cast<FunctionType>(callee->getType());
|
||||
|
||||
OperationState result(builder->getIdentifier("call_indirect"));
|
||||
result.operands.push_back(callee);
|
||||
result.operands.append(operands.begin(), operands.end());
|
||||
result.types.append(fnType->getResults().begin(), fnType->getResults().end());
|
||||
return result;
|
||||
result->operands.push_back(callee);
|
||||
result->addOperands(operands);
|
||||
result->addTypes(fnType->getResults());
|
||||
}
|
||||
|
||||
bool CallIndirectOp::parse(OpAsmParser *parser, OperationState *result) {
|
||||
|
@ -406,13 +403,10 @@ const char *ConstantOp::verify() const {
|
|||
return "requires a result type that aligns with the 'value' attribute";
|
||||
}
|
||||
|
||||
OperationState ConstantFloatOp::build(Builder *builder, double value,
|
||||
FloatType *type) {
|
||||
OperationState result(builder->getIdentifier("constant"));
|
||||
result.attributes.push_back(
|
||||
{builder->getIdentifier("value"), builder->getFloatAttr(value)});
|
||||
result.types.push_back(type);
|
||||
return result;
|
||||
void ConstantFloatOp::build(Builder *builder, OperationState *result,
|
||||
double value, FloatType *type) {
|
||||
result->addAttribute("value", builder->getFloatAttr(value));
|
||||
result->types.push_back(type);
|
||||
}
|
||||
|
||||
bool ConstantFloatOp::isClassFor(const Operation *op) {
|
||||
|
@ -426,13 +420,10 @@ bool ConstantIntOp::isClassFor(const Operation *op) {
|
|||
isa<IntegerType>(op->getResult(0)->getType());
|
||||
}
|
||||
|
||||
OperationState ConstantIntOp::build(Builder *builder, int64_t value,
|
||||
unsigned width) {
|
||||
OperationState result(builder->getIdentifier("constant"));
|
||||
result.attributes.push_back(
|
||||
{builder->getIdentifier("value"), builder->getIntegerAttr(value)});
|
||||
result.types.push_back(builder->getIntegerType(width));
|
||||
return result;
|
||||
void ConstantIntOp::build(Builder *builder, OperationState *result,
|
||||
int64_t value, unsigned width) {
|
||||
result->addAttribute("value", builder->getIntegerAttr(value));
|
||||
result->types.push_back(builder->getIntegerType(width));
|
||||
}
|
||||
|
||||
/// ConstantAffineIntOp only matches values whose result type is AffineInt.
|
||||
|
@ -441,28 +432,21 @@ bool ConstantAffineIntOp::isClassFor(const Operation *op) {
|
|||
op->getResult(0)->getType()->isAffineInt();
|
||||
}
|
||||
|
||||
OperationState ConstantAffineIntOp::build(Builder *builder, int64_t value) {
|
||||
OperationState result(builder->getIdentifier("constant"));
|
||||
result.attributes.push_back(
|
||||
{builder->getIdentifier("value"), builder->getIntegerAttr(value)});
|
||||
result.types.push_back(builder->getAffineIntType());
|
||||
return result;
|
||||
void ConstantAffineIntOp::build(Builder *builder, OperationState *result,
|
||||
int64_t value) {
|
||||
result->addAttribute("value", builder->getIntegerAttr(value));
|
||||
result->types.push_back(builder->getAffineIntType());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AffineApplyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperationState AffineApplyOp::build(Builder *builder, AffineMap *map,
|
||||
ArrayRef<SSAValue *> operands) {
|
||||
SmallVector<Type *, 4> resultTypes(map->getNumResults(),
|
||||
builder->getAffineIntType());
|
||||
|
||||
OperationState result(
|
||||
builder->getIdentifier("affine_apply"), operands, resultTypes,
|
||||
{{builder->getIdentifier("map"), builder->getAffineMapAttr(map)}});
|
||||
|
||||
return result;
|
||||
void AffineApplyOp::build(Builder *builder, OperationState *result,
|
||||
AffineMap *map, ArrayRef<SSAValue *> operands) {
|
||||
result->addOperands(operands);
|
||||
result->types.append(map->getNumResults(), builder->getAffineIntType());
|
||||
result->addAttribute("map", builder->getAffineMapAttr(map));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -1648,7 +1648,7 @@ Operation *FunctionParser::parseVerboseOperation(
|
|||
|
||||
consumeToken(Token::string);
|
||||
|
||||
OperationState result(builder.getIdentifier(name));
|
||||
OperationState result(builder.getContext(), name);
|
||||
|
||||
// Parse the operand list.
|
||||
SmallVector<SSAUseInfo, 8> operandInfos;
|
||||
|
@ -1675,7 +1675,7 @@ Operation *FunctionParser::parseVerboseOperation(
|
|||
if (!fnType)
|
||||
return (emitError(typeLoc, "expected function type"), nullptr);
|
||||
|
||||
result.types.append(fnType->getResults().begin(), fnType->getResults().end());
|
||||
result.addTypes(fnType->getResults());
|
||||
|
||||
// Check that we have the right number of types for the operands.
|
||||
auto operandTypes = fnType->getInputs();
|
||||
|
@ -1916,7 +1916,7 @@ Operation *FunctionParser::parseCustomOperation(
|
|||
opNameStr.c_str());
|
||||
|
||||
// Have the op implementation take a crack and parsing this.
|
||||
OperationState opState(builder.getIdentifier(opName));
|
||||
OperationState opState(builder.getContext(), opName);
|
||||
if (opDefinition->parseAssembly(&opAsmParser, &opState))
|
||||
return nullptr;
|
||||
|
||||
|
|
Loading…
Reference in New Issue