[TableGen] Support multiple variadic operands/results

Certain ops can have multiple variadic operands/results, e.g., `tf.DynamicStitch`.
    Even if an op has only one variadic operand/result, it is not necessarily the
    very last one, e.g., `tf.RaggedGather`. This CL enhances TableGen subsystem to be
    able to represent such cases.

    In order to deduce the operand/result value range for each variadic operand,
    currently we only support variadic operands/results all of the same size.
    So two new traits, `SameVariadicOperandSize` and `SameVariadicResultSize` are
    introduced.

--

PiperOrigin-RevId: 245310628
This commit is contained in:
Lei Zhang 2019-04-25 14:45:37 -07:00 committed by Mehdi Amini
parent 22ad45a7aa
commit 6749c21d6e
11 changed files with 409 additions and 182 deletions

View File

@ -759,6 +759,17 @@ def Terminator : NativeOpTrait<"IsTerminator">;
def FirstAttrDerivedResultType : def FirstAttrDerivedResultType :
GenInternalOpTrait<"FirstAttrDerivedResultType">; GenInternalOpTrait<"FirstAttrDerivedResultType">;
// All variadic operands of the op have the same number of values.
// A variadic operand contains an array of values whose array size is only
// known at runtime. This trait requires all variadic operands of an op
// to have the same array size.
def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
// All variadic results of the op have the same number of values.
// A variadic result contains an array of values whose array size is only
// known at runtime. This trait requires all variadic results of an op
// to have the same array size.
def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// Op definitions // Op definitions
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -48,10 +48,12 @@ struct NamedAttribute {
Attribute attr; Attribute attr;
}; };
// A struct wrapping an op operand/result and its name together // A struct wrapping an op operand/result's constraint and its name together
struct NamedTypeConstraint { struct NamedTypeConstraint {
// Returns true if this operand has constraint that need to be satisfied. // Returns true if this operand/result has constraint to be satisfied.
bool hasPredicate() const; bool hasPredicate() const;
// Returns true if this operand/result is variadic.
bool isVariadic() const;
llvm::StringRef name; llvm::StringRef name;
TypeConstraint constraint; TypeConstraint constraint;

View File

@ -82,8 +82,8 @@ public:
// Returns the `index`-th result's name. // Returns the `index`-th result's name.
StringRef getResultName(int index) const; StringRef getResultName(int index) const;
// Returns true if this operation has a variadic result. // Returns the number of variadic results in this operation.
bool hasVariadicResult() const; unsigned getNumVariadicResults() const;
// Op attribute interators. // Op attribute interators.
using attribute_iterator = const NamedAttribute *; using attribute_iterator = const NamedAttribute *;
@ -112,8 +112,8 @@ public:
return operands[index]; return operands[index];
} }
// Returns true if this operation has a variadic operand. // Returns the number of variadic operands in this operation.
bool hasVariadicOperand() const; unsigned getNumVariadicOperands() const;
// Returns the total number of arguments. // Returns the total number of arguments.
int getNumArgs() const { return arguments.size(); } int getNumArgs() const { return arguments.size(); }

View File

@ -23,3 +23,7 @@ using namespace mlir;
bool tblgen::NamedTypeConstraint::hasPredicate() const { bool tblgen::NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull(); return !constraint.getPredicate().isNull();
} }
bool tblgen::NamedTypeConstraint::isVariadic() const {
return constraint.isVariadic();
}

View File

@ -82,8 +82,10 @@ StringRef tblgen::Operator::getResultName(int index) const {
return results->getArgNameStr(index); return results->getArgNameStr(index);
} }
bool tblgen::Operator::hasVariadicResult() const { unsigned tblgen::Operator::getNumVariadicResults() const {
return !results.empty() && results.back().constraint.isVariadic(); return std::count_if(
results.begin(), results.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
} }
int tblgen::Operator::getNumNativeAttributes() const { int tblgen::Operator::getNumNativeAttributes() const {
@ -98,8 +100,10 @@ const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
return attributes[index]; return attributes[index];
} }
bool tblgen::Operator::hasVariadicOperand() const { unsigned tblgen::Operator::getNumVariadicOperands() const {
return !operands.empty() && operands.back().constraint.isVariadic(); return std::count_if(
operands.begin(), operands.end(),
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
} }
StringRef tblgen::Operator::getArgName(int index) const { StringRef tblgen::Operator::getArgName(int index) const {
@ -222,13 +226,6 @@ void tblgen::Operator::populateOpStructure() {
} }
} }
// Verify that only the last operand can be variadic.
for (int i = 0, e = operands.size() - 1; i < e; ++i) {
if (operands[i].constraint.isVariadic())
PrintFatalError(def.getLoc(),
"only the last operand allowed to be variadic");
}
auto *resultsDag = def.getValueAsDag("results"); auto *resultsDag = def.getValueAsDag("results");
auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator()); auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
if (!outsOp || outsOp->getDef()->getName() != "outs") { if (!outsOp || outsOp->getDef()->getName() != "outs") {
@ -246,13 +243,6 @@ void tblgen::Operator::populateOpStructure() {
results.push_back({name, TypeConstraint(resultDef)}); results.push_back({name, TypeConstraint(resultDef)});
} }
// Verify that only the last result can be variadic.
for (int i = 0, e = results.size() - 1; i < e; ++i) {
if (results[i].constraint.isVariadic())
PrintFatalError(def.getLoc(),
"only the last result allowed to be variadic");
}
auto traitListInit = def.getValueAsListInit("traits"); auto traitListInit = def.getValueAsListInit("traits");
if (!traitListInit) if (!traitListInit)
return; return;

View File

@ -1,28 +0,0 @@
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
include "mlir/IR/OpBase.td"
def NS_OpA : Op<"op_same_value_type", [SameValueType]> {
let arguments = (ins Tensor:$input);
let results = (outs Tensor:$result);
}
// Test that with SameValueType trait we can generate a builder without
// requiring result type
// ---
// CHECK-LABEL: OpA::build(Builder *, OperationState *tblgen_state, Value *input)
// CHECK: tblgen_state->addTypes({input->getType()});
def NS_OpB : Op<"op_same_value_type_variadic_input", [SameValueType]> {
let arguments = (ins Variadic<Tensor>:$input);
let results = (outs Tensor:$result);
}
// Test that if the only operand is variadic, we acess the first value in the
// pack to set result type
// ---
// CHECK-LABEL: OpB::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
// CHECK: tblgen_state->addTypes({input.front()->getType()});

View File

@ -2,7 +2,7 @@
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def OpA : Op<"one_operand_op", []> { def OpA : Op<"one_normal_operand_op", []> {
let arguments = (ins I32:$input); let arguments = (ins I32:$input);
} }
@ -10,7 +10,7 @@ def OpA : Op<"one_operand_op", []> {
// CHECK: void OpA::build // CHECK: void OpA::build
// CHECK-SAME: Value *input // CHECK-SAME: Value *input
// CHECK: tblgen_state->addOperands({input}); // CHECK: tblgen_state->operands.push_back(input);
// CHECK: void OpA::build // CHECK: void OpA::build
// CHECK-SAME: ArrayRef<Value *> operands // CHECK-SAME: ArrayRef<Value *> operands
@ -21,11 +21,72 @@ def OpA : Op<"one_operand_op", []> {
// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32)))) // CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32))))
// CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer"); // CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer");
def OpB : Op<"variadic_operand_op", []> { def OpB : Op<"one_variadic_operand_op", []> {
let arguments = (ins Variadic<I32>:$input); let arguments = (ins Variadic<I32>:$input);
} }
// CHECK-LABEL: OpB::build // CHECK-LABEL: OpB::build
// CHECK-SAME: ArrayRef<Value *> input // CHECK-SAME: ArrayRef<Value *> input
// CHECK-NOT: assert // CHECK-NOT: assert
// CHECK: tblgen_state->addOperands(input); // CHECK: tblgen_state->addOperands(input);
def OpC : Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<Tensor>:$input1, Variadic<Tensor>:$input2);
}
// CHECK-LABEL: Operation::operand_range OpC::input1()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0;
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
// CHECK-LABEL: Operation::operand_range OpC::input2()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1;
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
// CHECK-LABEL: OpC::build
// CHECK-NEXT: tblgen_state->addOperands(input1);
// CHECK-NEXT: tblgen_state->addOperands(input2);
def OpD : Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
let arguments = (ins Variadic<Tensor>:$input1, Tensor:$input2, Variadic<Tensor>:$input3);
}
// CHECK-LABEL: Operation::operand_range OpD::input1()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0;
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
// CHECK-LABEL: Value *OpD::input2()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1;
// CHECK-NEXT: return this->getOperand(offset);
// CHECK-LABEL: Operation::operand_range OpD::input3()
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
// CHECK-NEXT: unsigned offset = 1 + variadicOperandSize * 1;
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
// CHECK-LABEL: OpD::build
// CHECK-NEXT: tblgen_state->addOperands(input1);
// CHECK-NEXT: tblgen_state->operands.push_back(input2);
// CHECK-NEXT: tblgen_state->addOperands(input3);
def OpE : Op<"one_variadic_among_multi_normal_inputs_op", []> {
let arguments = (ins Tensor:$input1, Tensor:$input2, Variadic<Tensor>:$input3, Tensor:$input4, Tensor:$input5);
}
// CHECK-LABEL: Value *OpE::input1()
// CHECK-NEXT: return this->getOperation()->getOperand(0);
// CHECK-LABEL: Value *OpE::input2()
// CHECK-NEXT: return this->getOperation()->getOperand(1);
// CHECK-LABEL: Operation::operand_range OpE::input3()
// CHECK-NEXT: return {std::next(operand_begin(), 2), std::next(operand_begin(), 2 + this->getNumOperands() - 4)};
// CHECK-LABEL: Value *OpE::input4()
// CHECK-NEXT: return this->getOperation()->getOperand(this->getNumOperands() - 2);
// CHECK-LABEL: Value *OpE::input5()
// CHECK-NEXT: return this->getOperation()->getOperand(this->getNumOperands() - 1);

View File

@ -2,82 +2,160 @@
include "mlir/IR/OpBase.td" include "mlir/IR/OpBase.td"
def OneResultOp : Op<"one_result_op", []> { def OpA : Op<"one_normal_result_op", []> {
let results = (outs I32:$result); let results = (outs I32:$result);
} }
// CHECK-LABEL: OneResultOp definitions // CHECK-LABEL: Value *OpA::result()
// CHECK-NEXT: return this->getOperation()->getResult(0)
// CHECK: void OneResultOp::build // CHECK-LABEL: void OpA::build
// CHECK: ArrayRef<Type> resultTypes, ArrayRef<Value *> operands // CHECK: ArrayRef<Type> resultTypes, ArrayRef<Value *> operands
// CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types"); // CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
// CHECK-NEXT: tblgen_state->addTypes(resultTypes); // CHECK-NEXT: tblgen_state->addTypes(resultTypes);
// CHECK: LogicalResult OneResultOp::verify() { // CHECK-LABEL: LogicalResult OpA::verify()
// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32)))) // CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32))))
// CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer"); // CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer");
def OpB : Op<"same_input_output_type_op", [SameValueType]> {
def SameTypeOp : Op<"same_type_op", [SameValueType]> {
let arguments = (ins I32:$x); let arguments = (ins I32:$x);
let results = (outs I32:$y); let results = (outs I32:$y);
} }
// CHECK-LABEL: SameTypeOp definitions // CHECK-LABEL: OpB definitions
// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Type y, Value *x) // CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Type y, Value *x)
// CHECK: tblgen_state->addTypes({y}); // CHECK: tblgen_state->types.push_back(y);
// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Value *x) // CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Value *x)
// CHECK: tblgen_state->addTypes({x->getType()}); // CHECK: tblgen_state->addTypes({x->getType()});
def ThreeResultOp : Op<"three_result_op", []> { def OpC : Op<"three_normal_result_op", []> {
let results = (outs I32:$x, /*unnamed*/I32, I32:$z); let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
} }
// CHECK-LABEL: ThreeResultOp definitions // CHECK-LABEL: OpC definitions
// CHECK: void ThreeResultOp::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z) // CHECK: void OpC::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z)
// CHECK: tblgen_state->addTypes({x, resultType1, z}); // CHECK-NEXT: tblgen_state->types.push_back(x)
// CHECK-NEXT: tblgen_state->types.push_back(resultType1)
// CHECK-NEXT: tblgen_state->types.push_back(z)
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">; def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
def TypeAttrResultTypeOp : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> { def OpD : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32); let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
let results = (outs Tensor:$y); let results = (outs Tensor:$y);
} }
// CHECK-LABEL: TypeAttrResultTypeOp definitions // CHECK-LABEL: OpD definitions
// CHECK: void TypeAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32) // CHECK: void OpD::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32)
// CHECK: tblgen_state->addTypes({attr.getValue()}); // CHECK: tblgen_state->addTypes({attr.getValue()});
def ValueAttrResultTypeOp : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> { def OpE : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
let arguments = (ins I32:$x, F32Attr:$attr); let arguments = (ins I32:$x, F32Attr:$attr);
let results = (outs Tensor:$y); let results = (outs Tensor:$y);
} }
// CHECK-LABEL: ValueAttrResultTypeOp definitions // CHECK-LABEL: OpE definitions
// CHECK: void ValueAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr) // CHECK: void OpE::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr)
// CHECK: tblgen_state->addTypes({attr.getType()}); // CHECK: tblgen_state->addTypes({attr.getType()});
def VariadicResultAloneOp : Op<"variadic_alone_op", []> { def OpF : Op<"one_variadic_result_op", []> {
let results = (outs Variadic<I32>:$x); let results = (outs Variadic<I32>:$x);
} }
// CHECK-LABEL: VariadicResultAloneOp definitions // CHECK-LABEL: Operation::result_range OpF::x()
// CHECK-NEXT: return {std::next(result_begin(), 0), std::next(result_begin(), 0 + this->getNumResults() - 0)};
// CHECK-LABEL: void VariadicResultAloneOp::build // CHECK-LABEL: void OpF::build
// CHECK-SAME: ArrayRef<Type> x // CHECK-SAME: ArrayRef<Type> x
// CHECK-NOT: assert // CHECK-NOT: assert
// CHECK: tblgen_state->addTypes(x); // CHECK: tblgen_state->addTypes(x);
def OpG : Op<"one_normal_and_one_variadic_result_op", []> {
def VariadicResultOp : Op<"variadic_op", []> {
let results = (outs I32:$x, Variadic<I32>:$y); let results = (outs I32:$x, Variadic<I32>:$y);
} }
// CHECK-LABEL: VariadicResultOp definitions // CHECK-LABEL: OpG definitions
// CHECK: void VariadicResultOp::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y) // CHECK: void OpG::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y)
// CHECK: tblgen_state->addTypes({x}); // CHECK-NEXT: tblgen_state->types.push_back(x);
// CHECK: tblgen_state->addTypes(y); // CHECK-NEXT: tblgen_state->addTypes(y);
// CHECK: void VariadicResultOp::build // CHECK: void OpG::build
// CHECK: ArrayRef<Type> resultTypes // CHECK: ArrayRef<Type> resultTypes
// CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types"); // CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types");
// CHECK-NEXT: tblgen_state->addTypes(resultTypes); // CHECK-NEXT: tblgen_state->addTypes(resultTypes);
def OpH : Op<"all_variadic_results_op", [SameVariadicResultSize]> {
let results = (outs Variadic<Tensor>:$output1, Variadic<Tensor>:$output2);
}
// CHECK-LABEL: Operation::result_range OpH::output1()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0;
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
// CHECK-LABEL: Operation::result_range OpH::output2()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1;
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
// CHECK-LABEL: OpH::build
// CHECK-NEXT: tblgen_state->addTypes(output1);
// CHECK-NEXT: tblgen_state->addTypes(output2);
def OpI : Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
let results = (outs Variadic<Tensor>:$output1, Tensor:$output2, Variadic<Tensor>:$output3);
}
// CHECK-LABEL: Operation::result_range OpI::output1()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0;
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
// CHECK-LABEL: Value *OpI::output2()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1;
// CHECK-NEXT: return this->getResult(offset);
// CHECK-LABEL: Operation::result_range OpI::output3()
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
// CHECK-NEXT: unsigned offset = 1 + variadicResultSize * 1;
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
// CHECK-LABEL: OpI::build
// CHECK-NEXT: tblgen_state->addTypes(output1);
// CHECK-NEXT: tblgen_state->types.push_back(output2);
// CHECK-NEXT: tblgen_state->addTypes(output3);
def OpJ : Op<"one_variadic_among_multi_normal_results_op", []> {
let results = (outs Tensor:$output1, Tensor:$output2, Variadic<Tensor>:$output3, Tensor:$output4, Tensor:$output5);
}
// CHECK-LABEL: Value *OpJ::output1()
// CHECK-NEXT: return this->getOperation()->getResult(0);
// CHECK-LABEL: Value *OpJ::output2()
// CHECK-NEXT: return this->getOperation()->getResult(1);
// CHECK-LABEL: Operation::result_range OpJ::output3()
// CHECK-NEXT: return {std::next(result_begin(), 2), std::next(result_begin(), 2 + this->getNumResults() - 4)};
// CHECK-LABEL: Value *OpJ::output4()
// CHECK-NEXT: return this->getOperation()->getResult(this->getNumResults() - 2);
// CHECK-LABEL: Value *OpJ::output5()
// CHECK-NEXT: return this->getOperation()->getResult(this->getNumResults() - 1);
// Test that if the only operand is variadic, we acess the first value in the
// pack to set result type
// ---
def OpK : Op<"only_input_is_variadic_with_same_value_type_op", [SameValueType]> {
let arguments = (ins Variadic<Tensor>:$input);
let results = (outs Tensor:$result);
}
// CHECK-LABEL: OpK::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
// CHECK: tblgen_state->addTypes({input.front()->getType()});

View File

@ -75,10 +75,14 @@ static StringLoc findNextVariable(StringRef str) {
return {startPos, endPos - startPos}; return {startPos, endPos - startPos};
} }
// Check if `name` is the name of the variadic argument of `op`. The variadic // Check if `name` is the name of the variadic operand of `op`. The variadic
// argument can only appear at the last position in the list of arguments. // operand can only appear at the last position in the list of operands.
static bool isVariadicArgumentName(const tblgen::Operator &op, StringRef name) { static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
return op.hasVariadicOperand() && op.getArgName(op.getNumArgs() - 1) == name; unsigned numOperands = op.getNumOperands();
if (numOperands == 0)
return false;
const auto &operand = op.getOperand(numOperands - 1);
return operand.isVariadic() && operand.name == name;
} }
// Check if `result` is a known name of a result of `op`. // Check if `result` is a known name of a result of `op`.
@ -127,9 +131,9 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
// First, insert the non-matched part as is. // First, insert the non-matched part as is.
bs << builderStrRef.substr(0, loc.pos); bs << builderStrRef.substr(0, loc.pos);
// Then, rewrite the name based on its kind. // Then, rewrite the name based on its kind.
bool isVariadicArg = isVariadicArgumentName(op, name); bool isVariadicOperand = isVariadicOperandName(op, name);
if (isOperandName(op, name)) { if (isOperandName(op, name)) {
auto result = isVariadicArg auto result = isVariadicOperand
? formatv("lookupValues(op.{0}())", name) ? formatv("lookupValues(op.{0}())", name)
: formatv("valueMapping.lookup(op.{0}())", name); : formatv("valueMapping.lookup(op.{0}())", name);
bs << result; bs << result;

View File

@ -251,8 +251,9 @@ OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
} }
void OpMethodBody::writeTo(raw_ostream &os) const { void OpMethodBody::writeTo(raw_ostream &os) const {
os << body; auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
if (body.empty() || body.back() != '\n') os << bodyRef;
if (bodyRef.empty() || bodyRef.back() != '\n')
os << "\n"; os << "\n";
} }
@ -455,35 +456,153 @@ void OpEmitter::genAttrGetters() {
} }
void OpEmitter::genNamedOperandGetters() { void OpEmitter::genNamedOperandGetters() {
for (int i = 0, e = op.getNumOperands(); i != e; ++i) { const unsigned numOperands = op.getNumOperands();
const unsigned numVariadicOperands = op.getNumVariadicOperands();
const unsigned numNormalOperands = numOperands - numVariadicOperands;
// Special case for ops without variadic operands: the i-th value is for the
// i-th operand defined in the op.
// Special case for ops with one variadic operand: the variadic operand can
// appear at any place, so the i-th value may not necessarily belong to the
// i-th operand definition. we need to calculate the index (range) for each
// operand.
if (numVariadicOperands <= 1) {
bool emittedVariadicOperand = false;
for (unsigned i = 0; i != numOperands; ++i) {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
if (operand.isVariadic()) {
auto &m = opClass.newMethod("Operation::operand_range", operand.name);
m.body() << formatv(
" return {{std::next(operand_begin(), {0}), "
"std::next(operand_begin(), {0} + this->getNumOperands() - {1})};",
i, numNormalOperands);
emittedVariadicOperand = true;
} else {
auto &m = opClass.newMethod("Value *", operand.name);
m.body() << " return this->getOperation()->getOperand(";
if (emittedVariadicOperand)
m.body() << "this->getNumOperands() - " << numOperands - i;
else
m.body() << i;
m.body() << ");\n";
}
}
return;
}
// If we have more than one variadic operands, we need more complicated logic
// to calculate the value range for each operand.
if (!op.hasTrait("SameVariadicOperandSize")) {
PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
"specification over their sizes");
}
unsigned emittedNormalOperands = 0;
unsigned emittedVariadicOperands = 0;
for (unsigned i = 0; i != numOperands; ++i) {
const auto &operand = op.getOperand(i); const auto &operand = op.getOperand(i);
if (operand.name.empty()) if (operand.name.empty())
continue; continue;
if (!operand.constraint.isVariadic()) { const char *code = R"(
auto &m = opClass.newMethod("Value *", operand.name); unsigned variadicOperandSize = (this->getNumOperands() - {0}) / {1};
m.body() << " return this->getOperation()->getOperand(" << i << ");\n"; unsigned offset = {2} + variadicOperandSize * {3};
} else { return )";
assert(i + 1 == e && "only the last operand can be variadic"); auto sizeAndOffset =
formatv(code, numNormalOperands, numVariadicOperands,
emittedNormalOperands, emittedVariadicOperands);
const char *const code = R"( if (operand.isVariadic()) {
assert(getOperation()->getNumOperands() >= {0});
return {std::next(operand_begin(), {0}), operand_end()};
)";
auto &m = opClass.newMethod("Operation::operand_range", operand.name); auto &m = opClass.newMethod("Operation::operand_range", operand.name);
m.body() << formatv(code, i); m.body() << sizeAndOffset
<< "{std::next(operand_begin(), offset), "
"std::next(operand_begin(), offset + variadicOperandSize)};";
++emittedVariadicOperands;
} else {
auto &m = opClass.newMethod("Value *", operand.name);
m.body() << sizeAndOffset << "this->getOperand(offset);";
++emittedNormalOperands;
} }
} }
} }
void OpEmitter::genNamedResultGetters() { void OpEmitter::genNamedResultGetters() {
for (int i = 0, e = op.getNumResults(); i != e; ++i) { const unsigned numResults = op.getNumResults();
const unsigned numVariadicResults = op.getNumVariadicResults();
const unsigned numNormalResults = numResults - numVariadicResults;
// Special case for ops without variadic results: the i-th value is for the
// i-th result defined in the op.
// Special case for ops with one variadic result: the variadic result can
// appear at any place, so the i-th value may not necessarily belong to the
// i-th result definition. we need to calculate the index (range) for each
// result.
if (numVariadicResults <= 1) {
bool emittedVariadicResult = false;
for (unsigned i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i);
if (result.name.empty())
continue;
if (result.isVariadic()) {
auto &m = opClass.newMethod("Operation::result_range", result.name);
m.body() << formatv(
" return {{std::next(result_begin(), {0}), "
"std::next(result_begin(), {0} + this->getNumResults() - {1})};",
i, numNormalResults);
emittedVariadicResult = true;
} else {
auto &m = opClass.newMethod("Value *", result.name);
m.body() << " return this->getOperation()->getResult(";
if (emittedVariadicResult)
m.body() << "this->getNumResults() - " << numResults - i;
else
m.body() << i;
m.body() << ");\n";
}
}
return;
}
// If we have more than one variadic results, we need more complicated logic
// to calculate the value range for each result.
if (!op.hasTrait("SameVariadicResultSize")) {
PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
"specification over their sizes");
}
unsigned emittedNormalResults = 0;
unsigned emittedVariadicResults = 0;
for (unsigned i = 0; i != numResults; ++i) {
const auto &result = op.getResult(i); const auto &result = op.getResult(i);
if (result.constraint.isVariadic() || result.name.empty()) if (result.name.empty())
continue; continue;
auto &m = opClass.newMethod("Value *", result.name); const char *code = R"(
m.body() << " return this->getOperation()->getResult(" << i << ");\n"; unsigned variadicResultSize = (this->getNumResults() - {0}) / {1};
unsigned offset = {2} + variadicResultSize * {3};
return )";
auto sizeAndOffset = formatv(code, numNormalResults, numVariadicResults,
emittedNormalResults, emittedVariadicResults);
if (result.isVariadic()) {
auto &m = opClass.newMethod("Operation::result_range", result.name);
m.body() << sizeAndOffset
<< "{std::next(result_begin(), offset), "
"std::next(result_begin(), offset + variadicResultSize)};";
++emittedVariadicResults;
} else {
auto &m = opClass.newMethod("Value *", result.name);
m.body() << sizeAndOffset << "this->getResult(offset);";
++emittedNormalResults;
}
} }
} }
@ -505,12 +624,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
// Emit parameters for all return types // Emit parameters for all return types
if (!useOperandType && !useAttrType) { if (!useOperandType && !useAttrType) {
for (unsigned i = 0; i != numResults; ++i) { for (unsigned i = 0; i != numResults; ++i) {
std::string resultName = op.getResultName(i); const auto &result = op.getResult(i);
std::string resultName = result.name;
if (resultName.empty()) if (resultName.empty())
resultName = formatv("resultType{0}", i); resultName = formatv("resultType{0}", i);
bool isVariadic = op.getResultTypeConstraint(i).isVariadic(); paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
paramList.append(isVariadic ? ", ArrayRef<Type> " : ", Type ");
paramList.append(resultName); paramList.append(resultName);
resultNames.emplace_back(std::move(resultName)); resultNames.emplace_back(std::move(resultName));
@ -520,12 +639,13 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
// Emit parameters for all arguments (operands and attributes). // Emit parameters for all arguments (operands and attributes).
int numOperands = 0; int numOperands = 0;
int numAttrs = 0; int numAttrs = 0;
for (int i = 0, e = op.getNumArgs(); i < e; ++i) { for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i); auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) { if (argument.is<tblgen::NamedTypeConstraint *>()) {
auto &operand = op.getOperand(numOperands); const auto &operand = op.getOperand(numOperands);
paramList.append(operand.constraint.isVariadic() ? ", ArrayRef<Value *> " paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> "
: ", Value *"); : ", Value *");
paramList.append(getArgumentName(op, numOperands)); paramList.append(getArgumentName(op, numOperands));
++numOperands; ++numOperands;
} else { } else {
@ -542,33 +662,22 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
} }
if (numOperands + numAttrs != op.getNumArgs()) if (numOperands + numAttrs != op.getNumArgs())
return PrintFatalError( PrintFatalError("op arguments must be either operands or attributes");
"op arguments must be either operands or attributes");
auto &method = auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
bool hasVariadicOperand = op.hasVariadicOperand();
// Push all result types to the result // Push all result types to the result
if (numResults > 0) { if (numResults > 0) {
if (!useOperandType && !useAttrType) { if (!useOperandType && !useAttrType) {
bool hasVariadicResult = op.hasVariadicResult(); for (unsigned i = 0; i < numResults; ++i) {
int numNonVariadicResults = const auto &result = op.getResult(i);
numResults - static_cast<int>(hasVariadicResult); m.body() << " " << builderOpState;
if (result.isVariadic()) {
if (numNonVariadicResults > 0) { m.body() << "->addTypes(";
method.body() << " " << builderOpState << "->addTypes({" } else {
<< resultNames.front(); m.body() << "->types.push_back(";
for (int i = 1; i < numNonVariadicResults; ++i) {
method.body() << ", " << resultNames[i];
} }
method.body() << "});\n"; m.body() << resultNames[i] << ");\n";
}
if (hasVariadicResult) {
method.body() << " " << builderOpState << "->addTypes("
<< resultNames.back() << ");\n";
} }
} else { } else {
std::string resultType; std::string resultType;
@ -580,32 +689,27 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
resultType = formatv("{0}.getType()", namedAttr.name); resultType = formatv("{0}.getType()", namedAttr.name);
} }
} else { } else {
const char *index = const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
(numOperands == 1 && hasVariadicOperand) ? ".front()" : "";
resultType = resultType =
formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str(); formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str();
} }
method.body() << " " << builderOpState << "->addTypes({" << resultType; m.body() << " " << builderOpState << "->addTypes({" << resultType;
for (unsigned i = 1; i != numResults; ++i) for (unsigned i = 1; i != numResults; ++i)
method.body() << ", " << resultType; m.body() << ", " << resultType;
method.body() << "});\n\n"; m.body() << "});\n\n";
} }
} }
// Push all operands to the result // Push all operands to the result
int numNonVariadicOperands = for (unsigned i = 0; i < numOperands; ++i) {
numOperands - static_cast<int>(hasVariadicOperand); const auto &operand = op.getOperand(i);
if (numNonVariadicOperands > 0) { m.body() << " " << builderOpState;
method.body() << " " << builderOpState << "->addOperands({" if (operand.isVariadic()) {
<< getArgumentName(op, 0); m.body() << "->addOperands(";
for (int i = 1; i < numNonVariadicOperands; ++i) { } else {
method.body() << ", " << getArgumentName(op, i); m.body() << "->operands.push_back(";
} }
method.body() << "});\n"; m.body() << getArgumentName(op, i) << ");\n";
}
if (hasVariadicOperand) {
method.body() << " " << builderOpState << "->addOperands("
<< getArgumentName(op, numOperands - 1) << ");\n";
} }
// Push all attributes to the result // Push all attributes to the result
@ -613,12 +717,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
if (!namedAttr.attr.isDerivedAttr()) { if (!namedAttr.attr.isDerivedAttr()) {
bool emitNotNullCheck = namedAttr.attr.isOptional(); bool emitNotNullCheck = namedAttr.attr.isOptional();
if (emitNotNullCheck) { if (emitNotNullCheck) {
method.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n"; m.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n";
} }
method.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n", m.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n",
builderOpState, namedAttr.name); builderOpState, namedAttr.name);
if (emitNotNullCheck) { if (emitNotNullCheck) {
method.body() << " }\n"; m.body() << " }\n";
} }
} }
} }
@ -646,13 +750,13 @@ void OpEmitter::genBuilder() {
} }
} }
auto numResults = op.getNumResults(); unsigned numResults = op.getNumResults();
bool hasVariadicResult = op.hasVariadicResult(); unsigned numVariadicResults = op.getNumVariadicResults();
int numNonVariadicResults = numResults - int(hasVariadicResult); unsigned numNonVariadicResults = numResults - numVariadicResults;
auto numOperands = op.getNumOperands(); unsigned numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand(); unsigned numVariadicOperands = op.getNumVariadicOperands();
int numNonVariadicOperands = numOperands - int(hasVariadicOperand); unsigned numNonVariadicOperands = numOperands - numVariadicOperands;
// Generate default builders that requires all result type, operands, and // Generate default builders that requires all result type, operands, and
// attributes as parameters. // attributes as parameters.
@ -681,15 +785,16 @@ void OpEmitter::genBuilder() {
auto &body = m.body(); auto &body = m.body();
// Result types // Result types
if (!(hasVariadicResult && numNonVariadicResults == 0)) if (numVariadicResults == 0 || numNonVariadicResults != 0)
body << " assert(resultTypes.size()" body << " assert(resultTypes.size()"
<< (hasVariadicResult ? " >= " : " == ") << numNonVariadicResults << (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
<< "u && \"mismatched number of return types\");\n"; << "u && \"mismatched number of return types\");\n";
body << " " << builderOpState << "->addTypes(resultTypes);\n"; body << " " << builderOpState << "->addTypes(resultTypes);\n";
// Operands // Operands
if (!(hasVariadicOperand && numNonVariadicOperands == 0)) if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
body << " assert(operands.size()" << (hasVariadicOperand ? " >= " : " == ") body << " assert(operands.size()"
<< (numVariadicOperands != 0 ? " >= " : " == ")
<< numNonVariadicOperands << numNonVariadicOperands
<< "u && \"mismatched number of parameters\");\n"; << "u && \"mismatched number of parameters\");\n";
body << " " << builderOpState << "->addOperands(operands);\n\n"; body << " " << builderOpState << "->addOperands(operands);\n\n";
@ -703,7 +808,7 @@ void OpEmitter::genBuilder() {
bool useOperandType = op.hasTrait("SameOperandsAndResultType"); bool useOperandType = op.hasTrait("SameOperandsAndResultType");
bool useAttrType = op.hasTrait("FirstAttrDerivedResultType"); bool useAttrType = op.hasTrait("FirstAttrDerivedResultType");
if (!op.hasVariadicResult() && (useOperandType || useAttrType)) if (numVariadicResults == 0 && (useOperandType || useAttrType))
genStandaloneParamBuilder(useOperandType, useAttrType); genStandaloneParamBuilder(useOperandType, useAttrType);
} }
@ -824,7 +929,7 @@ void OpEmitter::genVerifier() {
auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index, auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
bool isOperand) -> void { bool isOperand) -> void {
// TODO: Handle variadic operand/result verification. // TODO: Handle variadic operand/result verification.
if (value.constraint.isVariadic()) if (value.isVariadic())
return; return;
// TODO: Commonality between matchers could be extracted to have a more // TODO: Commonality between matchers could be extracted to have a more
@ -869,12 +974,12 @@ void OpEmitter::genVerifier() {
} }
void OpEmitter::genTraits() { void OpEmitter::genTraits() {
auto numResults = op.getNumResults(); unsigned numResults = op.getNumResults();
bool hasVariadicResult = op.hasVariadicResult(); unsigned numVariadicResults = op.getNumVariadicResults();
// Add return size trait. // Add return size trait.
if (hasVariadicResult) { if (numVariadicResults != 0) {
if (numResults == 1) if (numResults == numVariadicResults)
opClass.addTrait("VariadicResults"); opClass.addTrait("VariadicResults");
else else
opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl"); opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl");
@ -898,12 +1003,12 @@ void OpEmitter::genTraits() {
} }
// Add variadic size trait and normal op traits. // Add variadic size trait and normal op traits.
auto numOperands = op.getNumOperands(); unsigned numOperands = op.getNumOperands();
bool hasVariadicOperand = op.hasVariadicOperand(); unsigned numVariadicOperands = op.getNumVariadicOperands();
// Add operand size trait. // Add operand size trait.
if (hasVariadicOperand) { if (numVariadicOperands != 0) {
if (numOperands == 1) if (numOperands == numVariadicOperands)
opClass.addTrait("VariadicOperands"); opClass.addTrait("VariadicOperands");
else else
opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) + opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) +

View File

@ -440,7 +440,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
const Operator &rootOp = pattern.getSourceRootOp(); const Operator &rootOp = pattern.getSourceRootOp();
auto rootName = rootOp.getOperationName(); auto rootName = rootOp.getOperationName();
if (rootOp.hasVariadicResult()) if (rootOp.getNumVariadicResults() != 0)
PrintFatalError( PrintFatalError(
loc, "replacing op with variadic results not supported right now"); loc, "replacing op with variadic results not supported right now");