[mlir] Add support for VariadicOfVariadic operands

This revision adds native ODS support for VariadicOfVariadic operand
groups. An example of this is the SwitchOp, which has a variadic number
of nested operand ranges for each of the case statements, where the
number of case statements is variadic. Builtin ODS support allows for
generating proper accessors for the nested operand ranges, builder
support, and declarative format support. VariadicOfVariadic operands
are supported by providing a segment attribute to use to store the
operand groups, mapping similarly to the AttrSizedOperand trait
(but with a user defined attribute name).

`build` methods for VariadicOfVariadic operand expect inputs of the
form `ArrayRef<ValueRange>`. Accessors for the variadic ranges
return a new `OperandRangeRange` type, which represents a
contiguous range of `OperandRange`. In the declarative assembly
format, VariadicOfVariadic operands and types are by default
formatted as a comma delimited list of value lists:
`(<value>, <value>), (), (<value>)`.

Differential Revision: https://reviews.llvm.org/D107774
This commit is contained in:
River Riddle 2021-08-23 20:23:09 +00:00
parent ce4545db1d
commit 4e103a12d9
22 changed files with 716 additions and 311 deletions

View File

@ -229,6 +229,17 @@ the `SameVariadicOperandSize` or `AttrSizedOperandSegments` trait is needed to
indicate that all variable length operands have the same number of dynamic
values.
#### VariadicOfVariadic operands
To declare a variadic operand that has a variadic number of sub-ranges, wrap the
`TypeConstraint` for the operand with `VariadicOfVariadic<...,
"<segment-attribute-name>">`.
The second field of the `VariadicOfVariadic` is the name of an `I32ElementsAttr`
argument that contains the sizes of the variadic sub-ranges. This attribute will
be used when determining the size of sub-ranges, or when updating the size of
sub-ranges.
#### Optional operands
To declare an optional operand, wrap the `TypeConstraint` for the operand with
@ -717,6 +728,8 @@ declarative parameter to `parse` method argument is detailed below:
- Single: `OpAsmParser::OperandType &`
- Optional: `Optional<OpAsmParser::OperandType> &`
- Variadic: `SmallVectorImpl<OpAsmParser::OperandType> &`
- VariadicOfVariadic:
`SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &`
* Ref Directives
- A reference directive is passed to the parser using the same mapping as
the input operand. For example, a single region would be passed as a
@ -731,6 +744,7 @@ declarative parameter to `parse` method argument is detailed below:
- Single: `Type &`
- Optional: `Type &`
- Variadic: `SmallVectorImpl<Type> &`
- VariadicOfVariadic: `SmallVectorImpl<SmallVector<Type>> &`
* `attr-dict` Directive: `NamedAttrList &`
When a variable is optional, the value should only be specified if the variable
@ -749,6 +763,7 @@ declarative parameter to `print` method argument is detailed below:
- Single: `Value`
- Optional: `Value`
- Variadic: `OperandRange`
- VariadicOfVariadic: `OperandRangeRange`
* Ref Directives
- A reference directive is passed to the printer using the same mapping as
the input operand. For example, a single region would be passed as a
@ -763,6 +778,7 @@ declarative parameter to `print` method argument is detailed below:
- Single: `Type`
- Optional: `Type`
- Variadic: `TypeRange`
- VariadicOfVariadic: `TypeRangeRange`
* `attr-dict` Directive: `DictionaryAttr`
When a variable is optional, the provided value may be null.
@ -923,7 +939,7 @@ be defined.
When this boolean field is set to `true`, it indicates that the op implements a
`canonicalize` method for simple "matchAndRewrite" style canonicalization
patterns. If `hasCanonicalizer` is 0, then an implementation of
patterns. If `hasCanonicalizer` is 0, then an implementation of
`::getCanonicalizationPatterns()` is implemented to call this function.
### `hasFolder`

View File

@ -701,23 +701,25 @@ def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
[AttrSizedOperandSegments, DeclareOpInterfaceMethods<BranchOpInterface>,
NoSideEffect]> {
let arguments = (ins I32:$value,
Variadic<AnyType>:$defaultOperands,
Variadic<AnyType>:$caseOperands,
OptionalAttr<ElementsAttr>:$case_values,
OptionalAttr<ElementsAttr>:$case_operand_offsets,
OptionalAttr<ElementsAttr>:$branch_weights);
let arguments = (ins
I32:$value,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<ElementsAttr>:$case_values,
ElementsAttr:$case_operand_segments,
OptionalAttr<ElementsAttr>:$branch_weights
);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations);
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = [{
$value `,`
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
`[` `\n` custom<SwitchOpCases>($case_values, $caseDestinations,
$caseOperands, type($caseOperands),
$case_operand_offsets) `]`
$caseOperands, type($caseOperands)) `]`
attr-dict
}];
@ -734,11 +736,15 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index);
OperandRange getCaseOperands(unsigned index) {
return caseOperands()[index];
}
/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index);
MutableOperandRange getCaseOperandsMutable(unsigned index) {
return caseOperandsMutable()[index];
}
}];
}

View File

@ -1812,14 +1812,17 @@ def SwitchOp : Std_Op<"switch",
```
}];
let arguments = (ins AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
Variadic<AnyType>:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
OptionalAttr<I32ElementsAttr>:$case_operand_offsets);
let arguments = (ins
AnyInteger:$flag,
Variadic<AnyType>:$defaultOperands,
VariadicOfVariadic<AnyType, "case_operand_segments">:$caseOperands,
OptionalAttr<AnyIntElementsAttr>:$case_values,
I32ElementsAttr:$case_operand_segments
);
let successors = (successor
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations);
AnySuccessor:$defaultDestination,
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let builders = [
OpBuilder<(ins "Value":$flag,
"Block *":$defaultDestination,
@ -1849,19 +1852,22 @@ def SwitchOp : Std_Op<"switch",
$case_values,
$caseDestinations,
$caseOperands,
type($caseOperands),
$case_operand_offsets)
type($caseOperands))
`]`
attr-dict
}];
let extraClassDeclaration = [{
/// Return the operands for the case destination block at the given index.
OperandRange getCaseOperands(unsigned index);
OperandRange getCaseOperands(unsigned index) {
return caseOperands()[index];
}
/// Return a mutable range of operands for the case destination block at the
/// given index.
MutableOperandRange getCaseOperandsMutable(unsigned index);
MutableOperandRange getCaseOperandsMutable(unsigned index) {
return caseOperandsMutable()[index];
}
}];
let hasCanonicalizer = 1;

View File

@ -324,6 +324,16 @@ class Variadic<Type type> : TypeConstraint<type.predicate, type.summary> {
Type baseType = type;
}
// A nested variadic type constraint. It expands to zero or more variadic ranges
// of the base type. This class is used for supporting variadic operands and
// results. `variadicSegmentAttrName` should correspond to the name of an
// I32ElementsAttr argument that provides the sizes of the inner variadic
// operand groups.
class VariadicOfVariadic<Type type, string variadicSegmentAttrName>
: Variadic<type> {
string segmentAttrName = variadicSegmentAttrName;
}
// An optional type constraint. It expands to either zero or one of the base
// type. This class is used for supporting optional operands/results.
class Optional<Type type> : TypeConstraint<type.predicate, type.summary> {

View File

@ -267,6 +267,9 @@ LogicalResult verifyZeroSuccessor(Operation *op);
LogicalResult verifyOneSuccessor(Operation *op);
LogicalResult verifyNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyAtLeastNSuccessors(Operation *op, unsigned numSuccessors);
LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
StringRef valueGroupName,
size_t expectedCount);
LogicalResult verifyOperandSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyResultSizeAttr(Operation *op, StringRef sizeAttrName);
LogicalResult verifyNoRegionArguments(Operation *op);

View File

@ -36,12 +36,14 @@ namespace mlir {
class Dialect;
class DictionaryAttr;
class ElementsAttr;
class MutableOperandRangeRange;
class Operation;
struct OperationState;
class OpAsmParser;
class OpAsmParserResult;
class OpAsmPrinter;
class OperandRange;
class OperandRangeRange;
class OpFoldResult;
class ParseResult;
class Pattern;
@ -727,6 +729,10 @@ public:
/// must not be empty.
unsigned getBeginOperandIndex() const;
/// Split this range into a set of contiguous subranges using the given
/// elements attribute, which contains the sizes of the sub ranges.
OperandRangeRange split(ElementsAttr segmentSizes) const;
private:
/// See `llvm::detail::indexed_accessor_range_base` for details.
static OpOperand *offset_base(OpOperand *object, ptrdiff_t index) {
@ -741,6 +747,42 @@ private:
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// OperandRangeRange
/// This class represents a contiguous range of operand ranges, e.g. from a
/// VariadicOfVariadic operand group.
class OperandRangeRange final
: public llvm::indexed_accessor_range<
OperandRangeRange, std::pair<OpOperand *, Attribute>, OperandRange,
OperandRange, OperandRange> {
using OwnerT = std::pair<OpOperand *, Attribute>;
using RangeBaseT =
llvm::indexed_accessor_range<OperandRangeRange, OwnerT, OperandRange,
OperandRange, OperandRange>;
public:
using RangeBaseT::RangeBaseT;
/// Returns the range of types of the values within this range.
TypeRangeRange getTypes() const { return TypeRangeRange(*this); }
auto getType() const { return getTypes(); }
/// Construct a range given a parent set of operands, and an I32 elements
/// attribute containing the sizes of the sub ranges.
OperandRangeRange(OperandRange operands, Attribute operandSegments);
/// Flatten all of the sub ranges into a single contiguous operand range.
OperandRange join() const;
private:
/// See `llvm::indexed_accessor_range` for details.
static OperandRange dereference(const OwnerT &object, ptrdiff_t index);
/// Allow access to `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//
// MutableOperandRange
@ -761,8 +803,9 @@ public:
MutableOperandRange(Operation *owner);
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange slice(unsigned subStart, unsigned subLen,
Optional<OperandSegment> segment = llvm::None);
MutableOperandRange
slice(unsigned subStart, unsigned subLen,
Optional<OperandSegment> segment = llvm::None) const;
/// Append the given values to the range.
void append(ValueRange values);
@ -782,12 +825,19 @@ public:
/// Returns the current size of the range.
unsigned size() const { return length; }
/// Returns if the current range is empty.
bool empty() const { return size() == 0; }
/// Allow implicit conversion to an OperandRange.
operator OperandRange() const;
/// Returns the owning operation.
Operation *getOwner() const { return owner; }
/// Split this range into a set of contiguous subranges using the given
/// elements attribute, which contains the sizes of the sub ranges.
MutableOperandRangeRange split(NamedAttribute segmentSizes) const;
private:
/// Update the length of this range to the one provided.
void updateLength(unsigned newLength);
@ -801,7 +851,46 @@ private:
/// Optional set of operand segments that should be updated when mutating the
/// length of this range.
SmallVector<std::pair<unsigned, NamedAttribute>, 1> operandSegments;
SmallVector<OperandSegment, 1> operandSegments;
};
//===----------------------------------------------------------------------===//
// MutableOperandRangeRange
/// This class represents a contiguous range of mutable operand ranges, e.g.
/// from a VariadicOfVariadic operand group.
class MutableOperandRangeRange final
: public llvm::indexed_accessor_range<
MutableOperandRangeRange,
std::pair<MutableOperandRange, NamedAttribute>, MutableOperandRange,
MutableOperandRange, MutableOperandRange> {
using OwnerT = std::pair<MutableOperandRange, NamedAttribute>;
using RangeBaseT =
llvm::indexed_accessor_range<MutableOperandRangeRange, OwnerT,
MutableOperandRange, MutableOperandRange,
MutableOperandRange>;
public:
using RangeBaseT::RangeBaseT;
/// Construct a range given a parent set of operands, and an I32 tensor
/// elements attribute containing the sizes of the sub ranges.
MutableOperandRangeRange(const MutableOperandRange &operands,
NamedAttribute operandSegmentAttr);
/// Flatten all of the sub ranges into a single contiguous mutable operand
/// range.
MutableOperandRange join() const;
/// Allow implicit conversion to an OperandRangeRange.
operator OperandRangeRange() const;
private:
/// See `llvm::indexed_accessor_range` for details.
static MutableOperandRange dereference(const OwnerT &object, ptrdiff_t index);
/// Allow access to `dereference_iterator`.
friend RangeBaseT;
};
//===----------------------------------------------------------------------===//

View File

@ -16,6 +16,7 @@
#include "mlir/IR/Types.h"
#include "mlir/IR/Value.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/Sequence.h"
namespace mlir {
class OperandRange;
@ -88,6 +89,35 @@ inline raw_ostream &operator<<(raw_ostream &os, const TypeRange &types) {
return os;
}
//===----------------------------------------------------------------------===//
// TypeRangeRange
using TypeRangeRangeIterator =
llvm::mapped_iterator<llvm::iota_range<unsigned>::iterator,
std::function<TypeRange(unsigned)>>;
/// This class provides an abstraction for a range of TypeRange. This is useful
/// when accessing the types of a range of ranges, such as when using
/// OperandRangeRange.
class TypeRangeRange : public llvm::iterator_range<TypeRangeRangeIterator> {
public:
template <typename RangeT>
TypeRangeRange(const RangeT &range)
: TypeRangeRange(llvm::seq<unsigned>(0, range.size()), range) {}
private:
template <typename RangeT>
TypeRangeRange(llvm::iota_range<unsigned> sizeRange, const RangeT &range)
: llvm::iterator_range<TypeRangeRangeIterator>(
{sizeRange.begin(), getRangeFn(range)},
{sizeRange.end(), nullptr}) {}
template <typename RangeT>
static std::function<TypeRange(unsigned)> getRangeFn(const RangeT &range) {
return [=](unsigned index) -> TypeRange { return TypeRange(range[index]); };
}
};
//===----------------------------------------------------------------------===//
// ValueTypeRange

View File

@ -48,6 +48,8 @@ struct NamedTypeConstraint {
bool isOptional() const;
// Returns true if this operand/result is variadic.
bool isVariadic() const;
// Returns true if this operand/result is a variadic of a variadic constraint.
bool isVariadicOfVariadic() const;
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }

View File

@ -40,6 +40,13 @@ public:
// Returns true if this is a variadic type constraint.
bool isVariadic() const;
// Returns true if this is a nested variadic type constraint.
bool isVariadicOfVariadic() const;
// Return the segment size attribute used if this is a variadic of variadic
// constraint. Asserts isVariadicOfVariadic() is true.
StringRef getVariadicOfVariadicSegmentSizeAttr() const;
// Returns true if this is a variable length type constraint. This is either
// variadic or optional.
bool isVariableLength() const { return isOptional() || isVariadic(); }

View File

@ -520,7 +520,7 @@ public:
/*defaultOperands=*/ValueRange(),
/*caseValues=*/caseValues,
/*caseDestinations=*/caseDest,
/*caseOperands=*/ArrayRef<ValueRange>(),
/*caseOperands=*/ArrayRef<ValueRange>({ValueRange(), ValueRange()}),
/*branchWeights=*/ArrayRef<int32_t>());
return success();

View File

@ -32,6 +32,7 @@
#include "llvm/Support/SourceMgr.h"
#include <iostream>
#include <numeric>
using namespace mlir;
using namespace mlir::LLVM;
@ -235,41 +236,27 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
ArrayRef<int32_t> caseValues, BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands,
ArrayRef<int32_t> branchWeights) {
SmallVector<Value> flattenedCaseOperands;
SmallVector<int32_t> caseOperandOffsets;
int32_t offset = 0;
for (ValueRange operands : caseOperands) {
flattenedCaseOperands.append(operands.begin(), operands.end());
caseOperandOffsets.push_back(offset);
offset += operands.size();
}
ElementsAttr caseValuesAttr;
if (!caseValues.empty())
caseValuesAttr = builder.getI32VectorAttr(caseValues);
ElementsAttr caseOperandOffsetsAttr;
if (!caseOperandOffsets.empty())
caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
ElementsAttr weightsAttr;
if (!branchWeights.empty())
weightsAttr = builder.getI32VectorAttr(llvm::to_vector<4>(branchWeights));
build(builder, result, value, defaultOperands, flattenedCaseOperands,
caseValuesAttr, caseOperandOffsetsAttr, weightsAttr, defaultDestination,
caseDestinations);
build(builder, result, value, defaultOperands, caseOperands, caseValuesAttr,
weightsAttr, defaultDestination, caseDestinations);
}
/// <cases> ::= integer `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )?
static ParseResult
parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
SmallVectorImpl<Type> &caseOperandTypes,
ElementsAttr &caseOperandOffsets) {
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, ElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
SmallVector<int32_t> values;
SmallVector<int32_t> offsets;
int32_t value, offset = 0;
int32_t value = 0;
do {
OptionalParseResult integerParseResult = parser.parseOptionalInteger(value);
if (values.empty() && !integerParseResult.hasValue())
@ -281,32 +268,28 @@ parseSwitchOpCases(OpAsmParser &parser, ElementsAttr &caseValues,
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
SmallVector<Type> operandTypes;
if (parser.parseColon() || parser.parseSuccessor(destination))
return failure();
if (!parser.parseOptionalLParen()) {
if (parser.parseRegionArgumentList(operands) ||
parser.parseColonTypeList(caseOperandTypes) || parser.parseRParen())
parser.parseColonTypeList(operandTypes) || parser.parseRParen())
return failure();
}
caseDestinations.push_back(destination);
caseOperands.append(operands.begin(), operands.end());
offsets.push_back(offset);
offset += operands.size();
caseOperands.emplace_back(operands);
caseOperandTypes.emplace_back(operandTypes);
} while (!parser.parseOptionalComma());
Builder &builder = parser.getBuilder();
caseValues = builder.getI32VectorAttr(values);
caseOperandOffsets = builder.getI32VectorAttr(offsets);
caseValues = parser.getBuilder().getI32VectorAttr(values);
return success();
}
static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
ElementsAttr caseValues,
SuccessorRange caseDestinations,
OperandRange caseOperands,
TypeRange caseOperandTypes,
ElementsAttr caseOperandOffsets) {
OperandRangeRange caseOperands,
TypeRangeRange caseOperandTypes) {
if (!caseValues)
return;
@ -317,7 +300,7 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op,
p << " ";
p << std::get<0>(i).getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(std::get<1>(i), op.getCaseOperands(index++));
p.printSuccessorAndUseList(std::get<1>(i), caseOperands[index++]);
},
[&] {
p << ',';
@ -341,28 +324,6 @@ static LogicalResult verify(SwitchOp op) {
return success();
}
OperandRange SwitchOp::getCaseOperands(unsigned index) {
return getCaseOperandsMutable(index);
}
MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
MutableOperandRange caseOperands = caseOperandsMutable();
if (!case_operand_offsets()) {
assert(caseOperands.size() == 0 &&
"non-empty case operands must have offsets");
return caseOperands;
}
ElementsAttr offsets = case_operand_offsets().getValue();
assert(index < offsets.size() && "invalid case operand offset index");
int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
int64_t end = index + 1 == offsets.size()
? caseOperands.size()
: offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
return caseOperandsMutable().slice(begin, end - begin);
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");

View File

@ -28,6 +28,7 @@
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/raw_ostream.h"
#include <numeric>
#include "mlir/Dialect/StandardOps/IR/OpsDialect.cpp.inc"
@ -2130,21 +2131,8 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
DenseIntElementsAttr caseValues,
BlockRange caseDestinations,
ArrayRef<ValueRange> caseOperands) {
SmallVector<Value> flattenedCaseOperands;
SmallVector<int32_t> caseOperandOffsets;
int32_t offset = 0;
for (ValueRange operands : caseOperands) {
flattenedCaseOperands.append(operands.begin(), operands.end());
caseOperandOffsets.push_back(offset);
offset += operands.size();
}
DenseIntElementsAttr caseOperandOffsetsAttr;
if (!caseOperandOffsets.empty())
caseOperandOffsetsAttr = builder.getI32VectorAttr(caseOperandOffsets);
build(builder, result, value, defaultOperands, flattenedCaseOperands,
caseValues, caseOperandOffsetsAttr, defaultDestination,
caseDestinations);
build(builder, result, value, defaultOperands, caseOperands, caseValues,
defaultDestination, caseDestinations);
}
void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
@ -2163,16 +2151,14 @@ void SwitchOp::build(OpBuilder &builder, OperationState &result, Value value,
/// <cases> ::= `default` `:` bb-id (`(` ssa-use-and-type-list `)`)?
/// ( `,` integer `:` bb-id (`(` ssa-use-and-type-list `)`)? )*
static ParseResult
parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
Block *&defaultDestination,
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes,
DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<OpAsmParser::OperandType> &caseOperands,
SmallVectorImpl<Type> &caseOperandTypes,
DenseIntElementsAttr &caseOperandOffsets) {
static ParseResult parseSwitchOpCases(
OpAsmParser &parser, Type &flagType, Block *&defaultDestination,
SmallVectorImpl<OpAsmParser::OperandType> &defaultOperands,
SmallVectorImpl<Type> &defaultOperandTypes,
DenseIntElementsAttr &caseValues,
SmallVectorImpl<Block *> &caseDestinations,
SmallVectorImpl<SmallVector<OpAsmParser::OperandType>> &caseOperands,
SmallVectorImpl<SmallVector<Type>> &caseOperandTypes) {
if (failed(parser.parseKeyword("default")) || failed(parser.parseColon()) ||
failed(parser.parseSuccessor(defaultDestination)))
return failure();
@ -2184,9 +2170,7 @@ parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
}
SmallVector<APInt> values;
SmallVector<int32_t> offsets;
unsigned bitWidth = flagType.getIntOrFloatBitWidth();
int64_t offset = 0;
while (succeeded(parser.parseOptionalComma())) {
int64_t value = 0;
if (failed(parser.parseInteger(value)))
@ -2195,30 +2179,26 @@ parseSwitchOpCases(OpAsmParser &parser, Type &flagType,
Block *destination;
SmallVector<OpAsmParser::OperandType> operands;
SmallVector<Type> operandTypes;
if (failed(parser.parseColon()) ||
failed(parser.parseSuccessor(destination)))
return failure();
if (succeeded(parser.parseOptionalLParen())) {
if (failed(parser.parseRegionArgumentList(operands)) ||
failed(parser.parseColonTypeList(caseOperandTypes)) ||
failed(parser.parseColonTypeList(operandTypes)) ||
failed(parser.parseRParen()))
return failure();
}
caseDestinations.push_back(destination);
caseOperands.append(operands.begin(), operands.end());
offsets.push_back(offset);
offset += operands.size();
caseOperands.emplace_back(operands);
caseOperandTypes.emplace_back(operandTypes);
}
if (values.empty())
return success();
Builder &builder = parser.getBuilder();
ShapedType caseValueType =
VectorType::get(static_cast<int64_t>(values.size()), flagType);
caseValues = DenseIntElementsAttr::get(caseValueType, values);
caseOperandOffsets = builder.getI32VectorAttr(offsets);
if (!values.empty()) {
ShapedType caseValueType =
VectorType::get(static_cast<int64_t>(values.size()), flagType);
caseValues = DenseIntElementsAttr::get(caseValueType, values);
}
return success();
}
@ -2226,8 +2206,7 @@ static void printSwitchOpCases(
OpAsmPrinter &p, SwitchOp op, Type flagType, Block *defaultDestination,
OperandRange defaultOperands, TypeRange defaultOperandTypes,
DenseIntElementsAttr caseValues, SuccessorRange caseDestinations,
OperandRange caseOperands, TypeRange caseOperandTypes,
ElementsAttr caseOperandOffsets) {
OperandRangeRange caseOperands, TypeRangeRange caseOperandTypes) {
p << " default: ";
p.printSuccessorAndUseList(defaultDestination, defaultOperands);
@ -2240,7 +2219,7 @@ static void printSwitchOpCases(
p << " ";
p << caseValues.getValue<APInt>(i).getLimitedValue();
p << ": ";
p.printSuccessorAndUseList(caseDestinations[i], op.getCaseOperands(i));
p.printSuccessorAndUseList(caseDestinations[i], caseOperands[i]);
}
p.printNewline();
}
@ -2268,28 +2247,6 @@ static LogicalResult verify(SwitchOp op) {
return success();
}
OperandRange SwitchOp::getCaseOperands(unsigned index) {
return getCaseOperandsMutable(index);
}
MutableOperandRange SwitchOp::getCaseOperandsMutable(unsigned index) {
MutableOperandRange caseOperands = caseOperandsMutable();
if (!case_operand_offsets()) {
assert(caseOperands.size() == 0 &&
"non-empty case operands must have offsets");
return caseOperands;
}
ElementsAttr offsets = case_operand_offsets().getValue();
assert(index < offsets.size() && "invalid case operand offset index");
int64_t begin = offsets.getValue(index).cast<IntegerAttr>().getInt();
int64_t end = index + 1 == offsets.size()
? caseOperands.size()
: offsets.getValue(index + 1).cast<IntegerAttr>().getInt();
return caseOperandsMutable().slice(begin, end - begin);
}
Optional<MutableOperandRange>
SwitchOp::getMutableSuccessorOperands(unsigned index) {
assert(index < getNumSuccessors() && "invalid successor index");

View File

@ -996,16 +996,19 @@ OpTrait::impl::verifyResultsAreSignlessIntegerLike(Operation *op) {
return success();
}
static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
bool isOperand) {
LogicalResult OpTrait::impl::verifyValueSizeAttr(Operation *op,
StringRef attrName,
StringRef valueGroupName,
size_t expectedCount) {
auto sizeAttr = op->getAttrOfType<DenseIntElementsAttr>(attrName);
if (!sizeAttr)
return op->emitOpError("requires 1D vector attribute '") << attrName << "'";
return op->emitOpError("requires 1D i32 elements attribute '")
<< attrName << "'";
auto sizeAttrType = sizeAttr.getType().dyn_cast<VectorType>();
if (!sizeAttrType || sizeAttrType.getRank() != 1 ||
auto sizeAttrType = sizeAttr.getType();
if (sizeAttrType.getRank() != 1 ||
!sizeAttrType.getElementType().isInteger(32))
return op->emitOpError("requires 1D vector of i32 attribute '")
return op->emitOpError("requires 1D i32 elements attribute '")
<< attrName << "'";
if (llvm::any_of(sizeAttr.getIntValues(), [](const APInt &element) {
@ -1018,25 +1021,22 @@ static LogicalResult verifyValueSizeAttr(Operation *op, StringRef attrName,
sizeAttr.begin(), sizeAttr.end(), 0,
[](unsigned all, APInt one) { return all + one.getZExtValue(); });
if (isOperand && totalCount != op->getNumOperands())
return op->emitOpError("operand count (")
<< op->getNumOperands() << ") does not match with the total size ("
<< totalCount << ") specified in attribute '" << attrName << "'";
else if (!isOperand && totalCount != op->getNumResults())
return op->emitOpError("result count (")
<< op->getNumResults() << ") does not match with the total size ("
<< totalCount << ") specified in attribute '" << attrName << "'";
if (totalCount != expectedCount)
return op->emitOpError()
<< valueGroupName << " count (" << expectedCount
<< ") does not match with the total size (" << totalCount
<< ") specified in attribute '" << attrName << "'";
return success();
}
LogicalResult OpTrait::impl::verifyOperandSizeAttr(Operation *op,
StringRef attrName) {
return verifyValueSizeAttr(op, attrName, /*isOperand=*/true);
return verifyValueSizeAttr(op, attrName, "operand", op->getNumOperands());
}
LogicalResult OpTrait::impl::verifyResultSizeAttr(Operation *op,
StringRef attrName) {
return verifyValueSizeAttr(op, attrName, /*isOperand=*/false);
return verifyValueSizeAttr(op, attrName, "result", op->getNumResults());
}
LogicalResult OpTrait::impl::verifyNoRegionArguments(Operation *op) {

View File

@ -12,9 +12,11 @@
//===----------------------------------------------------------------------===//
#include "mlir/IR/OperationSupport.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/OpDefinition.h"
#include "llvm/ADT/BitVector.h"
#include <numeric>
using namespace mlir;
@ -394,13 +396,38 @@ MutableArrayRef<OpOperand> detail::OperandStorage::resize(Operation *owner,
OperandRange::OperandRange(Operation *op)
: OperandRange(op->getOpOperands().data(), op->getNumOperands()) {}
/// Return the operand index of the first element of this range. The range
/// must not be empty.
unsigned OperandRange::getBeginOperandIndex() const {
assert(!empty() && "range must not be empty");
return base->getOperandNumber();
}
OperandRangeRange OperandRange::split(ElementsAttr segmentSizes) const {
return OperandRangeRange(*this, segmentSizes);
}
//===----------------------------------------------------------------------===//
// OperandRangeRange
OperandRangeRange::OperandRangeRange(OperandRange operands,
Attribute operandSegments)
: OperandRangeRange(OwnerT(operands.getBase(), operandSegments), 0,
operandSegments.cast<DenseElementsAttr>().size()) {}
OperandRange OperandRangeRange::join() const {
const OwnerT &owner = getBase();
auto sizeData = owner.second.cast<DenseElementsAttr>().getValues<uint32_t>();
return OperandRange(owner.first,
std::accumulate(sizeData.begin(), sizeData.end(), 0));
}
OperandRange OperandRangeRange::dereference(const OwnerT &object,
ptrdiff_t index) {
auto sizeData = object.second.cast<DenseElementsAttr>().getValues<uint32_t>();
uint32_t startIndex =
std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
return OperandRange(object.first + startIndex, *(sizeData.begin() + index));
}
//===----------------------------------------------------------------------===//
// MutableOperandRange
@ -419,7 +446,7 @@ MutableOperandRange::MutableOperandRange(Operation *owner)
/// Slice this range into a sub range, with the additional operand segment.
MutableOperandRange
MutableOperandRange::slice(unsigned subStart, unsigned subLen,
Optional<OperandSegment> segment) {
Optional<OperandSegment> segment) const {
assert((subStart + subLen) <= length && "invalid sub-range");
MutableOperandRange subSlice(owner, start + subStart, subLen,
operandSegments);
@ -475,6 +502,11 @@ MutableOperandRange::operator OperandRange() const {
return owner->getOperands().slice(start, length);
}
MutableOperandRangeRange
MutableOperandRange::split(NamedAttribute segmentSizes) const {
return MutableOperandRangeRange(*this, segmentSizes);
}
/// Update the length of this range to the one provided.
void MutableOperandRange::updateLength(unsigned newLength) {
int32_t diff = int32_t(newLength) - int32_t(length);
@ -490,6 +522,35 @@ void MutableOperandRange::updateLength(unsigned newLength) {
}
}
//===----------------------------------------------------------------------===//
// MutableOperandRangeRange
MutableOperandRangeRange::MutableOperandRangeRange(
const MutableOperandRange &operands, NamedAttribute operandSegmentAttr)
: MutableOperandRangeRange(
OwnerT(operands, operandSegmentAttr), 0,
operandSegmentAttr.second.cast<DenseElementsAttr>().size()) {}
MutableOperandRange MutableOperandRangeRange::join() const {
return getBase().first;
}
MutableOperandRangeRange::operator OperandRangeRange() const {
return OperandRangeRange(getBase().first,
getBase().second.second.cast<DenseElementsAttr>());
}
MutableOperandRange MutableOperandRangeRange::dereference(const OwnerT &object,
ptrdiff_t index) {
auto sizeData =
object.second.second.cast<DenseElementsAttr>().getValues<uint32_t>();
uint32_t startIndex =
std::accumulate(sizeData.begin(), sizeData.begin() + index, 0);
return object.first.slice(
startIndex, *(sizeData.begin() + index),
MutableOperandRange::OperandSegment(index, object.second));
}
//===----------------------------------------------------------------------===//
// ValueRange

View File

@ -12,6 +12,10 @@
using namespace mlir;
using namespace mlir::tblgen;
//===----------------------------------------------------------------------===//
// NamedTypeConstraint
//===----------------------------------------------------------------------===//
bool NamedTypeConstraint::hasPredicate() const {
return !constraint.getPredicate().isNull();
}
@ -19,3 +23,7 @@ bool NamedTypeConstraint::hasPredicate() const {
bool NamedTypeConstraint::isOptional() const { return constraint.isOptional(); }
bool NamedTypeConstraint::isVariadic() const { return constraint.isVariadic(); }
bool NamedTypeConstraint::isVariadicOfVariadic() const {
return constraint.isVariadicOfVariadic();
}

View File

@ -458,6 +458,13 @@ void Operator::populateOpStructure() {
results.push_back({name, TypeConstraint(resultDef)});
if (!name.empty())
argumentsAndResultsIndex[name] = resultIndex(i);
// We currently only support VariadicOfVariadic operands.
if (results.back().constraint.isVariadicOfVariadic()) {
PrintFatalError(
def.getLoc(),
"'VariadicOfVariadic' results are currently not supported");
}
}
// Handle successors
@ -577,8 +584,7 @@ bool Operator::hasAssemblyFormat() const {
StringRef Operator::getAssemblyFormat() const {
return TypeSwitch<llvm::Init *, StringRef>(def.getValueInit("assemblyFormat"))
.Case<llvm::StringInit>(
[&](auto *init) { return init->getValue(); });
.Case<llvm::StringInit>([&](auto *init) { return init->getValue(); });
}
void Operator::print(llvm::raw_ostream &os) const {

View File

@ -36,6 +36,15 @@ bool TypeConstraint::isVariadic() const {
return def->isSubClassOf("Variadic");
}
bool TypeConstraint::isVariadicOfVariadic() const {
return def->isSubClassOf("VariadicOfVariadic");
}
StringRef TypeConstraint::getVariadicOfVariadicSegmentSizeAttr() const {
assert(isVariadicOfVariadic());
return def->getValueAsString("segmentAttrName");
}
// Returns the builder call for this constraint if this is a buildable type,
// returns None otherwise.
Optional<StringRef> TypeConstraint::getBuilderCall() const {

View File

@ -375,28 +375,28 @@ func private @foo()
// -----
func @failedMissingOperandSizeAttr(%arg: i32) {
// expected-error @+1 {{requires 1D vector attribute 'operand_segment_sizes'}}
// expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) : (i32, i32, i32, i32) -> ()
}
// -----
func @failedOperandSizeAttrWrongType(%arg: i32) {
// expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : (i32, i32, i32, i32) -> ()
// expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = 10} : (i32, i32, i32, i32) -> ()
}
// -----
func @failedOperandSizeAttrWrongRank(%arg: i32) {
// expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}}
// expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : (i32, i32, i32, i32) -> ()
}
// -----
func @failedOperandSizeAttrWrongElementType(%arg: i32) {
// expected-error @+1 {{requires 1D vector of i32 attribute 'operand_segment_sizes'}}
// expected-error @+1 {{requires 1D i32 elements attribute 'operand_segment_sizes'}}
"test.attr_sized_operands"(%arg, %arg, %arg, %arg) {operand_segment_sizes = dense<[1, 1, 1, 1]>: vector<4xi64>} : (i32, i32, i32, i32) -> ()
}
@ -432,28 +432,28 @@ func @succeededOperandSizeAttr(%arg: i32) {
// -----
func @failedMissingResultSizeAttr() {
// expected-error @+1 {{requires 1D vector attribute 'result_segment_sizes'}}
// expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() : () -> (i32, i32, i32, i32)
}
// -----
func @failedResultSizeAttrWrongType() {
// expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, 1, 1]>: tensor<4xi32>} : () -> (i32, i32, i32, i32)
// expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = 10} : () -> (i32, i32, i32, i32)
}
// -----
func @failedResultSizeAttrWrongRank() {
// expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}}
// expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[[1, 1], [1, 1]]>: vector<2x2xi32>} : () -> (i32, i32, i32, i32)
}
// -----
func @failedResultSizeAttrWrongElementType() {
// expected-error @+1 {{requires 1D vector of i32 attribute 'result_segment_sizes'}}
// expected-error @+1 {{requires 1D i32 elements attribute 'result_segment_sizes'}}
%0:4 = "test.attr_sized_results"() {result_segment_sizes = dense<[1, 1, 1, 1]>: vector<4xi64>} : () -> (i32, i32, i32, i32)
}

View File

@ -1661,6 +1661,14 @@ def FormatVariadicOperand : TEST_Op<"format_variadic_operand"> {
let arguments = (ins Variadic<I64>:$operand);
let assemblyFormat = [{ $operand `:` type($operand) attr-dict}];
}
def FormatVariadicOfVariadicOperand
: TEST_Op<"format_variadic_of_variadic_operand"> {
let arguments = (ins
VariadicOfVariadic<I64, "operand_segments">:$operand,
I32ElementsAttr:$operand_segments
);
let assemblyFormat = [{ $operand `:` type($operand) attr-dict}];
}
def FormatMultipleVariadicOperands :
TEST_Op<"format_multiple_variadic_operands", [AttrSizedOperandSegments]> {

View File

@ -151,6 +151,9 @@ test.format_operand_e_op %i64, %memref : i64, memref<1xf64>
// CHECK: test.format_variadic_operand %[[I64]], %[[I64]], %[[I64]] : i64, i64, i64
test.format_variadic_operand %i64, %i64, %i64 : i64, i64, i64
// CHECK: test.format_variadic_of_variadic_operand (%[[I64]], %[[I64]]), (), (%[[I64]]) : (i64, i64), (), (i64)
test.format_variadic_of_variadic_operand (%i64, %i64), (), (%i64) : (i64, i64), (), (i64)
// CHECK: test.format_multiple_variadic_operands (%[[I64]], %[[I64]], %[[I64]]), (%[[I64]], %[[I32]] : i64, i32)
test.format_multiple_variadic_operands (%i64, %i64, %i64), (%i64, %i32 : i64, i32)

View File

@ -24,6 +24,7 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringSet.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@ -89,6 +90,23 @@ const char *attrSizedSegmentValueRangeCalcCode = R"(
unsigned size = *(sizeAttrValues.begin() + index);
return {start, size};
)";
// The logic to calculate the actual value range for a declared operand
// of an op with variadic of variadic operands within the OpAdaptor.
//
// {0}: The name of the segment attribute.
// {1}: The index of the main operand.
const char *variadicOfVariadicAdaptorCalcCode = R"(
auto tblgenTmpOperands = getODSOperands({1});
auto sizeAttrValues = {0}().getValues<uint32_t>();
auto sizeAttrIt = sizeAttrValues.begin();
::llvm::SmallVector<::mlir::ValueRange> tblgenTmpOperandGroups;
for (int i = 0, e = ::llvm::size(sizeAttrValues); i < e; ++i, ++sizeAttrIt) {{
tblgenTmpOperandGroups.push_back(tblgenTmpOperands.take_front(*sizeAttrIt));
tblgenTmpOperands = tblgenTmpOperands.drop_front(*sizeAttrIt);
}
return tblgenTmpOperandGroups;
)";
// The logic to build a range of either operand or result values.
//
@ -256,16 +274,20 @@ private:
// Builds the parameter list for build() method of this op. This method writes
// to `paramList` the comma-separated parameter list and updates
// `resultTypeNames` with the names for parameters for specifying result
// types. The given `typeParamKind` and `attrParamKind` controls how result
// types and attributes are placed in the parameter list.
// types. `inferredAttributes` is populated with any attributes that are
// elided from the build list. The given `typeParamKind` and `attrParamKind`
// controls how result types and attributes are placed in the parameter list.
void buildParamList(llvm::SmallVectorImpl<OpMethodParameter> &paramList,
llvm::StringSet<> &inferredAttributes,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
AttrParamKind attrParamKind = AttrParamKind::WrappedAttr);
// Adds op arguments and regions into operation state for build() methods.
void genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr = false);
void
genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
llvm::StringSet<> &inferredAttributes,
bool isRawValueAttr = false);
// Generates canonicalizer declaration for the operation.
void genCanonicalizerDecls();
@ -783,7 +805,7 @@ generateValueRangeStartAndEnd(Class &opClass, StringRef methodName,
// of ops, in particular for one-operand ops that may not have the
// `getOperand(unsigned)` method.
static void generateNamedOperandGetters(const Operator &op, Class &opClass,
StringRef sizeAttrInit,
bool isAdaptor, StringRef sizeAttrInit,
StringRef rangeType,
StringRef rangeBeginCall,
StringRef rangeSizeCall,
@ -838,6 +860,20 @@ static void generateNamedOperandGetters(const Operator &op, Class &opClass,
m->body()
<< " auto operands = getODSOperands(" << i << ");\n"
<< " return operands.empty() ? ::mlir::Value() : *operands.begin();";
} else if (operand.isVariadicOfVariadic()) {
StringRef segmentAttr =
operand.constraint.getVariadicOfVariadicSegmentSizeAttr();
if (isAdaptor) {
m = opClass.addMethodAndPrune("::llvm::SmallVector<::mlir::ValueRange>",
operand.name);
m->body() << llvm::formatv(variadicOfVariadicAdaptorCalcCode,
segmentAttr, i);
continue;
}
m = opClass.addMethodAndPrune("::mlir::OperandRangeRange", operand.name);
m->body() << " return getODSOperands(" << i << ").split(" << segmentAttr
<< "Attr());";
} else if (operand.isVariadic()) {
m = opClass.addMethodAndPrune(rangeType, operand.name);
m->body() << " return getODSOperands(" << i << ");";
@ -860,6 +896,7 @@ void OpEmitter::genNamedOperandGetters() {
generateNamedOperandGetters(
op, opClass,
/*isAdaptor=*/false,
/*sizeAttrInit=*/attrSizeInitCode,
/*rangeType=*/"::mlir::Operation::operand_range",
/*rangeBeginCall=*/"getOperation()->operand_begin()",
@ -874,17 +911,32 @@ void OpEmitter::genNamedOperandSetters() {
const auto &operand = op.getOperand(i);
if (operand.name.empty())
continue;
auto *m = opClass.addMethodAndPrune("::mlir::MutableOperandRange",
auto *m = opClass.addMethodAndPrune(operand.isVariadicOfVariadic()
? "::mlir::MutableOperandRangeRange"
: "::mlir::MutableOperandRange",
(operand.name + "Mutable").str());
auto &body = m->body();
body << " auto range = getODSOperandIndexAndLength(" << i << ");\n"
<< " return ::mlir::MutableOperandRange(getOperation(), "
<< " auto mutableRange = ::mlir::MutableOperandRange(getOperation(), "
"range.first, range.second";
if (attrSizedOperands)
body << ", ::mlir::MutableOperandRange::OperandSegment(" << i
<< "u, *getOperation()->getAttrDictionary().getNamed("
"operand_segment_sizesAttrName()))";
body << ");\n";
// If this operand is a nested variadic, we split the range into a
// MutableOperandRangeRange that provides a range over all of the
// sub-ranges.
if (operand.isVariadicOfVariadic()) {
body << " return "
"mutableRange.split(*(*this)->getAttrDictionary().getNamed("
<< operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
<< "AttrName()));\n";
} else {
// Otherwise, we use the full range directly.
body << " return mutableRange;\n";
}
}
}
@ -1038,7 +1090,9 @@ void OpEmitter::genSeparateArgParamBuilder() {
bool inferType) {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, paramKind, attrType);
llvm::StringSet<> inferredAttributes;
buildParamList(paramList, inferredAttributes, resultNames, paramKind,
attrType);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
@ -1046,8 +1100,9 @@ void OpEmitter::genSeparateArgParamBuilder() {
if (!m)
return;
auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder(
body, /*isRawValueAttr=*/attrType == AttrParamKind::UnwrappedValue);
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes,
/*isRawValueAttr=*/attrType ==
AttrParamKind::UnwrappedValue);
// Push all result types to the operation state
@ -1215,7 +1270,9 @@ void OpEmitter::genInferredTypeCollectiveParamBuilder() {
void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
llvm::SmallVector<OpMethodParameter, 4> paramList;
llvm::SmallVector<std::string, 4> resultNames;
buildParamList(paramList, resultNames, TypeParamKind::None);
llvm::StringSet<> inferredAttributes;
buildParamList(paramList, inferredAttributes, resultNames,
TypeParamKind::None);
auto *m = opClass.addMethodAndPrune("void", "build", OpMethod::MP_Static,
std::move(paramList));
@ -1223,7 +1280,7 @@ void OpEmitter::genUseOperandAsResultTypeSeparateParamBuilder() {
if (!m)
return;
auto &body = m->body();
genCodeForAddingArgAndRegionForBuilder(body);
genCodeForAddingArgAndRegionForBuilder(body, inferredAttributes);
auto numResults = op.getNumResults();
if (numResults == 0)
@ -1415,6 +1472,7 @@ void OpEmitter::genCollectiveParamBuilder() {
}
void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
llvm::StringSet<> &inferredAttributes,
SmallVectorImpl<std::string> &resultTypeNames,
TypeParamKind typeParamKind,
AttrParamKind attrParamKind) {
@ -1453,10 +1511,6 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
}
// Add parameters for all arguments (operands and attributes).
int numOperands = 0;
int numAttrs = 0;
int defaultValuedAttrStartIndex = op.getNumArgs();
if (attrParamKind == AttrParamKind::UnwrappedValue) {
// Calculate the start index from which we can attach default values in the
@ -1482,56 +1536,70 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
}
}
for (int i = 0, e = op.getNumArgs(); i < e; ++i) {
auto argument = op.getArg(i);
if (argument.is<tblgen::NamedTypeConstraint *>()) {
const auto &operand = op.getOperand(numOperands);
StringRef type =
operand.isVariadic() ? "::mlir::ValueRange" : "::mlir::Value";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (operand.isOptional())
properties = OpMethodParameter::PP_Optional;
paramList.emplace_back(type, getArgumentName(op, numOperands),
properties);
++numOperands;
} else {
const auto &namedAttr = op.getAttribute(numAttrs);
const auto &attr = namedAttr.attr;
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (attr.isOptional())
properties = OpMethodParameter::PP_Optional;
StringRef type;
switch (attrParamKind) {
case AttrParamKind::WrappedAttr:
type = attr.getStorageType();
break;
case AttrParamKind::UnwrappedValue:
if (canUseUnwrappedRawValue(attr))
type = attr.getReturnType();
else
type = attr.getStorageType();
break;
}
std::string defaultValue;
// Attach default value if requested and possible.
if (attrParamKind == AttrParamKind::UnwrappedValue &&
i >= defaultValuedAttrStartIndex) {
bool isString = attr.getReturnType() == "::llvm::StringRef";
if (isString)
defaultValue.append("\"");
defaultValue += attr.getDefaultValue();
if (isString)
defaultValue.append("\"");
}
paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
++numAttrs;
/// Collect any inferred attributes.
for (const NamedTypeConstraint &operand : op.getOperands()) {
if (operand.isVariadicOfVariadic()) {
inferredAttributes.insert(
operand.constraint.getVariadicOfVariadicSegmentSizeAttr());
}
}
for (int i = 0, e = op.getNumArgs(), numOperands = 0; i < e; ++i) {
Argument arg = op.getArg(i);
if (const auto *operand = arg.dyn_cast<NamedTypeConstraint *>()) {
StringRef type;
if (operand->isVariadicOfVariadic())
type = "::llvm::ArrayRef<::mlir::ValueRange>";
else if (operand->isVariadic())
type = "::mlir::ValueRange";
else
type = "::mlir::Value";
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (operand->isOptional())
properties = OpMethodParameter::PP_Optional;
paramList.emplace_back(type, getArgumentName(op, numOperands++),
properties);
continue;
}
const NamedAttribute &namedAttr = *arg.get<NamedAttribute *>();
const Attribute &attr = namedAttr.attr;
// inferred attributes don't need to be added to the param list.
if (inferredAttributes.contains(namedAttr.name))
continue;
OpMethodParameter::Property properties = OpMethodParameter::PP_None;
if (attr.isOptional())
properties = OpMethodParameter::PP_Optional;
StringRef type;
switch (attrParamKind) {
case AttrParamKind::WrappedAttr:
type = attr.getStorageType();
break;
case AttrParamKind::UnwrappedValue:
if (canUseUnwrappedRawValue(attr))
type = attr.getReturnType();
else
type = attr.getStorageType();
break;
}
// Attach default value if requested and possible.
std::string defaultValue;
if (attrParamKind == AttrParamKind::UnwrappedValue &&
i >= defaultValuedAttrStartIndex) {
bool isString = attr.getReturnType() == "::llvm::StringRef";
if (isString)
defaultValue.append("\"");
defaultValue += attr.getDefaultValue();
if (isString)
defaultValue.append("\"");
}
paramList.emplace_back(type, namedAttr.name, defaultValue, properties);
}
/// Insert parameters for each successor.
for (const NamedSuccessor &succ : op.getSuccessors()) {
StringRef type =
@ -1546,12 +1614,31 @@ void OpEmitter::buildParamList(SmallVectorImpl<OpMethodParameter> &paramList,
llvm::formatv("{0}Count", region.name).str());
}
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
bool isRawValueAttr) {
void OpEmitter::genCodeForAddingArgAndRegionForBuilder(
OpMethodBody &body, llvm::StringSet<> &inferredAttributes,
bool isRawValueAttr) {
// Push all operands to the result.
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
std::string argName = getArgumentName(op, i);
if (op.getOperand(i).isOptional())
NamedTypeConstraint &operand = op.getOperand(i);
if (operand.constraint.isVariadicOfVariadic()) {
body << " for (::mlir::ValueRange range : " << argName << ")\n "
<< builderOpState << ".addOperands(range);\n";
// Add the segment attribute.
body << " {\n"
<< " SmallVector<int32_t> rangeSegments;\n"
<< " for (::mlir::ValueRange range : " << argName << ")\n"
<< " rangeSegments.push_back(range.size());\n"
<< " " << builderOpState << ".addAttribute("
<< operand.constraint.getVariadicOfVariadicSegmentSizeAttr()
<< "AttrName(" << builderOpState << ".name), " << odsBuilder
<< ".getI32TensorAttr(rangeSegments));"
<< " }\n";
continue;
}
if (operand.isOptional())
body << " if (" << argName << ")\n ";
body << " " << builderOpState << ".addOperands(" << argName << ");\n";
}
@ -1563,12 +1650,24 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
<< ".name), "
<< "odsBuilder.getI32VectorAttr({";
interleaveComma(llvm::seq<int>(0, op.getNumOperands()), body, [&](int i) {
if (op.getOperand(i).isOptional())
body << "(" << getArgumentName(op, i) << " ? 1 : 0)";
else if (op.getOperand(i).isVariadic())
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
else
const NamedTypeConstraint &operand = op.getOperand(i);
if (!operand.isVariableLength()) {
body << "1";
return;
}
std::string operandName = getArgumentName(op, i);
if (operand.isOptional()) {
body << "(" << operandName << " ? 1 : 0)";
} else if (operand.isVariadicOfVariadic()) {
body << llvm::formatv(
"static_cast<int32_t>(std::accumulate({0}.begin(), {0}.end(), 0, "
"[](int32_t curSum, ::mlir::ValueRange range) {{ return curSum + "
"range.size(); }))",
operandName);
} else {
body << "static_cast<int32_t>(" << getArgumentName(op, i) << ".size())";
}
});
body << "}));\n";
}
@ -1576,38 +1675,38 @@ void OpEmitter::genCodeForAddingArgAndRegionForBuilder(OpMethodBody &body,
// Push all attributes to the result.
for (const auto &namedAttr : op.getAttributes()) {
auto &attr = namedAttr.attr;
if (!attr.isDerivedAttr()) {
bool emitNotNullCheck = attr.isOptional();
if (emitNotNullCheck)
body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
if (attr.isDerivedAttr() || inferredAttributes.contains(namedAttr.name))
continue;
if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
// If this is a raw value, then we need to wrap it in an Attribute
// instance.
FmtContext fctx;
fctx.withBuilder("odsBuilder");
bool emitNotNullCheck = attr.isOptional();
if (emitNotNullCheck)
body << formatv(" if ({0}) ", namedAttr.name) << "{\n";
std::string builderTemplate =
std::string(attr.getConstBuilderTemplate());
if (isRawValueAttr && canUseUnwrappedRawValue(attr)) {
// If this is a raw value, then we need to wrap it in an Attribute
// instance.
FmtContext fctx;
fctx.withBuilder("odsBuilder");
// For StringAttr, its constant builder call will wrap the input in
// quotes, which is correct for normal string literals, but incorrect
// here given we use function arguments. So we need to strip the
// wrapping quotes.
if (StringRef(builderTemplate).contains("\"$0\""))
builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
std::string builderTemplate = std::string(attr.getConstBuilderTemplate());
std::string value =
std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
builderOpState, namedAttr.name, value);
} else {
body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n",
builderOpState, namedAttr.name);
}
if (emitNotNullCheck)
body << " }\n";
// For StringAttr, its constant builder call will wrap the input in
// quotes, which is correct for normal string literals, but incorrect
// here given we use function arguments. So we need to strip the
// wrapping quotes.
if (StringRef(builderTemplate).contains("\"$0\""))
builderTemplate = replaceAllSubstrs(builderTemplate, "\"$0\"", "$0");
std::string value =
std::string(tgfmt(builderTemplate, &fctx, namedAttr.name));
body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {2});\n",
builderOpState, namedAttr.name, value);
} else {
body << formatv(" {0}.addAttribute({1}AttrName({0}.name), {1});\n",
builderOpState, namedAttr.name);
}
if (emitNotNullCheck)
body << " }\n";
}
// Create the correct number of regions.
@ -1960,9 +2059,12 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
body << " unsigned index = 0; (void)index;\n";
for (auto staticValue : llvm::enumerate(values)) {
bool hasPredicate = staticValue.value().hasPredicate();
bool isOptional = staticValue.value().isOptional();
if (!hasPredicate && !isOptional)
const NamedTypeConstraint &value = staticValue.value();
bool hasPredicate = value.hasPredicate();
bool isOptional = value.isOptional();
bool isVariadicOfVariadic = value.isVariadicOfVariadic();
if (!hasPredicate && !isOptional && !isVariadicOfVariadic)
continue;
body << formatv(" auto valueGroup{2} = getODS{0}{1}s({2});\n",
// Capitalize the first letter to match the function name
@ -1977,14 +2079,21 @@ void OpEmitter::genOperandResultVerifier(OpMethodBody &body,
"<< index << \" requires 0 or 1 element, but found \" << "
"valueGroup{0}.size();\n",
staticValue.index(), valueKind);
} else if (isVariadicOfVariadic) {
body << formatv(
" if (::mlir::failed(::mlir::OpTrait::impl::verifyValueSizeAttr("
"*this, \"{0}\", \"{1}\", valueGroup{2}.size())))\n"
" return ::mlir::failure();\n",
value.constraint.getVariadicOfVariadicSegmentSizeAttr(), value.name,
staticValue.index());
}
// Otherwise, if there is no predicate there is nothing left to do.
if (!hasPredicate)
continue;
// Emit a loop to check all the dynamic values in the pack.
StringRef constraintFn = staticVerifierEmitter.getTypeConstraintFn(
staticValue.value().constraint);
StringRef constraintFn =
staticVerifierEmitter.getTypeConstraintFn(value.constraint);
body << " for (::mlir::Value v : valueGroup" << staticValue.index()
<< ") {\n"
<< " if (::mlir::failed(" << constraintFn
@ -2257,7 +2366,8 @@ OpOperandAdaptorEmitter::OpOperandAdaptorEmitter(const Operator &op)
}
std::string sizeAttrInit =
formatv(adapterSegmentSizeAttrInitCode, "operand_segment_sizes");
generateNamedOperandGetters(op, adaptor, sizeAttrInit,
generateNamedOperandGetters(op, adaptor,
/*isAdaptor=*/true, sizeAttrInit,
/*rangeType=*/"::mlir::ValueRange",
/*rangeBeginCall=*/"odsOperands.begin()",
/*rangeSizeCall=*/"odsOperands.size()",

View File

@ -497,6 +497,7 @@ struct OperationFormat {
/// The set of attributes explicitly used within the format.
SmallVector<const NamedAttribute *, 8> usedAttributes;
llvm::StringSet<> inferredAttributes;
};
} // end anonymous namespace
@ -616,10 +617,38 @@ const char *const operandParserCode = R"(
if (parser.parseOperand({0}RawOperands[0]))
return ::mlir::failure();
)";
/// The code snippet used to generate a parser call for a VariadicOfVariadic
/// operand.
///
/// {0}: The name of the operand.
/// {1}: The name of segment size attribute.
const char *const variadicOfVariadicOperandParserCode = R"(
{
{0}OperandsLoc = parser.getCurrentLocation();
int32_t curSize = 0;
do {
if (parser.parseOptionalLParen())
break;
if (parser.parseOperandList({0}Operands) || parser.parseRParen())
return ::mlir::failure();
{0}OperandGroupSizes.push_back({0}Operands.size() - curSize);
curSize = {0}Operands.size();
} while (succeeded(parser.parseOptionalComma()));
}
)";
/// The code snippet used to generate a parser call for a type list.
///
/// {0}: The name for the type list.
const char *const variadicOfVariadicTypeParserCode = R"(
do {
if (parser.parseOptionalLParen())
break;
if (parser.parseOptionalRParen() &&
(parser.parseTypeList({0}Types) || parser.parseRParen()))
return ::mlir::failure();
} while (succeeded(parser.parseOptionalComma()));
)";
const char *const variadicTypeParserCode = R"(
if (parser.parseTypeList({0}Types))
return ::mlir::failure();
@ -758,6 +787,9 @@ const char *successorParserCode = R"(
namespace {
/// The type of length for a given parse argument.
enum class ArgumentLengthKind {
/// The argument is a variadic of a variadic, and may contain 0->N range
/// elements.
VariadicOfVariadic,
/// The argument is variadic, and may contain 0->N elements.
Variadic,
/// The argument is optional, and may contain 0 or 1 elements.
@ -772,6 +804,8 @@ static ArgumentLengthKind
getArgumentLengthKind(const NamedTypeConstraint *var) {
if (var->isOptional())
return ArgumentLengthKind::Optional;
if (var->isVariadicOfVariadic())
return ArgumentLengthKind::VariadicOfVariadic;
if (var->isVariadic())
return ArgumentLengthKind::Variadic;
return ArgumentLengthKind::Single;
@ -863,6 +897,10 @@ static void genElementParserStorage(Element *element, OpMethodBody &body) {
if (operand->getVar()->isVariableLength()) {
body << " ::mlir::SmallVector<::mlir::OpAsmParser::OperandType, 4> "
<< name << "Operands;\n";
if (operand->getVar()->isVariadicOfVariadic()) {
body << " llvm::SmallVector<int32_t> " << name
<< "OperandGroupSizes;\n";
}
} else {
body << " ::mlir::OpAsmParser::OperandType " << name
<< "RawOperands[1];\n"
@ -924,7 +962,9 @@ static void genCustomParameterParser(Element &param, OpMethodBody &body) {
} else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
StringRef name = operand->getVar()->name;
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
if (lengthKind == ArgumentLengthKind::Variadic)
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv("{0}OperandGroups", name);
else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Operands", name);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Operand", name);
@ -951,7 +991,9 @@ static void genCustomParameterParser(Element &param, OpMethodBody &body) {
} else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind == ArgumentLengthKind::Variadic)
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv("{0}TypeGroups", listName);
else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv("{0}Types", listName);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv("{0}Type", listName);
@ -972,19 +1014,32 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
// * Set the location of operand variables.
for (Element &param : dir->getArguments()) {
if (auto *operand = dyn_cast<OperandVariable>(&param)) {
body << " " << operand->getVar()->name
auto *var = operand->getVar();
body << " " << var->name
<< "OperandsLoc = parser.getCurrentLocation();\n";
if (operand->getVar()->isOptional()) {
if (var->isOptional()) {
body << llvm::formatv(
" llvm::Optional<::mlir::OpAsmParser::OperandType> "
"{0}Operand;\n",
operand->getVar()->name);
var->name);
} else if (var->isVariadicOfVariadic()) {
body << llvm::formatv(" "
"llvm::SmallVector<llvm::SmallVector<::mlir::"
"OpAsmParser::OperandType>> "
"{0}OperandGroups;\n",
var->name);
}
} else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind == ArgumentLengthKind::Optional)
if (lengthKind == ArgumentLengthKind::Optional) {
body << llvm::formatv(" ::mlir::Type {0}Type;\n", listName);
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(
" llvm::SmallVector<llvm::SmallVector<::mlir::Type>> "
"{0}TypeGroups;\n",
listName);
}
} else if (auto *dir = dyn_cast<RefDirective>(&param)) {
Element *input = dir->getOperand();
if (auto *operand = dyn_cast<OperandVariable>(input)) {
@ -1028,11 +1083,18 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
var->name);
} else if (auto *operand = dyn_cast<OperandVariable>(&param)) {
const NamedTypeConstraint *var = operand->getVar();
if (!var->isOptional())
continue;
body << llvm::formatv(" if ({0}Operand.hasValue())\n"
" {0}Operands.push_back(*{0}Operand);\n",
var->name);
if (var->isOptional()) {
body << llvm::formatv(" if ({0}Operand.hasValue())\n"
" {0}Operands.push_back(*{0}Operand);\n",
var->name);
} else if (var->isVariadicOfVariadic()) {
body << llvm::formatv(
" for (const auto &subRange : {0}OperandGroups) {{\n"
" {0}Operands.append(subRange.begin(), subRange.end());\n"
" {0}OperandGroupSizes.push_back(subRange.size());\n"
" }\n",
var->name, var->constraint.getVariadicOfVariadicSegmentSizeAttr());
}
} else if (auto *dir = dyn_cast<TypeDirective>(&param)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
@ -1040,6 +1102,11 @@ static void genCustomDirectiveParser(CustomDirective *dir, OpMethodBody &body) {
body << llvm::formatv(" if ({0}Type)\n"
" {0}Types.push_back({0}Type);\n",
listName);
} else if (lengthKind == ArgumentLengthKind::VariadicOfVariadic) {
body << llvm::formatv(
" for (const auto &subRange : {0}TypeGroups)\n"
" {0}Types.append(subRange.begin(), subRange.end());\n",
listName);
}
}
}
@ -1229,7 +1296,11 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
ArgumentLengthKind lengthKind = getArgumentLengthKind(operand->getVar());
StringRef name = operand->getVar()->name;
if (lengthKind == ArgumentLengthKind::Variadic)
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv(
variadicOfVariadicOperandParserCode, name,
operand->getVar()->constraint.getVariadicOfVariadicSegmentSizeAttr());
else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv(variadicOperandParserCode, name);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(optionalOperandParserCode, name);
@ -1281,7 +1352,9 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body,
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
ArgumentLengthKind lengthKind;
StringRef listName = getTypeListName(dir->getOperand(), lengthKind);
if (lengthKind == ArgumentLengthKind::Variadic)
if (lengthKind == ArgumentLengthKind::VariadicOfVariadic)
body << llvm::formatv(variadicOfVariadicTypeParserCode, listName);
else if (lengthKind == ArgumentLengthKind::Variadic)
body << llvm::formatv(variadicTypeParserCode, listName);
else if (lengthKind == ArgumentLengthKind::Optional)
body << llvm::formatv(optionalTypeParserCode, listName);
@ -1501,19 +1574,29 @@ void OperationFormat::genParserSuccessorResolution(Operator &op,
void OperationFormat::genParserVariadicSegmentResolution(Operator &op,
OpMethodBody &body) {
if (!allOperands &&
op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
body << " result.addAttribute(\"operand_segment_sizes\", "
<< "parser.getBuilder().getI32VectorAttr({";
auto interleaveFn = [&](const NamedTypeConstraint &operand) {
// If the operand is variadic emit the parsed size.
if (operand.isVariableLength())
body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
else
body << "1";
};
llvm::interleaveComma(op.getOperands(), body, interleaveFn);
body << "}));\n";
if (!allOperands) {
if (op.getTrait("::mlir::OpTrait::AttrSizedOperandSegments")) {
body << " result.addAttribute(\"operand_segment_sizes\", "
<< "parser.getBuilder().getI32VectorAttr({";
auto interleaveFn = [&](const NamedTypeConstraint &operand) {
// If the operand is variadic emit the parsed size.
if (operand.isVariableLength())
body << "static_cast<int32_t>(" << operand.name << "Operands.size())";
else
body << "1";
};
llvm::interleaveComma(op.getOperands(), body, interleaveFn);
body << "}));\n";
}
for (const NamedTypeConstraint &operand : op.getOperands()) {
if (!operand.isVariadicOfVariadic())
continue;
body << llvm::formatv(
" result.addAttribute(\"{0}\", "
"parser.getBuilder().getI32TensorAttr({1}OperandGroupSizes));\n",
operand.constraint.getVariadicOfVariadicSegmentSizeAttr(),
operand.name);
}
}
if (!allResultTypes &&
@ -1575,6 +1658,10 @@ static void genAttrDictPrinter(OperationFormat &fmt, Operator &op,
if (!fmt.allResultTypes &&
op.getTrait("::mlir::OpTrait::AttrSizedResultSegments"))
body << "\"result_segment_sizes\", ";
if (!fmt.inferredAttributes.empty()) {
for (const auto &attr : fmt.inferredAttributes)
body << "\"" << attr.getKey() << "\", ";
}
llvm::interleaveComma(
fmt.usedAttributes, body,
[&](const NamedAttribute *attr) { body << "\"" << attr->name << "\""; });
@ -1693,6 +1780,8 @@ static OpMethodBody &genTypeOperandPrinter(Element *arg, OpMethodBody &body) {
return body << "getOperation()->getResultTypes()";
auto *operand = dyn_cast<OperandVariable>(arg);
auto *var = operand ? operand->getVar() : cast<ResultVariable>(arg)->getVar();
if (var->isVariadicOfVariadic())
return body << llvm::formatv("{0}().join().getTypes()", var->name);
if (var->isVariadic())
return body << var->name << "().getTypes()";
if (var->isOptional())
@ -1896,7 +1985,12 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
else
body << " p.printAttribute(" << var->name << "Attr());\n";
} else if (auto *operand = dyn_cast<OperandVariable>(element)) {
if (operand->getVar()->isOptional()) {
if (operand->getVar()->isVariadicOfVariadic()) {
body << " ::llvm::interleaveComma(" << operand->getVar()->name
<< "(), p, [&](const auto &operands) { p << \"(\" << operands << "
"\")\"; });\n";
} else if (operand->getVar()->isOptional()) {
body << " if (::mlir::Value value = " << operand->getVar()->name
<< "())\n"
<< " p << value;\n";
@ -1926,6 +2020,15 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body,
} else if (isa<SuccessorsDirective>(element)) {
body << " ::llvm::interleaveComma(getOperation()->getSuccessors(), p);\n";
} else if (auto *dir = dyn_cast<TypeDirective>(element)) {
if (auto *operand = dyn_cast<OperandVariable>(dir->getOperand())) {
if (operand->getVar()->isVariadicOfVariadic()) {
body << llvm::formatv(" ::llvm::interleaveComma({0}().getTypes(), p, "
"[&](::mlir::TypeRange types) {{ p << \"(\" << "
"types << \")\"; });\n",
operand->getVar()->name);
return;
}
}
body << " p << ";
genTypeOperandPrinter(dir->getOperand(), body) << ";\n";
} else if (auto *dir = dyn_cast<FunctionalTypeDirective>(element)) {
@ -2449,6 +2552,16 @@ LogicalResult FormatParser::verifyAttributes(llvm::SMLoc loc) {
while (!iteratorStack.empty())
if (failed(verifyAttributes(loc, iteratorStack)))
return ::mlir::failure();
// Check for VariadicOfVariadic variables. The segment attribute of those
// variables will be infered.
for (const NamedTypeConstraint *var : seenOperands) {
if (var->constraint.isVariadicOfVariadic()) {
fmt.inferredAttributes.insert(
var->constraint.getVariadicOfVariadicSegmentSizeAttr());
}
}
return ::mlir::success();
}
/// Verify the attribute elements at the back of the given stack of iterators.