Add a "kind" attribute to ContractionOp and OuterProductOp.

Currently, vector.contract joins the intermediate result and the accumulator
argument (of ranks K) using summation. We desire more joining operations ---
such as max --- to help vector.contract express reductions. This change extends
Vector_ContractionOp to take an optional attribute (called "kind", of enum type
CombiningKind) specifying the joining operation to be add/mul/min/max for int/fp
, and and/or/xor for int only. By default this attribute has value "add".

To implement this we also need to extend vector.outerproduct, since
vector.contract gets transformed to vector.outerproduct (and that to
vector.fma). The extension for vector.outerproduct is also an optional kind
attribute that uses the same enum type and possible values. The default is
"add". In case of max/min we transform vector.outerproduct to a combination of
compare and select.

Reviewed By: nicolasvasilache

Differential Revision: https://reviews.llvm.org/D93280
This commit is contained in:
Praveen Narayanan 2021-02-12 20:14:51 +00:00 committed by Mehdi Amini
parent 48fcce1aea
commit a65fb1916c
12 changed files with 469 additions and 68 deletions

View File

@ -1,2 +1,8 @@
add_mlir_dialect(VectorOps vector)
add_mlir_doc(VectorOps -gen-op-doc VectorOps Dialects/)
set(LLVM_TARGET_DEFINITIONS VectorOps.td)
mlir_tablegen(VectorOpsEnums.h.inc -gen-enum-decls)
mlir_tablegen(VectorOpsEnums.cpp.inc -gen-enum-defs)
add_public_tablegen_target(MLIRVectorOpsEnumsIncGen)
add_dependencies(mlir-headers MLIRVectorOpsEnumsIncGen)

View File

@ -21,11 +21,21 @@
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Interfaces/VectorInterfaces.h"
#include "mlir/Interfaces/ViewLikeInterface.h"
#include "llvm/ADT/StringExtras.h"
// Pull in all enum type definitions and utility function declarations.
#include "mlir/Dialect/Vector/VectorOpsEnums.h.inc"
namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
namespace vector {
class VectorDialect;
namespace detail {
struct BitmaskEnumStorage;
} // namespace detail
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
@ -63,6 +73,22 @@ void populateBubbleVectorBitCastOpPatterns(OwningRewritePatternList &patterns,
void populateVectorSlicesLoweringPatterns(OwningRewritePatternList &patterns,
MLIRContext *context);
/// An attribute that specifies the combining function for `vector.contract`,
/// and `vector.reduction`.
class CombiningKindAttr
: public Attribute::AttrBase<CombiningKindAttr, Attribute,
detail::BitmaskEnumStorage> {
public:
using Base::Base;
static CombiningKindAttr get(CombiningKind kind, MLIRContext *context);
CombiningKind getKind() const;
void print(DialectAsmPrinter &p) const;
static Attribute parse(DialectAsmParser &parser);
};
/// Enum to control the lowering of `vector.contract` operations.
enum class VectorContractLowering {
/// Progressively lower to finer grained `vector.contract` and dot-products.

View File

@ -37,6 +37,35 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
// The "kind" of combining function for contractions and reductions.
def COMBINING_KIND_ADD : BitEnumAttrCase<"ADD", 0x1, "add">;
def COMBINING_KIND_MUL : BitEnumAttrCase<"MUL", 0x2, "mul">;
def COMBINING_KIND_MIN : BitEnumAttrCase<"MIN", 0x4, "min">;
def COMBINING_KIND_MAX : BitEnumAttrCase<"MAX", 0x8, "max">;
def COMBINING_KIND_AND : BitEnumAttrCase<"AND", 0x10, "and">;
def COMBINING_KIND_OR : BitEnumAttrCase<"OR", 0x20, "or">;
def COMBINING_KIND_XOR : BitEnumAttrCase<"XOR", 0x40, "xor">;
def CombiningKind : BitEnumAttr<
"CombiningKind",
"Kind of combining function for contractions and reductions",
[COMBINING_KIND_ADD, COMBINING_KIND_MUL, COMBINING_KIND_MIN,
COMBINING_KIND_MAX, COMBINING_KIND_AND, COMBINING_KIND_OR,
COMBINING_KIND_XOR]> {
let cppNamespace = "::mlir::vector";
}
def Vector_CombiningKindAttr : DialectAttr<
Vector_Dialect,
CPred<"$_self.isa<::mlir::vector::CombiningKindAttr>()">,
"Kind of combining function for contractions and reductions"> {
let storageType = "::mlir::vector::CombiningKindAttr";
let returnType = "::mlir::vector::CombiningKind";
let convertFromStorage = "$_self.getKind()";
let constBuilderCall =
"::mlir::vector::CombiningKindAttr::get($0, $_builder.getContext())";
}
// TODO: Add an attribute to specify a different algebra with operators other
// than the current set: {*, +}.
def Vector_ContractionOp :
@ -49,7 +78,9 @@ def Vector_ContractionOp :
]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types,
DefaultValuedAttr<Vector_CombiningKindAttr,
"CombiningKind::ADD">:$kind)>,
Results<(outs AnyType)> {
let summary = "vector contraction operation";
let description = [{
@ -88,6 +119,11 @@ def Vector_ContractionOp :
and acc arguments. An indexing map attribute specifies a mapping from each
iterator in the iterator type list, to each dimension of an N-D vector.
An optional kind attribute may be used to specify the combining function
between the intermediate result and accumulator argument of rank K. This
attribute can take the values add/mul/min/max for int/fp, and/or/xor for
int only. The default is "add".
Example:
```mlir
@ -146,6 +182,20 @@ def Vector_ContractionOp :
// types than accumulator/result.
%6 = vector.contract #contraction_trait %0, %1, %2
: vector<10xf16>, vector<10xf16> into f32
// Contract with max (K = 0).
#contraction_accesses = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
affine_map<(i) -> ()>
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["reduction"],
kind = #vector.kind<max>
}
%7 = vector.contract #contraction_trait %0, %1, %2
: vector<10xf32>, vector<10xf32> into f32
```
}];
let builders = [
@ -189,6 +239,12 @@ def Vector_ContractionOp :
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
static constexpr StringRef getKindAttrName() { return "kind"; }
static CombiningKind getDefaultKind() {
return CombiningKind::ADD;
}
}];
}
@ -820,7 +876,9 @@ def Vector_OuterProductOp :
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"rhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
Arguments<(ins AnyVector:$lhs, AnyType:$rhs, Variadic<AnyVector>:$acc)>,
Arguments<(ins AnyVector:$lhs, AnyType:$rhs,
Variadic<AnyVector>:$acc,
DefaultValuedAttr<Vector_CombiningKindAttr, "CombiningKind::ADD">:$kind)>,
Results<(outs AnyVector)> {
let summary = "vector outerproduct with optional fused add";
let description = [{
@ -846,6 +904,12 @@ def Vector_OuterProductOp :
lowered to the LLVMIR dialect, this form emits `llvm.intr.fma`, which
is guaranteed to lower to actual `fma` instructions on x86.
An optional kind attribute may be specified to be add/mul/min/max
for int/fp, and and/or/xor for int only. The default is "add", in which
case the operation returns a fused multiply-add. In other cases it returns
a multiply followed by the appropriate operation (for example, a compare and
select for "max").
Example:
```
@ -856,6 +920,10 @@ def Vector_OuterProductOp :
vector<4xf32>, vector<8xf32>, vector<4x8xf32>
return %3: vector<4x8xf32>
%4 = vector.outerproduct %0, %1, %2 {kind = #vector.kind<max>}:
vector<4xf32>, vector<8xf32>, vector<4x8xf32>
return %3: vector<4x8xf32>
%6 = vector.outerproduct %4, %5: vector<10xf32>, f32
return %6: vector<10xf32>
@ -880,6 +948,12 @@ def Vector_OuterProductOp :
VectorType getVectorType() {
return getResult().getType().cast<VectorType>();
}
static constexpr StringRef getKindAttrName() {
return "kind";
}
static CombiningKind getDefaultKind() {
return CombiningKind::ADD;
}
}];
}

View File

@ -1131,9 +1131,9 @@ class I64EnumAttrCase<string sym, int val, string str = sym>
// A bit enum case stored with 32-bit IntegerAttr. `val` here is *not* the
// ordinal number of the bit that is set. It is the 32-bit integer with only
// one bit set.
class BitEnumAttrCase<string sym, int val> :
EnumAttrCaseInfo<sym, val, sym>,
SignlessIntegerAttrBase<I32, "case " # sym> {
class BitEnumAttrCase<string sym, int val, string str = sym> :
EnumAttrCaseInfo<sym, val, str>,
SignlessIntegerAttrBase<I32, "case " # str> {
let predicate = CPred<
"$_self.cast<::mlir::IntegerAttr>().getValue().getZExtValue() & "
# val # "u">;

View File

@ -10,6 +10,7 @@ add_mlir_dialect_library(MLIRVector
DEPENDS
MLIRVectorOpsIncGen
MLIRVectorOpsEnumsIncGen
LINK_LIBS PUBLIC
MLIRAffineEDSC

View File

@ -19,6 +19,7 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/DialectImplementation.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/TypeUtilities.h"
@ -28,6 +29,9 @@
#include "llvm/ADT/bit.h"
#include <numeric>
// Pull in all enum type and utility function definitions.
#include "mlir/Dialect/Vector/VectorOpsEnums.cpp.inc"
using namespace mlir;
using namespace mlir::vector;
@ -77,11 +81,30 @@ static MaskFormat get1DMaskFormat(Value mask) {
return MaskFormat::Unknown;
}
// Helper for verifying combining kinds in contractions and reductions.
static bool isSupportedCombiningKind(CombiningKind combiningKind,
Type elementType) {
switch (combiningKind) {
case CombiningKind::ADD:
case CombiningKind::MUL:
case CombiningKind::MIN:
case CombiningKind::MAX:
return elementType.isIntOrIndexOrFloat();
case CombiningKind::AND:
case CombiningKind::OR:
case CombiningKind::XOR:
return elementType.isIntOrIndex();
}
return false;
}
//===----------------------------------------------------------------------===//
// VectorDialect
//===----------------------------------------------------------------------===//
void VectorDialect::initialize() {
addAttributes<CombiningKindAttr>();
addOperations<
#define GET_OP_LIST
#include "mlir/Dialect/Vector/VectorOps.cpp.inc"
@ -105,6 +128,106 @@ ArrayAttr vector::getVectorSubscriptAttr(Builder &builder,
return builder.getI64ArrayAttr(values);
}
//===----------------------------------------------------------------------===//
// CombiningKindAttr
//===----------------------------------------------------------------------===//
namespace mlir {
namespace vector {
namespace detail {
struct BitmaskEnumStorage : public AttributeStorage {
using KeyTy = uint64_t;
BitmaskEnumStorage(KeyTy val) : value(val) {}
bool operator==(const KeyTy &key) const { return value == key; }
static BitmaskEnumStorage *construct(AttributeStorageAllocator &allocator,
const KeyTy &key) {
return new (allocator.allocate<BitmaskEnumStorage>())
BitmaskEnumStorage(key);
}
KeyTy value = 0;
};
} // namespace detail
} // namespace vector
} // namespace mlir
CombiningKindAttr CombiningKindAttr::get(CombiningKind kind,
MLIRContext *context) {
return Base::get(context, static_cast<uint64_t>(kind));
}
CombiningKind CombiningKindAttr::getKind() const {
return static_cast<CombiningKind>(getImpl()->value);
}
static constexpr const CombiningKind combiningKindsList[] = {
// clang-format off
CombiningKind::ADD,
CombiningKind::MUL,
CombiningKind::MIN,
CombiningKind::MAX,
CombiningKind::AND,
CombiningKind::OR,
CombiningKind::XOR,
// clang-format on
};
void CombiningKindAttr::print(DialectAsmPrinter &printer) const {
printer << "kind<";
auto kinds = llvm::make_filter_range(combiningKindsList, [&](auto kind) {
return bitEnumContains(this->getKind(), kind);
});
llvm::interleaveComma(kinds, printer,
[&](auto kind) { printer << stringifyEnum(kind); });
printer << ">";
}
Attribute CombiningKindAttr::parse(DialectAsmParser &parser) {
if (failed(parser.parseLess()))
return {};
StringRef elemName;
if (failed(parser.parseKeyword(&elemName)))
return {};
auto kind = symbolizeCombiningKind(elemName);
if (!kind) {
parser.emitError(parser.getNameLoc(), "Unknown combining kind: ")
<< elemName;
return {};
}
if (failed(parser.parseGreater()))
return {};
return CombiningKindAttr::get(kind.getValue(),
parser.getBuilder().getContext());
}
Attribute VectorDialect::parseAttribute(DialectAsmParser &parser,
Type type) const {
StringRef attrKind;
if (parser.parseKeyword(&attrKind))
return {};
if (attrKind == "kind")
return CombiningKindAttr::parse(parser);
parser.emitError(parser.getNameLoc(), "Unknown attribute type: ") << attrKind;
return {};
}
void VectorDialect::printAttribute(Attribute attr,
DialectAsmPrinter &os) const {
if (auto ck = attr.dyn_cast<CombiningKindAttr>())
ck.print(os);
else
llvm_unreachable("Unknown attribute type");
}
//===----------------------------------------------------------------------===//
// ReductionOp
//===----------------------------------------------------------------------===//
@ -193,6 +316,9 @@ void vector::ContractionOp::build(OpBuilder &builder, OperationState &result,
result.addTypes(acc.getType());
result.addAttribute(getIndexingMapsAttrName(), indexingMaps);
result.addAttribute(getIteratorTypesAttrName(), iteratorTypes);
result.addAttribute(ContractionOp::getKindAttrName(),
CombiningKindAttr::get(ContractionOp::getDefaultKind(),
builder.getContext()));
}
static ParseResult parseContractionOp(OpAsmParser &parser,
@ -221,6 +347,11 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
return failure();
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
if (!result.attributes.get(ContractionOp::getKindAttrName())) {
result.addAttribute(ContractionOp::getKindAttrName(),
CombiningKindAttr::get(ContractionOp::getDefaultKind(),
result.getContext()));
}
if (masksInfo.empty())
return success();
if (masksInfo.size() != 2)
@ -421,12 +552,20 @@ static LogicalResult verify(ContractionOp op) {
rhsMaskType.getShape().size() != rhsType.getShape().size())
return op.emitOpError("invalid vector mask rank");
}
// Verify supported combining kind.
auto vectorType = resType.dyn_cast<VectorType>();
auto elementType = vectorType ? vectorType.getElementType() : resType;
if (!isSupportedCombiningKind(op.kind(), elementType))
return op.emitOpError("unsupported contraction type");
return success();
}
ArrayRef<StringRef> ContractionOp::getTraitAttrNames() {
static constexpr StringRef names[2] = {getIndexingMapsAttrName(),
getIteratorTypesAttrName()};
static constexpr StringRef names[3] = {getIndexingMapsAttrName(),
getIteratorTypesAttrName(),
ContractionOp::getKindAttrName()};
return llvm::makeArrayRef(names);
}
@ -1497,8 +1636,10 @@ void OuterProductOp::build(OpBuilder &builder, OperationState &result,
static void print(OpAsmPrinter &p, OuterProductOp op) {
p << op.getOperationName() << " " << op.lhs() << ", " << op.rhs();
if (!op.acc().empty())
if (!op.acc().empty()) {
p << ", " << op.acc();
p.printOptionalAttrDict(op.getAttrs());
}
p << " : " << op.lhs().getType() << ", " << op.rhs().getType();
}
@ -1506,8 +1647,10 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
OperationState &result) {
SmallVector<OpAsmParser::OperandType, 3> operandsInfo;
Type tLHS, tRHS;
if (parser.parseOperandList(operandsInfo) || parser.parseColonType(tLHS) ||
parser.parseComma() || parser.parseType(tRHS))
if (parser.parseOperandList(operandsInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(tLHS) || parser.parseComma() ||
parser.parseType(tRHS))
return failure();
if (operandsInfo.size() < 2)
return parser.emitError(parser.getNameLoc(),
@ -1521,6 +1664,14 @@ static ParseResult parseOuterProductOp(OpAsmParser &parser,
vRHS ? VectorType::get({vLHS.getDimSize(0), vRHS.getDimSize(0)},
vLHS.getElementType())
: VectorType::get({vLHS.getDimSize(0)}, vLHS.getElementType());
if (!result.attributes.get(OuterProductOp::getKindAttrName())) {
result.attributes.append(
OuterProductOp::getKindAttrName(),
CombiningKindAttr::get(OuterProductOp::getDefaultKind(),
result.getContext()));
}
return failure(
parser.resolveOperand(operandsInfo[0], tLHS, result.operands) ||
parser.resolveOperand(operandsInfo[1], tRHS, result.operands) ||
@ -1558,6 +1709,11 @@ static LogicalResult verify(OuterProductOp op) {
if (vACC && vACC != vRES)
return op.emitOpError("expected operand #3 of same type as result type");
// Verify supported combining kind.
if (!isSupportedCombiningKind(op.kind(), vRES.getElementType()))
return op.emitOpError("unsupported outerproduct type");
return success();
}

View File

@ -1354,11 +1354,17 @@ public:
Type eltType = resType.getElementType();
bool isInt = eltType.isa<IntegerType>();
Value acc = (op.acc().empty()) ? nullptr : op.acc()[0];
vector::CombiningKind kind = op.kind();
if (!rhsType) {
// Special case: AXPY operation.
Value b = rewriter.create<vector::BroadcastOp>(loc, lhsType, op.rhs());
rewriter.replaceOp(op, genMult(loc, op.lhs(), b, acc, isInt, rewriter));
Optional<Value> mult =
isInt ? genMultI(loc, op.lhs(), b, acc, kind, rewriter)
: genMultF(loc, op.lhs(), b, acc, kind, rewriter);
if (!mult.hasValue())
return failure();
rewriter.replaceOp(op, mult.getValue());
return success();
}
@ -1371,25 +1377,95 @@ public:
Value r = nullptr;
if (acc)
r = rewriter.create<vector::ExtractOp>(loc, rhsType, acc, pos);
Value m = genMult(loc, a, op.rhs(), r, isInt, rewriter);
result = rewriter.create<vector::InsertOp>(loc, resType, m, result, pos);
Optional<Value> m = isInt ? genMultI(loc, a, op.rhs(), r, kind, rewriter)
: genMultF(loc, a, op.rhs(), r, kind, rewriter);
if (!m.hasValue())
return failure();
result = rewriter.create<vector::InsertOp>(loc, resType, m.getValue(),
result, pos);
}
rewriter.replaceOp(op, result);
return success();
}
private:
static Value genMult(Location loc, Value x, Value y, Value acc, bool isInt,
PatternRewriter &rewriter) {
if (acc) {
if (isInt)
return rewriter.create<AddIOp>(loc, rewriter.create<MulIOp>(loc, x, y),
acc);
return rewriter.create<vector::FMAOp>(loc, x, y, acc);
static Optional<Value> genMultI(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind,
PatternRewriter &rewriter) {
using vector::CombiningKind;
MulIOp mul = rewriter.create<MulIOp>(loc, x, y);
if (!acc)
return Optional<Value>(mul);
Value combinedResult;
switch (kind) {
case CombiningKind::ADD:
combinedResult = rewriter.create<AddIOp>(loc, mul, acc);
break;
case CombiningKind::MUL:
combinedResult = rewriter.create<MulIOp>(loc, mul, acc);
break;
case CombiningKind::MIN:
combinedResult = rewriter.create<SelectOp>(
loc, rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, mul, acc), mul,
acc);
break;
case CombiningKind::MAX:
combinedResult = rewriter.create<SelectOp>(
loc, rewriter.create<CmpIOp>(loc, CmpIPredicate::sge, mul, acc), mul,
acc);
break;
case CombiningKind::AND:
combinedResult = rewriter.create<AndOp>(loc, mul, acc);
break;
case CombiningKind::OR:
combinedResult = rewriter.create<OrOp>(loc, mul, acc);
break;
case CombiningKind::XOR:
combinedResult = rewriter.create<XOrOp>(loc, mul, acc);
break;
}
if (isInt)
return rewriter.create<MulIOp>(loc, x, y);
return rewriter.create<MulFOp>(loc, x, y);
return Optional<Value>(combinedResult);
}
static Optional<Value> genMultF(Location loc, Value x, Value y, Value acc,
vector::CombiningKind kind,
PatternRewriter &rewriter) {
using vector::CombiningKind;
// Special case for fused multiply-add.
if (acc && kind == CombiningKind::ADD) {
return Optional<Value>(rewriter.create<vector::FMAOp>(loc, x, y, acc));
}
MulFOp mul = rewriter.create<MulFOp>(loc, x, y);
if (!acc)
return Optional<Value>(mul);
Value combinedResult;
switch (kind) {
case CombiningKind::MUL:
combinedResult = rewriter.create<MulFOp>(loc, mul, acc);
break;
case CombiningKind::MIN:
combinedResult = rewriter.create<SelectOp>(
loc, rewriter.create<CmpFOp>(loc, CmpFPredicate::OLE, mul, acc), mul,
acc);
break;
case CombiningKind::MAX:
combinedResult = rewriter.create<SelectOp>(
loc, rewriter.create<CmpFOp>(loc, CmpFPredicate::OGT, mul, acc), mul,
acc);
break;
case CombiningKind::ADD: // Already handled this special case above.
case CombiningKind::AND: // Only valid for integer types.
case CombiningKind::OR: // Only valid for integer types.
case CombiningKind::XOR: // Only valid for integer types.
return Optional<Value>();
}
return Optional<Value>(combinedResult);
}
};
@ -1804,7 +1880,8 @@ LogicalResult ContractionOpToOuterProductOpLowering::matchAndRewrite(
for (int64_t k = 0; k < reductionSize; ++k) {
Value a = rewriter.create<vector::ExtractOp>(op.getLoc(), lhs, k);
Value b = rewriter.create<vector::ExtractOp>(op.getLoc(), rhs, k);
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), a, b, res);
res = rewriter.create<vector::OuterProductOp>(op.getLoc(), res.getType(), a,
b, res, op.kind());
}
rewriter.replaceOp(op, res);
return success();

View File

@ -355,7 +355,7 @@ func @matmul_tensors(
//
// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
// a later canonicalization fuses the add into vector.contract.
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]] : vector<8x4xf32>, vector<4x12xf32> into vector<8x12xf32>
// CHECK: %[[C2:.*]] = addf %[[V2]], %[[C]] : vector<8x12xf32>
// CHECK: %[[W:.*]] = vector.transfer_write %[[C2]], %[[ARG2]][%[[C0]], %[[C0]]] {masked = [false, false]} : vector<8x12xf32>, tensor<8x12xf32>
%0 = linalg.matmul ins(%arg0, %arg1: tensor<8x4xf32>, tensor<4x12xf32>)
@ -380,7 +380,7 @@ func @matmul_i8_i8_i32(%a: memref<4x6xi8>, %b: memref<6x12xi8>, %c: memref<4x12x
//
// linalg contraction lowers to %tmp = vector.contract %a, %b, %c0 followed by addf %c, %tmp.
// a later canonicalization fuses the add into vector.contract.
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"]} %[[V0]], %[[V1]], %[[VEC_C0]]
// CHECK: %[[C:.*]] = vector.contract {{.*}} iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[V0]], %[[V1]], %[[VEC_C0]]
// CHECK-SAME: vector<4x6xi8>, vector<6x12xi8> into vector<4x12xi8>
// CHECK: %[[C32:.*]] = sexti %[[C]] : vector<4x12xi8> to vector<4x12xi32>
// CHECK: %[[RES:.*]] = addi %[[V2]], %[[C32]] : vector<4x12xi32>

View File

@ -198,13 +198,34 @@ func @extract_strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32
%f0 = constant 0.0: f32
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"]} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<add>} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
%0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
: vector<10xf32>, vector<10xf32> into f32
// CHECK: return %[[X]] : f32
return %0 : f32
}
#contraction_to_scalar_max_accesses = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
affine_map<(i) -> ()>
]
#contraction_to_scalar_max_trait = {
indexing_maps = #contraction_to_scalar_max_accesses,
iterator_types = ["reduction"],
kind = #vector.kind<max>
}
// CHECK-LABEL: @contraction_to_scalar_with_max
func @contraction_to_scalar_with_max(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32
%f0 = constant 0.0: f32
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"], kind = #vector.kind<max>} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
%0 = vector.contract #contraction_to_scalar_max_trait %arg0, %arg1, %f0
: vector<10xf32>, vector<10xf32> into f32
// CHECK: return %[[X]] : f32
return %0 : f32
}
#contraction_accesses0 = [
affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>,
affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,
@ -221,36 +242,46 @@ func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32
// 8, 8, 15, 5
affine_map<(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)>
]
#iterator_types1 = ["parallel", "parallel", "parallel", "parallel", "reduction",
"reduction"]
#contraction_trait1 = {
indexing_maps = #contraction_accesses1,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction",
"reduction"]
iterator_types = #iterator_types1
}
#contraction_trait2 = {
indexing_maps = #contraction_accesses1,
iterator_types = #iterator_types1,
kind = #vector.kind<max>
}
// CHECK-LABEL: @contraction
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x8x15x5xf32>,
%arg4 : vector<7x8x16x15xf16>, %arg5 : vector<8x16x7x5xf16>) {
// Test contraction with batch and contracting dims.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
// Test contraction with only contracting dims. In this case the lhs/rhs
// dimension of size 8 will be considered a parallel dim for lhs/rhs and will
// appear twice in the output.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
%1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
// Test contraction with optional vector mask arguments.
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
%rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
// Test contraction with mixed type.
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<add>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
%3 = vector.contract #contraction_trait1 %arg4, %arg5, %arg3
: vector<7x8x16x15xf16>, vector<8x16x7x5xf16> into vector<8x8x15x5xf32>
// Test contraction with "max" instead of "add".
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"], kind = #vector.kind<max>} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
%4 = vector.contract #contraction_trait2 %arg0, %arg1, %arg3
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x8x15x5xf32>
return
}

View File

@ -9,6 +9,11 @@
indexing_maps = #matvec_accesses,
iterator_types = ["parallel", "reduction"]
}
#matvecmax_trait = {
indexing_maps = #matvec_accesses,
iterator_types = ["parallel", "reduction"],
kind = #vector.kind<max>
}
#mattransvec_accesses = [
affine_map<(i, j) -> (j, i)>,
@ -50,10 +55,10 @@
// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
// CHECK: return
func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
@ -66,6 +71,32 @@ func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
return
}
// CHECK-LABEL: func @matvecmax2x2
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
// CHECK-SAME: %[[C:.*2]]: memref<vector<2xf32>>
// CHECK: %[[T0:.*]] = load %[[A]][] : memref<vector<2x2xf32>>
// CHECK: %[[T1:.*]] = load %[[B]][] : memref<vector<2xf32>>
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<max>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<max>} : vector<2xf32>, f32
// CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
// CHECK: return
func @matvecmax2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
%arg2: memref<vector<2xf32>>) {
%A = load %arg0[] : memref<vector<2x2xf32>>
%x = load %arg1[] : memref<vector<2xf32>>
%b = load %arg2[] : memref<vector<2xf32>>
%0 = vector.contract #matvecmax_trait %A, %x, %b : vector<2x2xf32>, vector<2xf32> into vector<2xf32>
store %0, %arg2[] : memref<vector<2xf32>>
return
}
// CHECK-LABEL: func @mattransvec2x2
// CHECK-SAME: %[[A:.*0]]: memref<vector<2x2xf32>>
// CHECK-SAME: %[[B:.*1]]: memref<vector<2xf32>>
@ -75,10 +106,10 @@ func @matvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: store %[[T8]], %[[C]][] : memref<vector<2xf32>>
// CHECK: return
func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
@ -101,10 +132,10 @@ func @mattransvec2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>
// CHECK: %[[T3:.*]] = vector.transpose %[[T0]], [1, 0] : vector<2x2xf32> to vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T3]][0] : vector<2x2xf32>
// CHECK: %[[T5:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.outerproduct %[[T4]], %[[T5]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T7:.*]] = vector.extract %[[T3]][1] : vector<2x2xf32>
// CHECK: %[[T8:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] : vector<2xf32>, f32
// CHECK: %[[T9:.*]] = vector.outerproduct %[[T7]], %[[T8]], %[[T6]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: store %[[T9]], %[[C]][] : memref<vector<2xf32>>
// CHECK: return
func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
@ -126,10 +157,10 @@ func @vecmat2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,
// CHECK: %[[T2:.*]] = load %[[C]][] : memref<vector<2xf32>>
// CHECK: %[[T3:.*]] = vector.extract %[[T0]][0] : vector<2x2xf32>
// CHECK: %[[T4:.*]] = vector.extract %[[T1]][0] : vector<2xf32>
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] : vector<2xf32>, f32
// CHECK: %[[T5:.*]] = vector.outerproduct %[[T3]], %[[T4]], %[[T2]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: %[[T6:.*]] = vector.extract %[[T0]][1] : vector<2x2xf32>
// CHECK: %[[T7:.*]] = vector.extract %[[T1]][1] : vector<2xf32>
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] : vector<2xf32>, f32
// CHECK: %[[T8:.*]] = vector.outerproduct %[[T6]], %[[T7]], %[[T5]] {kind = #vector.kind<add>} : vector<2xf32>, f32
// CHECK: store %[[T8]], %[[C]][] : memref<vector<2xf32>>
// CHECK: return
func @vecmattrans2x2(%arg0: memref<vector<2x2xf32>>, %arg1: memref<vector<2xf32>>,

View File

@ -92,56 +92,56 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES2]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES5]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG6]], %[[TG7]], %[[R1S00]], %[[TG8]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES1]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES2]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES4]], 2 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[TG13:.*]] = vector.tuple_get %[[ES5]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG10]], %[[TG11]], %[[R2S00]], %[[TG12]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [0, 2]
// CHECK-NEXT: %[[TG14:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG15:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG16:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG1]], %[[TG14]], %[[TG15]], %[[TG4]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[TG17:.*]] = vector.tuple_get %[[ES2]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG18:.*]] = vector.tuple_get %[[ES5]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG6]], %[[TG17]], %[[R1S02]], %[[TG8]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[TG19:.*]] = vector.tuple_get %[[ES2]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG20:.*]] = vector.tuple_get %[[ES5]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG10]], %[[TG19]], %[[R2S02]], %[[TG12]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 0]
// CHECK-NEXT: %[[TG21:.*]] = vector.tuple_get %[[ES1]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG22:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG23:.*]] = vector.tuple_get %[[ES4]], 3 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG21]], %[[TG2]], %[[TG22]], %[[TG23]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[TG24:.*]] = vector.tuple_get %[[ES1]], 4 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG25:.*]] = vector.tuple_get %[[ES4]], 4 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG24]], %[[TG7]], %[[R1S20]], %[[TG25]], %[[TG9]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[TG26:.*]] = vector.tuple_get %[[ES1]], 5 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG27:.*]] = vector.tuple_get %[[ES4]], 5 : tuple<vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG26]], %[[TG11]], %[[R2S20]], %[[TG27]], %[[TG13]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 2]
// CHECK-NEXT: %[[TG28:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG21]], %[[TG14]], %[[TG28]], %[[TG23]], %[[TG16]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG24]], %[[TG17]], %[[R1S22]], %[[TG25]], %[[TG18]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3S22:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"], kind = #vector.kind<add>} %[[TG26]], %[[TG19]], %[[R2S22]], %[[TG27]], %[[TG20]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R3S00]], %[[R3S02]], %[[R3S20]], %[[R3S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
@ -187,26 +187,26 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
// CHECK-NEXT: %[[TG3:.*]] = vector.tuple_get %[[ES3]], 0 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG4:.*]] = vector.tuple_get %[[ES4]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[TG5:.*]] = vector.tuple_get %[[ES5]], 0 : tuple<vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[TG1]], %[[TG2]], %[[TG3]], %[[TG4]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [0, 2]
// CHECK-NEXT: %[[TG6:.*]] = vector.tuple_get %[[ES2]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG7:.*]] = vector.tuple_get %[[ES3]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG8:.*]] = vector.tuple_get %[[ES5]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[TG1]], %[[TG6]], %[[TG7]], %[[TG4]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 0]
// CHECK-NEXT: %[[TG9:.*]] = vector.tuple_get %[[ES1]], 1 : tuple<vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG10:.*]] = vector.tuple_get %[[ES3]], 2 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[TG11:.*]] = vector.tuple_get %[[ES4]], 1 : tuple<vector<2x2xi1>, vector<2x2xi1>>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[TG9]], %[[TG2]], %[[TG10]], %[[TG11]], %[[TG5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 2]
// CHECK-NEXT: %[[TG12:.*]] = vector.tuple_get %[[ES3]], 3 : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>>
// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1S22:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[TG9]], %[[TG6]], %[[TG12]], %[[TG11]], %[[TG8]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[RES0:.*]] = vector.tuple %[[R1S00]], %[[R1S02]], %[[R1S20]], %[[R1S22]] : vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[RES1:.*]] = vector.insert_slices %[[RES0]], [2, 2], [1, 1] : tuple<vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>, vector<2x2xf32>> into vector<4x4xf32>
@ -241,10 +241,10 @@ func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : memref<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
// CHECK-NEXT: vector.transfer_write %[[R1]], %{{.*}}[%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, memref<4x4xf32>
@ -572,10 +572,10 @@ func @elementwise_unroll(%arg0 : memref<4x4xf32>, %arg1 : memref<4x4xf32>) {
// CHECK-NEXT: %[[VTR6:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C0]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[VTR7:.*]] = vector.transfer_read %{{.*}}[%[[C2]], %[[C2]]], %{{.*}} : tensor<4x4xf32>, vector<2x2xf32>
// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R0:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR2]], %[[VTR4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R1:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR0]], %[[VTR3]], %[[VTR5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R2:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR2]], %[[VTR6]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[R3:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"], kind = #vector.kind<add>} %[[VTR1]], %[[VTR3]], %[[VTR7]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[VTW0:.*]] = vector.transfer_write %[[R0]], %{{.*}}[%[[C0]], %[[C0]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>
// CHECK-NEXT: %[[VTW1:.*]] = vector.transfer_write %[[R1]], %[[VTW0]][%[[C0]], %[[C2]]] {masked = [false, false]} : vector<2x2xf32>, tensor<4x4xf32>

View File

@ -199,7 +199,7 @@ static void emitSymToStrFnForBitEnum(const Record &enumDef, raw_ostream &os) {
if (auto val = enumerant.getValue())
os << formatv(" if ({0}u & val) {{ strs.push_back(\"{1}\"); "
"val &= ~{0}u; }\n",
val, enumerant.getSymbol());
val, enumerant.getStr());
}
// If we have unknown bit set, return an empty string to signal errors.
os << "\n if (val) return \"\";\n";
@ -261,8 +261,7 @@ static void emitStrToSymFnForBitEnum(const Record &enumDef, raw_ostream &os) {
for (const auto &enumerant : enumerants) {
// Skip the special enumerant for None.
if (auto val = enumerant.getValue())
os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getSymbol(),
val);
os.indent(6) << formatv(".Case(\"{0}\", {1})\n", enumerant.getStr(), val);
}
os.indent(6) << ".Default(::llvm::None);\n";