forked from OSchip/llvm-project
[TableGen] Support multiple variadic operands/results
Certain ops can have multiple variadic operands/results, e.g., `tf.DynamicStitch`. Even if an op has only one variadic operand/result, it is not necessarily the very last one, e.g., `tf.RaggedGather`. This CL enhances TableGen subsystem to be able to represent such cases. In order to deduce the operand/result value range for each variadic operand, currently we only support variadic operands/results all of the same size. So two new traits, `SameVariadicOperandSize` and `SameVariadicResultSize` are introduced. -- PiperOrigin-RevId: 245310628
This commit is contained in:
parent
22ad45a7aa
commit
6749c21d6e
|
@ -759,6 +759,17 @@ def Terminator : NativeOpTrait<"IsTerminator">;
|
|||
def FirstAttrDerivedResultType :
|
||||
GenInternalOpTrait<"FirstAttrDerivedResultType">;
|
||||
|
||||
// All variadic operands of the op have the same number of values.
|
||||
// A variadic operand contains an array of values whose array size is only
|
||||
// known at runtime. This trait requires all variadic operands of an op
|
||||
// to have the same array size.
|
||||
def SameVariadicOperandSize : GenInternalOpTrait<"SameVariadicOperandSize">;
|
||||
// All variadic results of the op have the same number of values.
|
||||
// A variadic result contains an array of values whose array size is only
|
||||
// known at runtime. This trait requires all variadic results of an op
|
||||
// to have the same array size.
|
||||
def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -48,10 +48,12 @@ struct NamedAttribute {
|
|||
Attribute attr;
|
||||
};
|
||||
|
||||
// A struct wrapping an op operand/result and its name together
|
||||
// A struct wrapping an op operand/result's constraint and its name together
|
||||
struct NamedTypeConstraint {
|
||||
// Returns true if this operand has constraint that need to be satisfied.
|
||||
// Returns true if this operand/result has constraint to be satisfied.
|
||||
bool hasPredicate() const;
|
||||
// Returns true if this operand/result is variadic.
|
||||
bool isVariadic() const;
|
||||
|
||||
llvm::StringRef name;
|
||||
TypeConstraint constraint;
|
||||
|
|
|
@ -82,8 +82,8 @@ public:
|
|||
// Returns the `index`-th result's name.
|
||||
StringRef getResultName(int index) const;
|
||||
|
||||
// Returns true if this operation has a variadic result.
|
||||
bool hasVariadicResult() const;
|
||||
// Returns the number of variadic results in this operation.
|
||||
unsigned getNumVariadicResults() const;
|
||||
|
||||
// Op attribute interators.
|
||||
using attribute_iterator = const NamedAttribute *;
|
||||
|
@ -112,8 +112,8 @@ public:
|
|||
return operands[index];
|
||||
}
|
||||
|
||||
// Returns true if this operation has a variadic operand.
|
||||
bool hasVariadicOperand() const;
|
||||
// Returns the number of variadic operands in this operation.
|
||||
unsigned getNumVariadicOperands() const;
|
||||
|
||||
// Returns the total number of arguments.
|
||||
int getNumArgs() const { return arguments.size(); }
|
||||
|
|
|
@ -23,3 +23,7 @@ using namespace mlir;
|
|||
bool tblgen::NamedTypeConstraint::hasPredicate() const {
|
||||
return !constraint.getPredicate().isNull();
|
||||
}
|
||||
|
||||
bool tblgen::NamedTypeConstraint::isVariadic() const {
|
||||
return constraint.isVariadic();
|
||||
}
|
||||
|
|
|
@ -82,8 +82,10 @@ StringRef tblgen::Operator::getResultName(int index) const {
|
|||
return results->getArgNameStr(index);
|
||||
}
|
||||
|
||||
bool tblgen::Operator::hasVariadicResult() const {
|
||||
return !results.empty() && results.back().constraint.isVariadic();
|
||||
unsigned tblgen::Operator::getNumVariadicResults() const {
|
||||
return std::count_if(
|
||||
results.begin(), results.end(),
|
||||
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
|
||||
}
|
||||
|
||||
int tblgen::Operator::getNumNativeAttributes() const {
|
||||
|
@ -98,8 +100,10 @@ const tblgen::NamedAttribute &tblgen::Operator::getAttribute(int index) const {
|
|||
return attributes[index];
|
||||
}
|
||||
|
||||
bool tblgen::Operator::hasVariadicOperand() const {
|
||||
return !operands.empty() && operands.back().constraint.isVariadic();
|
||||
unsigned tblgen::Operator::getNumVariadicOperands() const {
|
||||
return std::count_if(
|
||||
operands.begin(), operands.end(),
|
||||
[](const NamedTypeConstraint &c) { return c.constraint.isVariadic(); });
|
||||
}
|
||||
|
||||
StringRef tblgen::Operator::getArgName(int index) const {
|
||||
|
@ -222,13 +226,6 @@ void tblgen::Operator::populateOpStructure() {
|
|||
}
|
||||
}
|
||||
|
||||
// Verify that only the last operand can be variadic.
|
||||
for (int i = 0, e = operands.size() - 1; i < e; ++i) {
|
||||
if (operands[i].constraint.isVariadic())
|
||||
PrintFatalError(def.getLoc(),
|
||||
"only the last operand allowed to be variadic");
|
||||
}
|
||||
|
||||
auto *resultsDag = def.getValueAsDag("results");
|
||||
auto *outsOp = dyn_cast<DefInit>(resultsDag->getOperator());
|
||||
if (!outsOp || outsOp->getDef()->getName() != "outs") {
|
||||
|
@ -246,13 +243,6 @@ void tblgen::Operator::populateOpStructure() {
|
|||
results.push_back({name, TypeConstraint(resultDef)});
|
||||
}
|
||||
|
||||
// Verify that only the last result can be variadic.
|
||||
for (int i = 0, e = results.size() - 1; i < e; ++i) {
|
||||
if (results[i].constraint.isVariadic())
|
||||
PrintFatalError(def.getLoc(),
|
||||
"only the last result allowed to be variadic");
|
||||
}
|
||||
|
||||
auto traitListInit = def.getValueAsListInit("traits");
|
||||
if (!traitListInit)
|
||||
return;
|
||||
|
|
|
@ -1,28 +0,0 @@
|
|||
// RUN: mlir-tblgen -gen-op-defs -I %S/../../include %s | FileCheck %s
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
|
||||
def NS_OpA : Op<"op_same_value_type", [SameValueType]> {
|
||||
let arguments = (ins Tensor:$input);
|
||||
let results = (outs Tensor:$result);
|
||||
}
|
||||
|
||||
// Test that with SameValueType trait we can generate a builder without
|
||||
// requiring result type
|
||||
// ---
|
||||
|
||||
// CHECK-LABEL: OpA::build(Builder *, OperationState *tblgen_state, Value *input)
|
||||
// CHECK: tblgen_state->addTypes({input->getType()});
|
||||
|
||||
def NS_OpB : Op<"op_same_value_type_variadic_input", [SameValueType]> {
|
||||
let arguments = (ins Variadic<Tensor>:$input);
|
||||
let results = (outs Tensor:$result);
|
||||
}
|
||||
|
||||
// Test that if the only operand is variadic, we acess the first value in the
|
||||
// pack to set result type
|
||||
// ---
|
||||
|
||||
// CHECK-LABEL: OpB::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
|
||||
// CHECK: tblgen_state->addTypes({input.front()->getType()});
|
|
@ -2,7 +2,7 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def OpA : Op<"one_operand_op", []> {
|
||||
def OpA : Op<"one_normal_operand_op", []> {
|
||||
let arguments = (ins I32:$input);
|
||||
}
|
||||
|
||||
|
@ -10,7 +10,7 @@ def OpA : Op<"one_operand_op", []> {
|
|||
|
||||
// CHECK: void OpA::build
|
||||
// CHECK-SAME: Value *input
|
||||
// CHECK: tblgen_state->addOperands({input});
|
||||
// CHECK: tblgen_state->operands.push_back(input);
|
||||
|
||||
// CHECK: void OpA::build
|
||||
// CHECK-SAME: ArrayRef<Value *> operands
|
||||
|
@ -21,11 +21,72 @@ def OpA : Op<"one_operand_op", []> {
|
|||
// CHECK: if (!((this->getOperation()->getOperand(0)->getType().isInteger(32))))
|
||||
// CHECK-NEXT: return emitOpError("operand #0 must be 32-bit integer");
|
||||
|
||||
def OpB : Op<"variadic_operand_op", []> {
|
||||
def OpB : Op<"one_variadic_operand_op", []> {
|
||||
let arguments = (ins Variadic<I32>:$input);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpB::build
|
||||
// CHECK-SAME: ArrayRef<Value *> input
|
||||
// CHECK-NOT: assert
|
||||
// CHECK: tblgen_state->addOperands(input);
|
||||
// CHECK-SAME: ArrayRef<Value *> input
|
||||
// CHECK-NOT: assert
|
||||
// CHECK: tblgen_state->addOperands(input);
|
||||
|
||||
def OpC : Op<"all_variadic_inputs_op", [SameVariadicOperandSize]> {
|
||||
let arguments = (ins Variadic<Tensor>:$input1, Variadic<Tensor>:$input2);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpC::input1()
|
||||
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0;
|
||||
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpC::input2()
|
||||
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 0) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1;
|
||||
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
|
||||
|
||||
// CHECK-LABEL: OpC::build
|
||||
// CHECK-NEXT: tblgen_state->addOperands(input1);
|
||||
// CHECK-NEXT: tblgen_state->addOperands(input2);
|
||||
|
||||
def OpD : Op<"mix_variadic_and_normal_inputs_op", [SameVariadicOperandSize]> {
|
||||
let arguments = (ins Variadic<Tensor>:$input1, Tensor:$input2, Variadic<Tensor>:$input3);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpD::input1()
|
||||
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 0;
|
||||
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
|
||||
|
||||
// CHECK-LABEL: Value *OpD::input2()
|
||||
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 0 + variadicOperandSize * 1;
|
||||
// CHECK-NEXT: return this->getOperand(offset);
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpD::input3()
|
||||
// CHECK-NEXT: unsigned variadicOperandSize = (this->getNumOperands() - 1) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 1 + variadicOperandSize * 1;
|
||||
// CHECK-NEXT: return {std::next(operand_begin(), offset), std::next(operand_begin(), offset + variadicOperandSize)};
|
||||
|
||||
// CHECK-LABEL: OpD::build
|
||||
// CHECK-NEXT: tblgen_state->addOperands(input1);
|
||||
// CHECK-NEXT: tblgen_state->operands.push_back(input2);
|
||||
// CHECK-NEXT: tblgen_state->addOperands(input3);
|
||||
|
||||
def OpE : Op<"one_variadic_among_multi_normal_inputs_op", []> {
|
||||
let arguments = (ins Tensor:$input1, Tensor:$input2, Variadic<Tensor>:$input3, Tensor:$input4, Tensor:$input5);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Value *OpE::input1()
|
||||
// CHECK-NEXT: return this->getOperation()->getOperand(0);
|
||||
|
||||
// CHECK-LABEL: Value *OpE::input2()
|
||||
// CHECK-NEXT: return this->getOperation()->getOperand(1);
|
||||
|
||||
// CHECK-LABEL: Operation::operand_range OpE::input3()
|
||||
// CHECK-NEXT: return {std::next(operand_begin(), 2), std::next(operand_begin(), 2 + this->getNumOperands() - 4)};
|
||||
|
||||
// CHECK-LABEL: Value *OpE::input4()
|
||||
// CHECK-NEXT: return this->getOperation()->getOperand(this->getNumOperands() - 2);
|
||||
|
||||
// CHECK-LABEL: Value *OpE::input5()
|
||||
// CHECK-NEXT: return this->getOperation()->getOperand(this->getNumOperands() - 1);
|
||||
|
|
|
@ -2,82 +2,160 @@
|
|||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
def OneResultOp : Op<"one_result_op", []> {
|
||||
def OpA : Op<"one_normal_result_op", []> {
|
||||
let results = (outs I32:$result);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OneResultOp definitions
|
||||
// CHECK-LABEL: Value *OpA::result()
|
||||
// CHECK-NEXT: return this->getOperation()->getResult(0)
|
||||
|
||||
// CHECK: void OneResultOp::build
|
||||
// CHECK-LABEL: void OpA::build
|
||||
// CHECK: ArrayRef<Type> resultTypes, ArrayRef<Value *> operands
|
||||
// CHECK: assert(resultTypes.size() == 1u && "mismatched number of return types");
|
||||
// CHECK-NEXT: tblgen_state->addTypes(resultTypes);
|
||||
|
||||
// CHECK: LogicalResult OneResultOp::verify() {
|
||||
// CHECK-LABEL: LogicalResult OpA::verify()
|
||||
// CHECK: if (!((this->getOperation()->getResult(0)->getType().isInteger(32))))
|
||||
// CHECK-NEXT: return emitOpError("result #0 must be 32-bit integer");
|
||||
|
||||
|
||||
def SameTypeOp : Op<"same_type_op", [SameValueType]> {
|
||||
def OpB : Op<"same_input_output_type_op", [SameValueType]> {
|
||||
let arguments = (ins I32:$x);
|
||||
let results = (outs I32:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: SameTypeOp definitions
|
||||
// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Type y, Value *x)
|
||||
// CHECK: tblgen_state->addTypes({y});
|
||||
// CHECK: void SameTypeOp::build(Builder *, OperationState *tblgen_state, Value *x)
|
||||
// CHECK-LABEL: OpB definitions
|
||||
// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Type y, Value *x)
|
||||
// CHECK: tblgen_state->types.push_back(y);
|
||||
// CHECK: void OpB::build(Builder *, OperationState *tblgen_state, Value *x)
|
||||
// CHECK: tblgen_state->addTypes({x->getType()});
|
||||
|
||||
def ThreeResultOp : Op<"three_result_op", []> {
|
||||
def OpC : Op<"three_normal_result_op", []> {
|
||||
let results = (outs I32:$x, /*unnamed*/I32, I32:$z);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ThreeResultOp definitions
|
||||
// CHECK: void ThreeResultOp::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z)
|
||||
// CHECK: tblgen_state->addTypes({x, resultType1, z});
|
||||
// CHECK-LABEL: OpC definitions
|
||||
// CHECK: void OpC::build(Builder *, OperationState *tblgen_state, Type x, Type resultType1, Type z)
|
||||
// CHECK-NEXT: tblgen_state->types.push_back(x)
|
||||
// CHECK-NEXT: tblgen_state->types.push_back(resultType1)
|
||||
// CHECK-NEXT: tblgen_state->types.push_back(z)
|
||||
|
||||
def IntegerTypeAttr : TypeAttrBase<"IntegerType", "Integer type attribute">;
|
||||
def TypeAttrResultTypeOp : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||
def OpD : Op<"type_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||
let arguments = (ins I32:$x, IntegerTypeAttr:$attr, F32Attr:$f32);
|
||||
let results = (outs Tensor:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: TypeAttrResultTypeOp definitions
|
||||
// CHECK: void TypeAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32)
|
||||
// CHECK-LABEL: OpD definitions
|
||||
// CHECK: void OpD::build(Builder *, OperationState *tblgen_state, Value *x, TypeAttr attr, FloatAttr f32)
|
||||
// CHECK: tblgen_state->addTypes({attr.getValue()});
|
||||
|
||||
def ValueAttrResultTypeOp : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||
def OpE : Op<"value_attr_as_result_type", [FirstAttrDerivedResultType]> {
|
||||
let arguments = (ins I32:$x, F32Attr:$attr);
|
||||
let results = (outs Tensor:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: ValueAttrResultTypeOp definitions
|
||||
// CHECK: void ValueAttrResultTypeOp::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr)
|
||||
// CHECK-LABEL: OpE definitions
|
||||
// CHECK: void OpE::build(Builder *, OperationState *tblgen_state, Value *x, FloatAttr attr)
|
||||
// CHECK: tblgen_state->addTypes({attr.getType()});
|
||||
|
||||
def VariadicResultAloneOp : Op<"variadic_alone_op", []> {
|
||||
def OpF : Op<"one_variadic_result_op", []> {
|
||||
let results = (outs Variadic<I32>:$x);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: VariadicResultAloneOp definitions
|
||||
// CHECK-LABEL: Operation::result_range OpF::x()
|
||||
// CHECK-NEXT: return {std::next(result_begin(), 0), std::next(result_begin(), 0 + this->getNumResults() - 0)};
|
||||
|
||||
// CHECK-LABEL: void VariadicResultAloneOp::build
|
||||
// CHECK-SAME: ArrayRef<Type> x
|
||||
// CHECK-NOT: assert
|
||||
// CHECK: tblgen_state->addTypes(x);
|
||||
// CHECK-LABEL: void OpF::build
|
||||
// CHECK-SAME: ArrayRef<Type> x
|
||||
// CHECK-NOT: assert
|
||||
// CHECK: tblgen_state->addTypes(x);
|
||||
|
||||
def OpG : Op<"one_normal_and_one_variadic_result_op", []> {
|
||||
|
||||
def VariadicResultOp : Op<"variadic_op", []> {
|
||||
let results = (outs I32:$x, Variadic<I32>:$y);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: VariadicResultOp definitions
|
||||
// CHECK-LABEL: OpG definitions
|
||||
|
||||
// CHECK: void VariadicResultOp::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y)
|
||||
// CHECK: tblgen_state->addTypes({x});
|
||||
// CHECK: tblgen_state->addTypes(y);
|
||||
// CHECK: void OpG::build(Builder *, OperationState *tblgen_state, Type x, ArrayRef<Type> y)
|
||||
// CHECK-NEXT: tblgen_state->types.push_back(x);
|
||||
// CHECK-NEXT: tblgen_state->addTypes(y);
|
||||
|
||||
// CHECK: void VariadicResultOp::build
|
||||
// CHECK: void OpG::build
|
||||
// CHECK: ArrayRef<Type> resultTypes
|
||||
// CHECK: assert(resultTypes.size() >= 1u && "mismatched number of return types");
|
||||
// CHECK-NEXT: tblgen_state->addTypes(resultTypes);
|
||||
|
||||
|
||||
def OpH : Op<"all_variadic_results_op", [SameVariadicResultSize]> {
|
||||
let results = (outs Variadic<Tensor>:$output1, Variadic<Tensor>:$output2);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpH::output1()
|
||||
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0;
|
||||
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpH::output2()
|
||||
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 0) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1;
|
||||
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
|
||||
|
||||
|
||||
// CHECK-LABEL: OpH::build
|
||||
// CHECK-NEXT: tblgen_state->addTypes(output1);
|
||||
// CHECK-NEXT: tblgen_state->addTypes(output2);
|
||||
|
||||
def OpI : Op<"mix_variadic_and_normal_results_op", [SameVariadicResultSize]> {
|
||||
let results = (outs Variadic<Tensor>:$output1, Tensor:$output2, Variadic<Tensor>:$output3);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpI::output1()
|
||||
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 0;
|
||||
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
|
||||
|
||||
// CHECK-LABEL: Value *OpI::output2()
|
||||
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 0 + variadicResultSize * 1;
|
||||
// CHECK-NEXT: return this->getResult(offset);
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpI::output3()
|
||||
// CHECK-NEXT: unsigned variadicResultSize = (this->getNumResults() - 1) / 2;
|
||||
// CHECK-NEXT: unsigned offset = 1 + variadicResultSize * 1;
|
||||
// CHECK-NEXT: return {std::next(result_begin(), offset), std::next(result_begin(), offset + variadicResultSize)};
|
||||
|
||||
// CHECK-LABEL: OpI::build
|
||||
// CHECK-NEXT: tblgen_state->addTypes(output1);
|
||||
// CHECK-NEXT: tblgen_state->types.push_back(output2);
|
||||
// CHECK-NEXT: tblgen_state->addTypes(output3);
|
||||
|
||||
def OpJ : Op<"one_variadic_among_multi_normal_results_op", []> {
|
||||
let results = (outs Tensor:$output1, Tensor:$output2, Variadic<Tensor>:$output3, Tensor:$output4, Tensor:$output5);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: Value *OpJ::output1()
|
||||
// CHECK-NEXT: return this->getOperation()->getResult(0);
|
||||
|
||||
// CHECK-LABEL: Value *OpJ::output2()
|
||||
// CHECK-NEXT: return this->getOperation()->getResult(1);
|
||||
|
||||
// CHECK-LABEL: Operation::result_range OpJ::output3()
|
||||
// CHECK-NEXT: return {std::next(result_begin(), 2), std::next(result_begin(), 2 + this->getNumResults() - 4)};
|
||||
|
||||
// CHECK-LABEL: Value *OpJ::output4()
|
||||
// CHECK-NEXT: return this->getOperation()->getResult(this->getNumResults() - 2);
|
||||
|
||||
// CHECK-LABEL: Value *OpJ::output5()
|
||||
// CHECK-NEXT: return this->getOperation()->getResult(this->getNumResults() - 1);
|
||||
|
||||
// Test that if the only operand is variadic, we acess the first value in the
|
||||
// pack to set result type
|
||||
// ---
|
||||
def OpK : Op<"only_input_is_variadic_with_same_value_type_op", [SameValueType]> {
|
||||
let arguments = (ins Variadic<Tensor>:$input);
|
||||
let results = (outs Tensor:$result);
|
||||
}
|
||||
|
||||
// CHECK-LABEL: OpK::build(Builder *, OperationState *tblgen_state, ArrayRef<Value *> input)
|
||||
// CHECK: tblgen_state->addTypes({input.front()->getType()});
|
||||
|
|
|
@ -75,10 +75,14 @@ static StringLoc findNextVariable(StringRef str) {
|
|||
return {startPos, endPos - startPos};
|
||||
}
|
||||
|
||||
// Check if `name` is the name of the variadic argument of `op`. The variadic
|
||||
// argument can only appear at the last position in the list of arguments.
|
||||
static bool isVariadicArgumentName(const tblgen::Operator &op, StringRef name) {
|
||||
return op.hasVariadicOperand() && op.getArgName(op.getNumArgs() - 1) == name;
|
||||
// Check if `name` is the name of the variadic operand of `op`. The variadic
|
||||
// operand can only appear at the last position in the list of operands.
|
||||
static bool isVariadicOperandName(const tblgen::Operator &op, StringRef name) {
|
||||
unsigned numOperands = op.getNumOperands();
|
||||
if (numOperands == 0)
|
||||
return false;
|
||||
const auto &operand = op.getOperand(numOperands - 1);
|
||||
return operand.isVariadic() && operand.name == name;
|
||||
}
|
||||
|
||||
// Check if `result` is a known name of a result of `op`.
|
||||
|
@ -127,9 +131,9 @@ static bool emitOneBuilder(const Record &record, raw_ostream &os) {
|
|||
// First, insert the non-matched part as is.
|
||||
bs << builderStrRef.substr(0, loc.pos);
|
||||
// Then, rewrite the name based on its kind.
|
||||
bool isVariadicArg = isVariadicArgumentName(op, name);
|
||||
bool isVariadicOperand = isVariadicOperandName(op, name);
|
||||
if (isOperandName(op, name)) {
|
||||
auto result = isVariadicArg
|
||||
auto result = isVariadicOperand
|
||||
? formatv("lookupValues(op.{0}())", name)
|
||||
: formatv("valueMapping.lookup(op.{0}())", name);
|
||||
bs << result;
|
||||
|
|
|
@ -251,8 +251,9 @@ OpMethodBody &OpMethodBody::operator<<(const FmtObjectBase &content) {
|
|||
}
|
||||
|
||||
void OpMethodBody::writeTo(raw_ostream &os) const {
|
||||
os << body;
|
||||
if (body.empty() || body.back() != '\n')
|
||||
auto bodyRef = StringRef(body).drop_while([](char c) { return c == '\n'; });
|
||||
os << bodyRef;
|
||||
if (bodyRef.empty() || bodyRef.back() != '\n')
|
||||
os << "\n";
|
||||
}
|
||||
|
||||
|
@ -455,35 +456,153 @@ void OpEmitter::genAttrGetters() {
|
|||
}
|
||||
|
||||
void OpEmitter::genNamedOperandGetters() {
|
||||
for (int i = 0, e = op.getNumOperands(); i != e; ++i) {
|
||||
const unsigned numOperands = op.getNumOperands();
|
||||
const unsigned numVariadicOperands = op.getNumVariadicOperands();
|
||||
const unsigned numNormalOperands = numOperands - numVariadicOperands;
|
||||
|
||||
// Special case for ops without variadic operands: the i-th value is for the
|
||||
// i-th operand defined in the op.
|
||||
// Special case for ops with one variadic operand: the variadic operand can
|
||||
// appear at any place, so the i-th value may not necessarily belong to the
|
||||
// i-th operand definition. we need to calculate the index (range) for each
|
||||
// operand.
|
||||
if (numVariadicOperands <= 1) {
|
||||
bool emittedVariadicOperand = false;
|
||||
for (unsigned i = 0; i != numOperands; ++i) {
|
||||
const auto &operand = op.getOperand(i);
|
||||
if (operand.name.empty())
|
||||
continue;
|
||||
|
||||
if (operand.isVariadic()) {
|
||||
auto &m = opClass.newMethod("Operation::operand_range", operand.name);
|
||||
m.body() << formatv(
|
||||
" return {{std::next(operand_begin(), {0}), "
|
||||
"std::next(operand_begin(), {0} + this->getNumOperands() - {1})};",
|
||||
i, numNormalOperands);
|
||||
emittedVariadicOperand = true;
|
||||
} else {
|
||||
auto &m = opClass.newMethod("Value *", operand.name);
|
||||
m.body() << " return this->getOperation()->getOperand(";
|
||||
if (emittedVariadicOperand)
|
||||
m.body() << "this->getNumOperands() - " << numOperands - i;
|
||||
else
|
||||
m.body() << i;
|
||||
m.body() << ");\n";
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// If we have more than one variadic operands, we need more complicated logic
|
||||
// to calculate the value range for each operand.
|
||||
|
||||
if (!op.hasTrait("SameVariadicOperandSize")) {
|
||||
PrintFatalError(op.getLoc(), "op has multiple variadic operands but no "
|
||||
"specification over their sizes");
|
||||
}
|
||||
|
||||
unsigned emittedNormalOperands = 0;
|
||||
unsigned emittedVariadicOperands = 0;
|
||||
|
||||
for (unsigned i = 0; i != numOperands; ++i) {
|
||||
const auto &operand = op.getOperand(i);
|
||||
if (operand.name.empty())
|
||||
continue;
|
||||
|
||||
if (!operand.constraint.isVariadic()) {
|
||||
auto &m = opClass.newMethod("Value *", operand.name);
|
||||
m.body() << " return this->getOperation()->getOperand(" << i << ");\n";
|
||||
} else {
|
||||
assert(i + 1 == e && "only the last operand can be variadic");
|
||||
const char *code = R"(
|
||||
unsigned variadicOperandSize = (this->getNumOperands() - {0}) / {1};
|
||||
unsigned offset = {2} + variadicOperandSize * {3};
|
||||
return )";
|
||||
auto sizeAndOffset =
|
||||
formatv(code, numNormalOperands, numVariadicOperands,
|
||||
emittedNormalOperands, emittedVariadicOperands);
|
||||
|
||||
const char *const code = R"(
|
||||
assert(getOperation()->getNumOperands() >= {0});
|
||||
return {std::next(operand_begin(), {0}), operand_end()};
|
||||
)";
|
||||
if (operand.isVariadic()) {
|
||||
auto &m = opClass.newMethod("Operation::operand_range", operand.name);
|
||||
m.body() << formatv(code, i);
|
||||
m.body() << sizeAndOffset
|
||||
<< "{std::next(operand_begin(), offset), "
|
||||
"std::next(operand_begin(), offset + variadicOperandSize)};";
|
||||
++emittedVariadicOperands;
|
||||
} else {
|
||||
auto &m = opClass.newMethod("Value *", operand.name);
|
||||
m.body() << sizeAndOffset << "this->getOperand(offset);";
|
||||
++emittedNormalOperands;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genNamedResultGetters() {
|
||||
for (int i = 0, e = op.getNumResults(); i != e; ++i) {
|
||||
const unsigned numResults = op.getNumResults();
|
||||
const unsigned numVariadicResults = op.getNumVariadicResults();
|
||||
const unsigned numNormalResults = numResults - numVariadicResults;
|
||||
|
||||
// Special case for ops without variadic results: the i-th value is for the
|
||||
// i-th result defined in the op.
|
||||
// Special case for ops with one variadic result: the variadic result can
|
||||
// appear at any place, so the i-th value may not necessarily belong to the
|
||||
// i-th result definition. we need to calculate the index (range) for each
|
||||
// result.
|
||||
if (numVariadicResults <= 1) {
|
||||
bool emittedVariadicResult = false;
|
||||
for (unsigned i = 0; i != numResults; ++i) {
|
||||
const auto &result = op.getResult(i);
|
||||
if (result.name.empty())
|
||||
continue;
|
||||
|
||||
if (result.isVariadic()) {
|
||||
auto &m = opClass.newMethod("Operation::result_range", result.name);
|
||||
m.body() << formatv(
|
||||
" return {{std::next(result_begin(), {0}), "
|
||||
"std::next(result_begin(), {0} + this->getNumResults() - {1})};",
|
||||
i, numNormalResults);
|
||||
emittedVariadicResult = true;
|
||||
} else {
|
||||
auto &m = opClass.newMethod("Value *", result.name);
|
||||
m.body() << " return this->getOperation()->getResult(";
|
||||
if (emittedVariadicResult)
|
||||
m.body() << "this->getNumResults() - " << numResults - i;
|
||||
else
|
||||
m.body() << i;
|
||||
m.body() << ");\n";
|
||||
}
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
// If we have more than one variadic results, we need more complicated logic
|
||||
// to calculate the value range for each result.
|
||||
|
||||
if (!op.hasTrait("SameVariadicResultSize")) {
|
||||
PrintFatalError(op.getLoc(), "op has multiple variadic results but no "
|
||||
"specification over their sizes");
|
||||
}
|
||||
|
||||
unsigned emittedNormalResults = 0;
|
||||
unsigned emittedVariadicResults = 0;
|
||||
|
||||
for (unsigned i = 0; i != numResults; ++i) {
|
||||
const auto &result = op.getResult(i);
|
||||
if (result.constraint.isVariadic() || result.name.empty())
|
||||
if (result.name.empty())
|
||||
continue;
|
||||
|
||||
auto &m = opClass.newMethod("Value *", result.name);
|
||||
m.body() << " return this->getOperation()->getResult(" << i << ");\n";
|
||||
const char *code = R"(
|
||||
unsigned variadicResultSize = (this->getNumResults() - {0}) / {1};
|
||||
unsigned offset = {2} + variadicResultSize * {3};
|
||||
return )";
|
||||
auto sizeAndOffset = formatv(code, numNormalResults, numVariadicResults,
|
||||
emittedNormalResults, emittedVariadicResults);
|
||||
|
||||
if (result.isVariadic()) {
|
||||
auto &m = opClass.newMethod("Operation::result_range", result.name);
|
||||
m.body() << sizeAndOffset
|
||||
<< "{std::next(result_begin(), offset), "
|
||||
"std::next(result_begin(), offset + variadicResultSize)};";
|
||||
++emittedVariadicResults;
|
||||
} else {
|
||||
auto &m = opClass.newMethod("Value *", result.name);
|
||||
m.body() << sizeAndOffset << "this->getResult(offset);";
|
||||
++emittedNormalResults;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -505,12 +624,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
|
|||
// Emit parameters for all return types
|
||||
if (!useOperandType && !useAttrType) {
|
||||
for (unsigned i = 0; i != numResults; ++i) {
|
||||
std::string resultName = op.getResultName(i);
|
||||
const auto &result = op.getResult(i);
|
||||
std::string resultName = result.name;
|
||||
if (resultName.empty())
|
||||
resultName = formatv("resultType{0}", i);
|
||||
|
||||
bool isVariadic = op.getResultTypeConstraint(i).isVariadic();
|
||||
paramList.append(isVariadic ? ", ArrayRef<Type> " : ", Type ");
|
||||
paramList.append(result.isVariadic() ? ", ArrayRef<Type> " : ", Type ");
|
||||
paramList.append(resultName);
|
||||
|
||||
resultNames.emplace_back(std::move(resultName));
|
||||
|
@ -520,12 +639,13 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
|
|||
// Emit parameters for all arguments (operands and attributes).
|
||||
int numOperands = 0;
|
||||
int numAttrs = 0;
|
||||
|
||||
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
|
||||
auto argument = op.getArg(i);
|
||||
if (argument.is<tblgen::NamedTypeConstraint *>()) {
|
||||
auto &operand = op.getOperand(numOperands);
|
||||
paramList.append(operand.constraint.isVariadic() ? ", ArrayRef<Value *> "
|
||||
: ", Value *");
|
||||
const auto &operand = op.getOperand(numOperands);
|
||||
paramList.append(operand.isVariadic() ? ", ArrayRef<Value *> "
|
||||
: ", Value *");
|
||||
paramList.append(getArgumentName(op, numOperands));
|
||||
++numOperands;
|
||||
} else {
|
||||
|
@ -542,33 +662,22 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
|
|||
}
|
||||
|
||||
if (numOperands + numAttrs != op.getNumArgs())
|
||||
return PrintFatalError(
|
||||
"op arguments must be either operands or attributes");
|
||||
PrintFatalError("op arguments must be either operands or attributes");
|
||||
|
||||
auto &method =
|
||||
opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
|
||||
|
||||
bool hasVariadicOperand = op.hasVariadicOperand();
|
||||
auto &m = opClass.newMethod("void", "build", paramList, OpMethod::MP_Static);
|
||||
|
||||
// Push all result types to the result
|
||||
if (numResults > 0) {
|
||||
if (!useOperandType && !useAttrType) {
|
||||
bool hasVariadicResult = op.hasVariadicResult();
|
||||
int numNonVariadicResults =
|
||||
numResults - static_cast<int>(hasVariadicResult);
|
||||
|
||||
if (numNonVariadicResults > 0) {
|
||||
method.body() << " " << builderOpState << "->addTypes({"
|
||||
<< resultNames.front();
|
||||
for (int i = 1; i < numNonVariadicResults; ++i) {
|
||||
method.body() << ", " << resultNames[i];
|
||||
for (unsigned i = 0; i < numResults; ++i) {
|
||||
const auto &result = op.getResult(i);
|
||||
m.body() << " " << builderOpState;
|
||||
if (result.isVariadic()) {
|
||||
m.body() << "->addTypes(";
|
||||
} else {
|
||||
m.body() << "->types.push_back(";
|
||||
}
|
||||
method.body() << "});\n";
|
||||
}
|
||||
|
||||
if (hasVariadicResult) {
|
||||
method.body() << " " << builderOpState << "->addTypes("
|
||||
<< resultNames.back() << ");\n";
|
||||
m.body() << resultNames[i] << ");\n";
|
||||
}
|
||||
} else {
|
||||
std::string resultType;
|
||||
|
@ -580,32 +689,27 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
|
|||
resultType = formatv("{0}.getType()", namedAttr.name);
|
||||
}
|
||||
} else {
|
||||
const char *index =
|
||||
(numOperands == 1 && hasVariadicOperand) ? ".front()" : "";
|
||||
const char *index = op.getOperand(0).isVariadic() ? ".front()" : "";
|
||||
resultType =
|
||||
formatv("{0}{1}->getType()", getArgumentName(op, 0), index).str();
|
||||
}
|
||||
method.body() << " " << builderOpState << "->addTypes({" << resultType;
|
||||
m.body() << " " << builderOpState << "->addTypes({" << resultType;
|
||||
for (unsigned i = 1; i != numResults; ++i)
|
||||
method.body() << ", " << resultType;
|
||||
method.body() << "});\n\n";
|
||||
m.body() << ", " << resultType;
|
||||
m.body() << "});\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
// Push all operands to the result
|
||||
int numNonVariadicOperands =
|
||||
numOperands - static_cast<int>(hasVariadicOperand);
|
||||
if (numNonVariadicOperands > 0) {
|
||||
method.body() << " " << builderOpState << "->addOperands({"
|
||||
<< getArgumentName(op, 0);
|
||||
for (int i = 1; i < numNonVariadicOperands; ++i) {
|
||||
method.body() << ", " << getArgumentName(op, i);
|
||||
for (unsigned i = 0; i < numOperands; ++i) {
|
||||
const auto &operand = op.getOperand(i);
|
||||
m.body() << " " << builderOpState;
|
||||
if (operand.isVariadic()) {
|
||||
m.body() << "->addOperands(";
|
||||
} else {
|
||||
m.body() << "->operands.push_back(";
|
||||
}
|
||||
method.body() << "});\n";
|
||||
}
|
||||
if (hasVariadicOperand) {
|
||||
method.body() << " " << builderOpState << "->addOperands("
|
||||
<< getArgumentName(op, numOperands - 1) << ");\n";
|
||||
m.body() << getArgumentName(op, i) << ");\n";
|
||||
}
|
||||
|
||||
// Push all attributes to the result
|
||||
|
@ -613,12 +717,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
|
|||
if (!namedAttr.attr.isDerivedAttr()) {
|
||||
bool emitNotNullCheck = namedAttr.attr.isOptional();
|
||||
if (emitNotNullCheck) {
|
||||
method.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n";
|
||||
m.body() << formatv(" if ({0}) ", namedAttr.name) << "{\n";
|
||||
}
|
||||
method.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n",
|
||||
builderOpState, namedAttr.name);
|
||||
m.body() << formatv(" {0}->addAttribute(\"{1}\", {1});\n",
|
||||
builderOpState, namedAttr.name);
|
||||
if (emitNotNullCheck) {
|
||||
method.body() << " }\n";
|
||||
m.body() << " }\n";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -646,13 +750,13 @@ void OpEmitter::genBuilder() {
|
|||
}
|
||||
}
|
||||
|
||||
auto numResults = op.getNumResults();
|
||||
bool hasVariadicResult = op.hasVariadicResult();
|
||||
int numNonVariadicResults = numResults - int(hasVariadicResult);
|
||||
unsigned numResults = op.getNumResults();
|
||||
unsigned numVariadicResults = op.getNumVariadicResults();
|
||||
unsigned numNonVariadicResults = numResults - numVariadicResults;
|
||||
|
||||
auto numOperands = op.getNumOperands();
|
||||
bool hasVariadicOperand = op.hasVariadicOperand();
|
||||
int numNonVariadicOperands = numOperands - int(hasVariadicOperand);
|
||||
unsigned numOperands = op.getNumOperands();
|
||||
unsigned numVariadicOperands = op.getNumVariadicOperands();
|
||||
unsigned numNonVariadicOperands = numOperands - numVariadicOperands;
|
||||
|
||||
// Generate default builders that requires all result type, operands, and
|
||||
// attributes as parameters.
|
||||
|
@ -681,15 +785,16 @@ void OpEmitter::genBuilder() {
|
|||
auto &body = m.body();
|
||||
|
||||
// Result types
|
||||
if (!(hasVariadicResult && numNonVariadicResults == 0))
|
||||
if (numVariadicResults == 0 || numNonVariadicResults != 0)
|
||||
body << " assert(resultTypes.size()"
|
||||
<< (hasVariadicResult ? " >= " : " == ") << numNonVariadicResults
|
||||
<< (numVariadicResults != 0 ? " >= " : " == ") << numNonVariadicResults
|
||||
<< "u && \"mismatched number of return types\");\n";
|
||||
body << " " << builderOpState << "->addTypes(resultTypes);\n";
|
||||
|
||||
// Operands
|
||||
if (!(hasVariadicOperand && numNonVariadicOperands == 0))
|
||||
body << " assert(operands.size()" << (hasVariadicOperand ? " >= " : " == ")
|
||||
if (numVariadicOperands == 0 || numNonVariadicOperands != 0)
|
||||
body << " assert(operands.size()"
|
||||
<< (numVariadicOperands != 0 ? " >= " : " == ")
|
||||
<< numNonVariadicOperands
|
||||
<< "u && \"mismatched number of parameters\");\n";
|
||||
body << " " << builderOpState << "->addOperands(operands);\n\n";
|
||||
|
@ -703,7 +808,7 @@ void OpEmitter::genBuilder() {
|
|||
|
||||
bool useOperandType = op.hasTrait("SameOperandsAndResultType");
|
||||
bool useAttrType = op.hasTrait("FirstAttrDerivedResultType");
|
||||
if (!op.hasVariadicResult() && (useOperandType || useAttrType))
|
||||
if (numVariadicResults == 0 && (useOperandType || useAttrType))
|
||||
genStandaloneParamBuilder(useOperandType, useAttrType);
|
||||
}
|
||||
|
||||
|
@ -824,7 +929,7 @@ void OpEmitter::genVerifier() {
|
|||
auto verifyValue = [&](const tblgen::NamedTypeConstraint &value, int index,
|
||||
bool isOperand) -> void {
|
||||
// TODO: Handle variadic operand/result verification.
|
||||
if (value.constraint.isVariadic())
|
||||
if (value.isVariadic())
|
||||
return;
|
||||
|
||||
// TODO: Commonality between matchers could be extracted to have a more
|
||||
|
@ -869,12 +974,12 @@ void OpEmitter::genVerifier() {
|
|||
}
|
||||
|
||||
void OpEmitter::genTraits() {
|
||||
auto numResults = op.getNumResults();
|
||||
bool hasVariadicResult = op.hasVariadicResult();
|
||||
unsigned numResults = op.getNumResults();
|
||||
unsigned numVariadicResults = op.getNumVariadicResults();
|
||||
|
||||
// Add return size trait.
|
||||
if (hasVariadicResult) {
|
||||
if (numResults == 1)
|
||||
if (numVariadicResults != 0) {
|
||||
if (numResults == numVariadicResults)
|
||||
opClass.addTrait("VariadicResults");
|
||||
else
|
||||
opClass.addTrait("AtLeastNResults<" + Twine(numResults - 1) + ">::Impl");
|
||||
|
@ -898,12 +1003,12 @@ void OpEmitter::genTraits() {
|
|||
}
|
||||
|
||||
// Add variadic size trait and normal op traits.
|
||||
auto numOperands = op.getNumOperands();
|
||||
bool hasVariadicOperand = op.hasVariadicOperand();
|
||||
unsigned numOperands = op.getNumOperands();
|
||||
unsigned numVariadicOperands = op.getNumVariadicOperands();
|
||||
|
||||
// Add operand size trait.
|
||||
if (hasVariadicOperand) {
|
||||
if (numOperands == 1)
|
||||
if (numVariadicOperands != 0) {
|
||||
if (numOperands == numVariadicOperands)
|
||||
opClass.addTrait("VariadicOperands");
|
||||
else
|
||||
opClass.addTrait("AtLeastNOperands<" + Twine(numOperands - 1) +
|
||||
|
|
|
@ -440,7 +440,7 @@ void PatternEmitter::emit(StringRef rewriteName) {
|
|||
const Operator &rootOp = pattern.getSourceRootOp();
|
||||
auto rootName = rootOp.getOperationName();
|
||||
|
||||
if (rootOp.hasVariadicResult())
|
||||
if (rootOp.getNumVariadicResults() != 0)
|
||||
PrintFatalError(
|
||||
loc, "replacing op with variadic results not supported right now");
|
||||
|
||||
|
|
Loading…
Reference in New Issue