Unroll vector masks along with their associated vector arguments.

Updates vector ContractionOp to use proper vector masks (produced by CreateMaskOp/ConstantMaskOp).
Leverages the following canonicalizations in unrolling unit test: CreateMaskOp -> ConstantMaskOp, StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp
Removes IndexTupleOp (no longer needed now that we have vector mask ops).
Updates all unit tests.

PiperOrigin-RevId: 284182168
This commit is contained in:
Andy Davis 2019-12-06 07:36:55 -08:00 committed by A. Unique TensorFlower
parent 9ca53130f3
commit 41f8e105fa
7 changed files with 86 additions and 114 deletions

View File

@ -28,6 +28,8 @@
#include "mlir/IR/StandardTypes.h"
namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
namespace vector {
/// Dialect for Ops on higher-dimensional vector types.
@ -37,6 +39,10 @@ public:
static StringRef getDialectNamespace() { return "vector"; }
};
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
#define GET_OP_CLASSES
#include "mlir/Dialect/VectorOps/VectorOps.h.inc"

View File

@ -49,7 +49,7 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
def Vector_ContractionOp :
Vector_Op<"contract", [NoSideEffect]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
Variadic<TupleOf<[Index]>>:$masks,
Variadic<VectorOf<[I1]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
Results<(outs AnyVector)> {
let summary = "vector contraction operation";
@ -60,8 +60,9 @@ def Vector_ContractionOp :
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
num_batch_dims (see dimension type descriptions below)).
Optional vector mask arguments specify the dynamic dimension sizes of
valid data within the lhs/rhs vector arguments.
Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
specify the dynamic dimension sizes of valid data within the lhs/rhs vector
arguments.
An iterator type attribute list must be specified, where each element of
the list represents an iterator with one of the following types:
@ -120,10 +121,8 @@ def Vector_ContractionOp :
// 4D vector contraction with two contracting dimensions and optional
// vector mask arguments.
%lhs_mask = vector.make_tuple %size0, %size1, %size2, %size3
: tuple<index, index, index, index>
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
: tuple<index, index, index, index>
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
@ -138,13 +137,13 @@ def Vector_ContractionOp :
VectorType getAccType() {
return acc()->getType().cast<VectorType>();
}
TupleType getLHSVectorMaskType() {
if (llvm::size(masks()) != 2) return TupleType();
return getOperand(3)->getType().cast<TupleType>();
VectorType getLHSVectorMaskType() {
if (llvm::size(masks()) != 2) return VectorType();
return getOperand(3)->getType().cast<VectorType>();
}
TupleType getRHSVectorMaskType() {
if (llvm::size(masks()) != 2) return TupleType();
return getOperand(4)->getType().cast<TupleType>();
VectorType getRHSVectorMaskType() {
if (llvm::size(masks()) != 2) return VectorType();
return getOperand(4)->getType().cast<VectorType>();
}
VectorType getResultType() {
return getResult()->getType().cast<VectorType>();
@ -706,20 +705,4 @@ def Vector_CreateMaskOp :
let hasCanonicalizer = 1;
}
// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
def Vector_IndexTupleOp :
Vector_Op<"make_index_tuple", [NoSideEffect]>,
Arguments<(ins Variadic<Index>:$operands)>,
Results<(outs TupleOf<[Index]>)> {
let summary = "creates a tuple of operand values";
let description = [{
Creates and returns a tuple of its operands which must be of index type.
Example:
%1 = vector.make_index_tuple %size0, %size1, %size2
: tuple<index, index, index>
}];
}
#endif // VECTOR_OPS

View File

@ -82,16 +82,12 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
if (masksInfo.size() != 2)
return parser.emitError(parser.getNameLoc(),
"expected zero or exactly 2 vector mask operands");
auto indexType = parser.getBuilder().getIndexType();
auto lhsType = types[0].cast<VectorType>();
auto rhsType = types[1].cast<VectorType>();
auto maskElementType = parser.getBuilder().getI1Type();
SmallVector<Type, 2> maskTypes;
SmallVector<Type, 4> lhsMaskElementTypes(lhsType.getRank(), indexType);
maskTypes.push_back(
TupleType::get(lhsMaskElementTypes, parser.getBuilder().getContext()));
SmallVector<Type, 4> rhsMaskElementTypes(rhsType.getRank(), indexType);
maskTypes.push_back(
TupleType::get(rhsMaskElementTypes, parser.getBuilder().getContext()));
maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
return failure();
return success();
@ -231,15 +227,10 @@ static LogicalResult verify(ContractionOp op) {
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
return op.emitOpError("invalid number of vector masks specified");
if (lhsMaskType && rhsMaskType) {
// Verify tuple element size is != rank.
if (lhsMaskType.getTypes().size() != lhsType.getShape().size() ||
rhsMaskType.getTypes().size() != rhsType.getShape().size())
return op.emitOpError("invalid number of vector mask elements");
// Verify all tuple elements are index type.
for (auto eltType : lhsMaskType.getTypes()) {
if (!eltType.isa<IndexType>())
return op.emitOpError("vector mask element must have index type");
}
// Verify mask rank == argument rank.
if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
rhsMaskType.getShape().size() != rhsType.getShape().size())
return op.emitOpError("invalid vector mask rank");
}
return success();
}
@ -1218,33 +1209,9 @@ void CreateMaskOp::getCanonicalizationPatterns(
results.insert<CreateMaskFolder>(context);
}
//===----------------------------------------------------------------------===//
// IndexTupleOp
//===----------------------------------------------------------------------===//
ParseResult parseIndexTupleOp(OpAsmParser &parser, OperationState &result) {
auto indexType = parser.getBuilder().getIndexType();
Type resultType;
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
return failure(
parser.parseOperandList(operandInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(resultType) ||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
static void print(OpAsmPrinter &p, IndexTupleOp &op) {
p << op.getOperationName() << ' ';
p.printOperands(op.operands());
p << " : " << op.getResult()->getType();
}
static LogicalResult verify(IndexTupleOp &op) {
for (auto operand : op.getOperands())
if (!operand->getType().isa<IndexType>())
return op.emitOpError("all operands must be of index type");
return success();
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context) {
patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
}
namespace mlir {

View File

@ -278,9 +278,8 @@ static Value *getOrCreateUnrolledOperandSlice(
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
// An iteration space index map argument 'iterationIndexMapList' must be
// specified, with a map for each structured op input and a single map for the
// single result. The last map in the list must be the single result map.
// Extra operands can be passed to unrolled instances of 'op' using the
// 'extraOperands' argument.
// single result. The map at index 'indexMapListResultIndex' in the list must
// be the single result map.
//
// Example:
//
@ -310,7 +309,7 @@ static Value *getOrCreateUnrolledOperandSlice(
static Value *unrollSingleResultStructuredOp(
Operation *op, ArrayRef<int64_t> iterationBounds,
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands,
unsigned indexMapListResultIndex, ArrayRef<int64_t> targetShape,
PatternRewriter &builder) {
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape())
@ -334,7 +333,7 @@ static Value *unrollSingleResultStructuredOp(
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
auto basis = computeStrides(unrollFactors);
auto &resultOperandState = unrolledOperandState[numMaps - 1];
auto &resultOperandState = unrolledOperandState[indexMapListResultIndex];
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
shapedType.getElementType());
@ -360,7 +359,6 @@ static Value *unrollSingleResultStructuredOp(
iterationIndexMapList[i], caches[i], builder));
}
// Create op on sliced vector arguments.
operands.append(extraOperands.begin(), extraOperands.end());
auto resultVector =
cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
unrolledResultType)
@ -368,12 +366,14 @@ static Value *unrollSingleResultStructuredOp(
// Compute linear result index.
int64_t resultIndex = getUnrolledOperandLinearIndex(
resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]);
resultOperandState, vectorOffsets,
iterationIndexMapList[indexMapListResultIndex]);
// Update result cache at 'resultIndex'.
caches[numMaps - 1][resultIndex] = resultVector;
caches[indexMapListResultIndex][resultIndex] = resultVector;
}
// Make zero splat into which we will insert results from 'cache[numMaps - 1]'
// Make zero splat into which we will insert results from
// 'cache[indexMapListResultIndex]'
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
@ -384,7 +384,8 @@ static Value *unrollSingleResultStructuredOp(
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, resultOperandState.unrolledShape);
res = builder.create<vector::InsertStridedSliceOp>(
op->getLoc(), caches[numMaps - 1][i], res, offsets, strides);
op->getLoc(), caches[indexMapListResultIndex][i], res, offsets,
strides);
}
return res;
@ -434,13 +435,17 @@ Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
// Get map from iteration space index to lhs/rhs/result shape index.
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
contractionOp.getIterationIndexMap(iterationIndexMapList);
// TODO(andydavis) Support unrollable vector masks.
SmallVector<Value *, 2> masks(contractionOp.masks().begin(),
contractionOp.masks().end());
if (llvm::size(contractionOp.masks()) == 2) {
// Add maps for lhs/rhs vector mask arguments (same lhs/rhs vector shape)
iterationIndexMapList.push_back(iterationIndexMapList[0]);
iterationIndexMapList.push_back(iterationIndexMapList[1]);
}
// Unroll 'op' 'iterationBounds' to 'targetShape'.
return unrollSingleResultStructuredOp(op, iterationBounds,
iterationIndexMapList, targetShape,
masks, builder);
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
// 'iterationIndexMapList' instead of 'indexMapListResultIndex'.
return unrollSingleResultStructuredOp(
op, iterationBounds, iterationIndexMapList,
/*indexMapListResultIndex=*/2, targetShape, builder);
}
// TODO(andydavis) Create trivial iteration bounds and index map for
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
@ -680,6 +685,7 @@ void mlir::populateVectorToVectorConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns,
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
vector::populateWithGenerated(context, &patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
patterns
.insert<ConvertMatchingFakeForkFakeJoinOp,
ConvertFakeForkFromBlockArgsOrTransferReadOp, ConvertFakeJoinOp,

View File

@ -66,33 +66,45 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %[[LMASK1]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[LMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[RMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %[[LMASK2]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [0, 2]
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[RMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[RMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %[[LMASK1]], %[[RMASK4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[RMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %[[LMASK2]], %[[RMASK5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 0]
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[LMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK3]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[LMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %[[LMASK4]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[LMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %[[LMASK5]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 2]
@ -111,9 +123,8 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
%arg2 : vector<4x4xf32>, %arg3 : index)
-> (vector<4x4xf32>) {
%lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
%rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
%lhsm = vector.constant_mask [4, 6] : vector<4x6xi1>
%rhsm = vector.constant_mask [6, 4] : vector<6x4xi1>
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
: vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
@ -138,19 +149,23 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [0, 2]
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 0]
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK1]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 2]
@ -167,9 +182,8 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
%arg2 : vector<4x4xf32>, %arg3 : index)
-> (vector<4x4xf32>) {
%lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
%rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
%lhsm = vector.constant_mask [4, 2] : vector<4x2xi1>
%rhsm = vector.constant_mask [2, 4] : vector<2x4xi1>
%0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
: vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>

View File

@ -597,10 +597,8 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>

View File

@ -114,10 +114,8 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
// Test contraction with optional vector mask arguments.
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
%rhs_mask