[TableGen] Support multiple variadic operands/results

Certain ops can have multiple variadic operands/results, e.g., `tf.DynamicStitch`.
    Even if an op has only one variadic operand/result, it is not necessarily the
    very last one, e.g., `tf.RaggedGather`. This CL enhances TableGen subsystem to be
    able to represent such cases.

    In order to deduce the operand/result value range for each variadic operand,
    currently we only support variadic operands/results all of the same size.
    So two new traits, `SameVariadicOperandSize` and `SameVariadicResultSize` are
    introduced.

--

PiperOrigin-RevId: 245310628
This commit is contained in:
Lei Zhang 2019-04-25 14:45:37 -07:00 committed by Mehdi Amini
parent 22ad45a7aa
commit 6749c21d6e
11 changed files with 409 additions and 182 deletions

View File

@ -759,6 +759,17 @@ def Terminator : NativeOpTrait<"IsTerminator">;
def FirstAttrDerivedResultType :
GenInternalOpTrait<"FirstAttrDerivedResultType">;
// All variadic operands of the op have the same number of values.
// A variadic operand contains an array of values whose array size is only
// known at runtime. This trait requires all variadic operands of an op
// to have the same array size.
def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
// All variadic results of the op have the same number of values.
// A variadic result contains an array of values whose array size is only
// known at runtime. This trait requires all variadic results of an op
// to have the same array size.
def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
//===----------------------------------------------------------------------===//
// Op definitions
//===----------------------------------------------------------------------===//

View File

@ -48,10 +48,12 @@ struct NamedAttribute {
Attribute attr;
};
// A struct wrapping an op operand/result and its name together
// A struct wrapping an op operand/result's constraint and its name together
struct NamedTypeConstraint {
// Returns true if this operand has constraint that need to be satisfied.
// Returns true if this operand/result has constraint to be satisfied.
bool hasPredicate() const;
// Returns true if this operand/result is variadic.
bool isVariadic() const;
llvm::StringRef name;
TypeConstraint constraint;

View File

@ -82,8 +82,8 @@ public:
// Returns the `index`-th result's name.
StringRef getResultName(int index) const;
// Returns true if this operation has a variadic result.
bool hasVariadicResult() const;
// Returns the number of variadic results in this operation.
unsigned getNumVariadicResults() const;
// Op attribute interators.
using attribute_iterator = const NamedAttribute *;
@ -112,8 +112,8 @@ public:
return operands[index];
}
// Returns true if this operation has a variadic operand.
bool hasVariadicOperand() const;
// Returns the number of variadic operands in this operation.
unsigned getNumVariadicOperands() const;
// Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); }

View File

@ -23,3 +23,7 @@ using namespace mlir;
bool tblgen::NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull();
}
bool tblgen::NamedTypeConstraint::isVariadic() const {
return constraint.isVariadic();
}

View File

@ -82,8 +82,10 @@ StringRef tblgen::Operator::getResultName(int index) const {
return results->getArgNameStr(index);
}
bool tblgen::Operator::hasVariadicResult() const {
return !results.empty() && results.back().constraint.isVariadic();
unsigned tblgen::Operator::getNumVariadicResults() const {
return std::count_if(
results.begin(), results.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
}
int tblgen::Operator::getNumNativeAttributes() const {
@ -98,8 +100,10 @@ const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
return attributes[index];
}
bool tblgen::Operator::hasVariadicOperand() const {
return !operands.empty() && operands.back().constraint.isVariadic();
unsigned tblgen::Operator::getNumVariadicOperands() const {
return std::count_if(
operands.begin(), operands.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
}
StringRef tblgen::Operator::getArgName(int index) const {
@ -222,13 +226,6 @@ void tblgen::Operator::populateOpStructure() {
}
}
// Verify that only the last operand can be variadic.
for (int i = 0, e = operands.size() - 1; i < e; ++i) {
if (operands[i].constraint.isVariadic())
PrintFatalError(def.getLoc(),
"only the last operand allowed to be variadic");
}
auto *resultsDag = def.getValueAsDag("results");
auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
if (!outsOp || outsOp->getDef()->getName() != "outs") {
@ -246,13 +243,6 @@ void tblgen::Operator::populateOpStructure() {
results.push_back({name, TypeConstraint(resultDef)});
}
// Verify that only the last result can be variadic.
for (int i = 0, e = results.size() - 1; i < e; ++i) {
if (results[i].constraint.isVariadic())
PrintFatalError(def.getLoc(),
"only the last result allowed to be variadic");
}
auto traitListInit = def.getValueAsListInit("traits");
if (!traitListInit)
return;

View File

@ -1,28 +0,0 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
def NS_OpA : Op<"op_same_value_type", [SameValueType]> {
let arguments = (ins Tensor:$input);
let results = (outs Tensor:$result);
}
// Test that with SameValueType trait we can generate a builder without
// requiring result type
// ---
// CHECK-LABEL: OpA::build(Builder *, OperationState *tblgen_state, Value *input)
// CHECK: tblgen_state->addTypes({input->getType()});
def NS_OpB : Op<"op_same_value_type_variadic_input", [SameValueType]> {
let arguments = (ins Variadic<Tensor>:$input);
let results = (outs Tensor:$result);
}
// Test that if the only operand is variadic, we acess the first value in the
// pack to set result type
// ---
// CHECK-LABEL: OpB::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
// CHECK: tblgen_state->addTypes({input.front()->getType()});

View File

@ -2,7 +2,7 @@
include "mlir/IR/OpBase.td"
def OpA : Op<"one_operand_op", []> {
def OpA : Op<"one_normal_operand_op", []> {
let arguments = (ins I32:$input);
}
@ -10,7 +10,7 @@ def OpA : Op<"one_operand_op", []> {
// CHECK: void OpA::build
// CHECK-SAME: Value *input
// CHECK: tblgen_state->addOperands({input});
// CHECK: tblgen_state->operands.push_back(input);
// CHECK: void OpA::build
// CHECK-SAME: ArrayRef<Value *> operands
@ -21,11 +21,72 @@ def OpA : Op<"one_operand_op", []> {
// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32))))
// CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer");
def OpB : Op<"variadic_operand_op", []> {
def OpB : Op<"one_variadic_operand_op", []> {
let arguments = (ins Variadic<I32>:$input);
}
// CHECK-LABEL: OpB::build
// CHECK-SAME: ArrayRef<Value *> input
// CHECK-NOT: assert
// CHECK: tblgen_state->addOperands(input);
// CHECK-SAME: ArrayRef<Value *> input
// CHECK-NOT: assert
// CHECK: tblgen_state->addOperands(input);
def OpC : Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<Tensor>:$input1, Variadic<Tensor>:$input2);
}
// CHECK-LABEL: Operation::operand_range OpC::input1()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0;
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
// CHECK-LABEL: Operation::operand_range OpC::input2()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1;
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
// CHECK-LABEL: OpC::build
// CHECK-NEXT: tblgen_state->addOperands(input1);
// CHECK-NEXT: tblgen_state->addOperands(input2);
def OpD : Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<Tensor>:$input1, Tensor:$input2, Variadic<Tensor>:$input3);
}
// CHECK-LABEL: Operation::operand_range OpD::input1()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0;
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
// CHECK-LABEL: Value *OpD::input2()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1;
// CHECK-NEXT: return this->getOperand(offset);
// CHECK-LABEL: Operation::operand_range OpD::input3()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: unsigned offset = 1 + variadicOperandSize * 1;
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
// CHECK-LABEL: OpD::build
// CHECK-NEXT: tblgen_state->addOperands(input1);
// CHECK-NEXT: tblgen_state->operands.push_back(input2);
// CHECK-NEXT: tblgen_state->addOperands(input3);
def OpE : Op<"one_variadic_among_multi_normal_inputs_op", []> {
let arguments = (ins Tensor:$input1, Tensor:$input2, Variadic<Tensor>:$input3, Tensor:$input4, Tensor:$input5);
}
// CHECK-LABEL: Value *OpE::input1()
// CHECK-NEXT: return this->getOperation()->getOperand(0);
// CHECK-LABEL: Value *OpE::input2()
// CHECK-NEXT: return this->getOperation()->getOperand(1);
// CHECK-LABEL: Operation::operand_range OpE::input3()
// CHECK-NEXT: return {std::next(operand_begin(), 2), std::next(operand_begin(), 2 + this->getNumOperands() - 4)};
// CHECK-LABEL: Value *OpE::input4()
// CHECK-NEXT: return this->getOperation()->getOperand(this->getNumOperands() - 2);
// CHECK-LABEL: Value *OpE::input5()
// CHECK-NEXT: return this->getOperation()->getOperand(this->getNumOperands() - 1);

View File

@ -2,82 +2,160 @@
include "mlir/IR/OpBase.td"
def OneResultOp : Op<"one_result_op", []> {
def OpA : Op<"one_normal_result_op", []> {
let results = (outs I32:$result);
}
// CHECK-LABEL: OneResultOp definitions
// CHECK-LABEL: Value *OpA::result()
// CHECK-NEXT: return this->getOperation()->getResult(0)
// CHECK: void OneResultOp::build
// CHECK-LABEL: void OpA::build
// CHECK: ArrayRef<Type> resultTypes, ArrayRef<Value *> operands
// CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
// CHECK-NEXT: tblgen_state->addTypes(resultTypes);
// CHECK: LogicalResult OneResultOp::verify() {
// CHECK-LABEL: LogicalResult OpA::verify()
// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32))))
// CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer");
def SameTypeOp : Op<"same_type_op", [SameValueType]> {
def OpB : Op<"same_input_output_type_op", [SameValueType]> {
let arguments = (ins I32:$x);
let results = (outs I32:$y);
}
// CHECK-LABEL: SameTypeOp definitions
// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Type y, Value *x)
// CHECK: tblgen_state->addTypes({y});
// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Value *x)
// CHECK-LABEL: OpB definitions
// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Type y, Value *x)
// CHECK: tblgen_state->types.push_back(y);
// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Value *x)
// CHECK: tblgen_state->addTypes({x->getType()});
def ThreeResultOp : Op<"three_result_op", []> {
def OpC : Op<"three_normal_result_op", []> {
let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
}
// CHECK-LABEL: ThreeResultOp definitions
// CHECK: void ThreeResultOp::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z)
// CHECK: tblgen_state->addTypes({x, resultType1, z});
// CHECK-LABEL: OpC definitions
// CHECK: void OpC::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z)
// CHECK-NEXT: tblgen_state->types.push_back(x)
// CHECK-NEXT: tblgen_state->types.push_back(resultType1)
// CHECK-NEXT: tblgen_state->types.push_back(z)
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
def TypeAttrResultTypeOp : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
def OpD : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
let results = (outs Tensor:$y);
}
// CHECK-LABEL: TypeAttrResultTypeOp definitions
// CHECK: void TypeAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32)
// CHECK-LABEL: OpD definitions
// CHECK: void OpD::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32)
// CHECK: tblgen_state->addTypes({attr.getValue()});
def ValueAttrResultTypeOp : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
def OpE : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, F32Attr:$attr);
let results = (outs Tensor:$y);
}
// CHECK-LABEL: ValueAttrResultTypeOp definitions
// CHECK: void ValueAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr)
// CHECK-LABEL: OpE definitions
// CHECK: void OpE::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr)
// CHECK: tblgen_state->addTypes({attr.getType()});
def VariadicResultAloneOp : Op<"variadic_alone_op", []> {
def OpF : Op<"one_variadic_result_op", []> {
let results = (outs Variadic<I32>:$x);
}
// CHECK-LABEL: VariadicResultAloneOp definitions
// CHECK-LABEL: Operation::result_range OpF::x()
// CHECK-NEXT: return {std::next(result_begin(), 0), std::next(result_begin(), 0 + this->getNumResults() - 0)};
// CHECK-LABEL: void VariadicResultAloneOp::build
// CHECK-SAME: ArrayRef<Type> x
// CHECK-NOT: assert
// CHECK: tblgen_state->addTypes(x);
// CHECK-LABEL: void OpF::build
// CHECK-SAME: ArrayRef<Type> x
// CHECK-NOT: assert
// CHECK: tblgen_state->addTypes(x);
def OpG : Op<"one_normal_and_one_variadic_result_op", []> {
def VariadicResultOp : Op<"variadic_op", []> {
let results = (outs I32:$x, Variadic<I32>:$y);
}
// CHECK-LABEL: VariadicResultOp definitions
// CHECK-LABEL: OpG definitions
// CHECK: void VariadicResultOp::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y)
// CHECK: tblgen_state->addTypes({x});
// CHECK: tblgen_state->addTypes(y);
// CHECK: void OpG::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y)
// CHECK-NEXT: tblgen_state->types.push_back(x);
// CHECK-NEXT: tblgen_state->addTypes(y);
// CHECK: void VariadicResultOp::build
// CHECK: void OpG::build
// CHECK: ArrayRef<Type> resultTypes
// CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types");
// CHECK-NEXT: tblgen_state->addTypes(resultTypes);
def OpH : Op<"all_variadic_results_op", [SameVariadicResultSize]> {
let results = (outs Variadic<Tensor>:$output1, Variadic<Tensor>:$output2);
}
// CHECK-LABEL: Operation::result_range OpH::output1()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0;
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
// CHECK-LABEL: Operation::result_range OpH::output2()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1;
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
// CHECK-LABEL: OpH::build
// CHECK-NEXT: tblgen_state->addTypes(output1);
// CHECK-NEXT: tblgen_state->addTypes(output2);
def OpI : Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
let results = (outs Variadic<Tensor>:$output1, Tensor:$output2, Variadic<Tensor>:$output3);
}
// CHECK-LABEL: Operation::result_range OpI::output1()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0;
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
// CHECK-LABEL: Value *OpI::output2()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1;
// CHECK-NEXT: return this->getResult(offset);
// CHECK-LABEL: Operation::result_range OpI::output3()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
// CHECK-NEXT: unsigned offset = 1 + variadicResultSize * 1;
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
// CHECK-LABEL: OpI::build
// CHECK-NEXT: tblgen_state->addTypes(output1);
// CHECK-NEXT: tblgen_state->types.push_back(output2);
// CHECK-NEXT: tblgen_state->addTypes(output3);
def OpJ : Op<"one_variadic_among_multi_normal_results_op", []> {
let results = (outs Tensor:$output1, Tensor:$output2, Variadic<Tensor>:$output3, Tensor:$output4, Tensor:$output5);
}
// CHECK-LABEL: Value *OpJ::output1()
// CHECK-NEXT: return this->getOperation()->getResult(0);
// CHECK-LABEL: Value *OpJ::output2()
// CHECK-NEXT: return this->getOperation()->getResult(1);
// CHECK-LABEL: Operation::result_range OpJ::output3()
// CHECK-NEXT: return {std::next(result_begin(), 2), std::next(result_begin(), 2 + this->getNumResults() - 4)};
// CHECK-LABEL: Value *OpJ::output4()
// CHECK-NEXT: return this->getOperation()->getResult(this->getNumResults() - 2);
// CHECK-LABEL: Value *OpJ::output5()
// CHECK-NEXT: return this->getOperation()->getResult(this->getNumResults() - 1);
// Test that if the only operand is variadic, we acess the first value in the
// pack to set result type
// ---
def OpK : Op<"only_input_is_variadic_with_same_value_type_op", [SameValueType]> {
let arguments = (ins Variadic<Tensor>:$input);
let results = (outs Tensor:$result);
}
// CHECK-LABEL: OpK::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
// CHECK: tblgen_state->addTypes({input.front()->getType()});

View File

@ -75,10 +75,14 @@ static StringLoc findNextVariable(StringRef str) {
return {startPos, endPos - startPos};
}
// Check if `name` is the name of the variadic argument of `op`. The variadic
// argument can only appear at the last position in the list of arguments.
static bool isVariadicArgumentName(const tblgen::Operator &op, StringRef name) {
return op.hasVariadicOperand() && op.getArgName(op.getNumArgs() - 1) == name;
// Check if `name` is the name of the variadic operand of `op`. The variadic
// operand can only appear at the last position in the list of operands.
static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
unsigned numOperands = op.getNumOperands();
if (numOperands == 0)
return false;
const auto &operand = op.getOperand(numOperands - 1);
return operand.isVariadic() && operand.name == name;
}
// Check if `result` is a known name of a result of `op`.
@ -127,9 +131,9 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
// First, insert the non-matched part as is.
bs << builderStrRef.substr(0, loc.pos);
// Then, rewrite the name based on its kind.
bool isVariadicArg = isVariadicArgumentName(op, name);
bool isVariadicOperand = isVariadicOperandName(op, name);
if (isOperandName(op, name)) {
auto result = isVariadicArg
auto result = isVariadicOperand
? formatv("lookupValues(op.{0}())", name)
: formatv("valueMapping.lookup(op.{0}())", name);
bs << result;

View File

@ -251,8 +251,9 @@ OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
}
void OpMethodBody::writeTo(raw_ostream &os) const {
os << body;
if (body.empty() || body.back() != '\n')
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
os << bodyRef;
if (bodyRef.empty() || bodyRef.back() != '\n')
os << "\n";
}
@ -455,35 +456,153 @@ void OpEmitter::genAttrGetters() {
}
void OpEmitter::genNamedOperandGetters() {
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
const unsigned numOperands = op.getNumOperands();
const unsigned numVariadicOperands = op.getNumVariadicOperands();
const unsigned numNormalOperands = numOperands - numVariadicOperands;
// Special case for ops without variadic operands: the i-th value is for the
// i-th operand defined in the op.
// Special case for ops with one variadic operand: the variadic operand can
// appear at any place, so the i-th value may not necessarily belong to the
// i-th operand definition. we need to calculate the index (range) for each
// operand.
if (numVariadicOperands <= 1) {
bool emittedVariadicOperand = false;
for (unsigned i = 0; i != numOperands; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
if (operand.isVariadic()) {
auto &m = opClass.newMethod("Operation::operand_range", operand.name);
m.body() << formatv(
" return {{std::next(operand_begin(), {0}), "
"std::next(operand_begin(), {0} + this->getNumOperands() - {1})};",
i, numNormalOperands);
emittedVariadicOperand = true;
} else {
auto &m = opClass.newMethod("Value *", operand.name);
m.body() << " return this->getOperation()->getOperand(";
if (emittedVariadicOperand)
m.body() << "this->getNumOperands() - " << numOperands - i;
else
m.body() << i;
m.body() << ");\n";
}
}
return;
}
// If we have more than one variadic operands, we need more complicated logic
// to calculate the value range for each operand.
if (!op.hasTrait("SameVariadicOperandSize")) {
PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
"specification over their sizes");
}
unsigned emittedNormalOperands = 0;
unsigned emittedVariadicOperands = 0;
for (unsigned i = 0; i != numOperands; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
if (!operand.constraint.isVariadic()) {
auto &m = opClass.newMethod("Value *", operand.name);
m.body() << " return this->getOperation()->getOperand(" << i << ");\n";
} else {
assert(i + 1 == e && "only the last operand can be variadic");
const char *code = R"(
unsigned variadicOperandSize = (this->getNumOperands() - {0}) / {1};
unsigned offset = {2} + variadicOperandSize * {3};
return )";
auto sizeAndOffset =
formatv(code, numNormalOperands, numVariadicOperands,
emittedNormalOperands, emittedVariadicOperands);
const char *const code = R"(
assert(getOperation()->getNumOperands() >= {0});
return {std::next(operand_begin(), {0}), operand_end()};
)";
if (operand.isVariadic()) {
auto &m = opClass.newMethod("Operation::operand_range", operand.name);
m.body() << formatv(code, i);
m.body() << sizeAndOffset
<< "{std::next(operand_begin(), offset), "
"std::next(operand_begin(), offset + variadicOperandSize)};";
++emittedVariadicOperands;
} else {
auto &m = opClass.newMethod("Value *", operand.name);
m.body() << sizeAndOffset << "this->getOperand(offset);";
++emittedNormalOperands;
}
}
}
void OpEmitter::genNamedResultGetters() {
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
const unsigned numResults = op.getNumResults();
const unsigned numVariadicResults = op.getNumVariadicResults();
const unsigned numNormalResults = numResults - numVariadicResults;
// Special case for ops without variadic results: the i-th value is for the
// i-th result defined in the op.
// Special case for ops with one variadic result: the variadic result can
// appear at any place, so the i-th value may not necessarily belong to the
// i-th result definition. we need to calculate the index (range) for each
// result.
if (numVariadicResults <= 1) {
bool emittedVariadicResult = false;
for (unsigned i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);
if (result.name.empty())
continue;
if (result.isVariadic()) {
auto &m = opClass.newMethod("Operation::result_range", result.name);
m.body() << formatv(
" return {{std::next(result_begin(), {0}), "
"std::next(result_begin(), {0} + this->getNumResults() - {1})};",
i, numNormalResults);
emittedVariadicResult = true;
} else {
auto &m = opClass.newMethod("Value *", result.name);
m.body() << " return this->getOperation()->getResult(";
if (emittedVariadicResult)
m.body() << "this->getNumResults() - " << numResults - i;
else
m.body() << i;
m.body() << ");\n";
}
}
return;
}
// If we have more than one variadic results, we need more complicated logic
// to calculate the value range for each result.
if (!op.hasTrait("SameVariadicResultSize")) {
PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
"specification over their sizes");
}
unsigned emittedNormalResults = 0;
unsigned emittedVariadicResults = 0;
for (unsigned i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);
if (result.constraint.isVariadic() || result.name.empty())
if (result.name.empty())
continue;
auto &m = opClass.newMethod("Value *", result.name);
m.body() << " return this->getOperation()->getResult(" << i << ");\n";
const char *code = R"(
unsigned variadicResultSize = (this->getNumResults() - {0}) / {1};
unsigned offset = {2} + variadicResultSize * {3};
return )";
auto sizeAndOffset = formatv(code, numNormalResults, numVariadicResults,
emittedNormalResults, emittedVariadicResults);
if (result.isVariadic()) {
auto &m = opClass.newMethod("Operation::result_range", result.name);
m.body() << sizeAndOffset
<< "{std::next(result_begin(), offset), "
"std::next(result_begin(), offset + variadicResultSize)};";
++emittedVariadicResults;
} else {
auto &m = opClass.newMethod("Value *", result.name);
m.body() << sizeAndOffset << "this->getResult(offset);";
++emittedNormalResults;
}
}
}
@ -505,12 +624,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
// Emit parameters for all return types
if (!useOperandType && !useAttrType) {
for (unsigned i = 0; i != numResults; ++i) {
std::string resultName = op.getResultName(i);
const auto &result = op.getResult(i);
std::string resultName = result.name;
if (resultName.empty())
resultName = formatv("resultType{0}", i);
bool isVariadic = op.getResultTypeConstraint(i).isVariadic();
paramList.append(isVariadic ? ", ArrayRef<Type> " : ", Type ");
paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
paramList.append(resultName);
resultNames.emplace_back(std::move(resultName));
@ -520,12 +639,13 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
// Emit parameters for all arguments (operands and attributes).
int numOperands = 0;
int numAttrs = 0;
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
auto &operand = op.getOperand(numOperands);
paramList.append(operand.constraint.isVariadic() ? ", ArrayRef<Value *> "
: ", Value *");
const auto &operand = op.getOperand(numOperands);
paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> "
: ", Value *");
paramList.append(getArgumentName(op, numOperands));
++numOperands;
} else {
@ -542,33 +662,22 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
}
if (numOperands + numAttrs != op.getNumArgs())
return PrintFatalError(
"op arguments must be either operands or attributes");
PrintFatalError("op arguments must be either operands or attributes");
auto &method =
opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
bool hasVariadicOperand = op.hasVariadicOperand();
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
// Push all result types to the result
if (numResults > 0) {
if (!useOperandType && !useAttrType) {
bool hasVariadicResult = op.hasVariadicResult();
int numNonVariadicResults =
numResults - static_cast<int>(hasVariadicResult);
if (numNonVariadicResults > 0) {
method.body() << " " << builderOpState << "->addTypes({"
<< resultNames.front();
for (int i = 1; i < numNonVariadicResults; ++i) {
method.body() << ", " << resultNames[i];
for (unsigned i = 0; i < numResults; ++i) {
const auto &result = op.getResult(i);
m.body() << " " << builderOpState;
if (result.isVariadic()) {
m.body() << "->addTypes(";
} else {
m.body() << "->types.push_back(";
}
method.body() << "});\n";
}
if (hasVariadicResult) {
method.body() << " " << builderOpState << "->addTypes("
<< resultNames.back() << ");\n";
m.body() << resultNames[i] << ");\n";
}
} else {
std::string resultType;
@ -580,32 +689,27 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
resultType = formatv("{0}.getType()", namedAttr.name);
}
} else {
const char *index =
(numOperands == 1 && hasVariadicOperand) ? ".front()" : "";
const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
resultType =
formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str();
}
method.body() << " " << builderOpState << "->addTypes({" << resultType;
m.body() << " " << builderOpState << "->addTypes({" << resultType;
for (unsigned i = 1; i != numResults; ++i)
method.body() << ", " << resultType;
method.body() << "});\n\n";
m.body() << ", " << resultType;
m.body() << "});\n\n";
}
}
// Push all operands to the result
int numNonVariadicOperands =
numOperands - static_cast<int>(hasVariadicOperand);
if (numNonVariadicOperands > 0) {
method.body() << " " << builderOpState << "->addOperands({"
<< getArgumentName(op, 0);
for (int i = 1; i < numNonVariadicOperands; ++i) {
method.body() << ", " << getArgumentName(op, i);
for (unsigned i = 0; i < numOperands; ++i) {
const auto &operand = op.getOperand(i);
m.body() << " " << builderOpState;
if (operand.isVariadic()) {
m.body() << "->addOperands(";
} else {
m.body() << "->operands.push_back(";
}
method.body() << "});\n";
}
if (hasVariadicOperand) {
method.body() << " " << builderOpState << "->addOperands("
<< getArgumentName(op, numOperands - 1) << ");\n";
m.body() << getArgumentName(op, i) << ");\n";
}
// Push all attributes to the result
@ -613,12 +717,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
if (!namedAttr.attr.isDerivedAttr()) {
bool emitNotNullCheck = namedAttr.attr.isOptional();
if (emitNotNullCheck) {
method.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n";
m.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n";
}
method.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n",
builderOpState, namedAttr.name);
m.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n",
builderOpState, namedAttr.name);
if (emitNotNullCheck) {
method.body() << " }\n";
m.body() << " }\n";
}
}
}
@ -646,13 +750,13 @@ void OpEmitter::genBuilder() {
}
}
auto numResults = op.getNumResults();
bool hasVariadicResult = op.hasVariadicResult();
int numNonVariadicResults = numResults - int(hasVariadicResult);
unsigned numResults = op.getNumResults();
unsigned numVariadicResults = op.getNumVariadicResults();
unsigned numNonVariadicResults = numResults - numVariadicResults;
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
unsigned numOperands = op.getNumOperands();
unsigned numVariadicOperands = op.getNumVariadicOperands();
unsigned numNonVariadicOperands = numOperands - numVariadicOperands;
// Generate default builders that requires all result type, operands, and
// attributes as parameters.
@ -681,15 +785,16 @@ void OpEmitter::genBuilder() {
auto &body = m.body();
// Result types
if (!(hasVariadicResult && numNonVariadicResults == 0))
if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(resultTypes.size()"
<< (hasVariadicResult ? " >= " : " == ") << numNonVariadicResults
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << "->addTypes(resultTypes);\n";
// Operands
if (!(hasVariadicOperand && numNonVariadicOperands == 0))
body << " assert(operands.size()" << (hasVariadicOperand ? " >= " : " == ")
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()"
<< (numVariadicOperands != 0 ? " >= " : " == ")
<< numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n";
body << " " << builderOpState << "->addOperands(operands);\n\n";
@ -703,7 +808,7 @@ void OpEmitter::genBuilder() {
bool useOperandType = op.hasTrait("SameOperandsAndResultType");
bool useAttrType = op.hasTrait("FirstAttrDerivedResultType");
if (!op.hasVariadicResult() && (useOperandType || useAttrType))
if (numVariadicResults == 0 && (useOperandType || useAttrType))
genStandaloneParamBuilder(useOperandType, useAttrType);
}
@ -824,7 +929,7 @@ void OpEmitter::genVerifier() {
auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
bool isOperand) -> void {
// TODO: Handle variadic operand/result verification.
if (value.constraint.isVariadic())
if (value.isVariadic())
return;
// TODO: Commonality between matchers could be extracted to have a more
@ -869,12 +974,12 @@ void OpEmitter::genVerifier() {
}
void OpEmitter::genTraits() {
auto numResults = op.getNumResults();
bool hasVariadicResult = op.hasVariadicResult();
unsigned numResults = op.getNumResults();
unsigned numVariadicResults = op.getNumVariadicResults();
// Add return size trait.
if (hasVariadicResult) {
if (numResults == 1)
if (numVariadicResults != 0) {
if (numResults == numVariadicResults)
opClass.addTrait("VariadicResults");
else
opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl");
@ -898,12 +1003,12 @@ void OpEmitter::genTraits() {
}
// Add variadic size trait and normal op traits.
auto numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand();
unsigned numOperands = op.getNumOperands();
unsigned numVariadicOperands = op.getNumVariadicOperands();
// Add operand size trait.
if (hasVariadicOperand) {
if (numOperands == 1)
if (numVariadicOperands != 0) {
if (numOperands == numVariadicOperands)
opClass.addTrait("VariadicOperands");
else
opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) +

View File

@ -440,7 +440,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
const Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName();
if (rootOp.hasVariadicResult())
if (rootOp.getNumVariadicResults() != 0)
PrintFatalError(
loc, "replacing op with variadic results not supported right now");