Unroll vector masks along with their associated vector arguments.

Updates vector ContractionOp to use proper vector masks (produced by CreateMaskOp/ConstantMaskOp).
Leverages the following canonicalizations in unrolling unit test: CreateMaskOp -> ConstantMaskOp, StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp
Removes IndexTupleOp (no longer needed now that we have vector mask ops).
Updates all unit tests.

PiperOrigin-RevId: 284182168
This commit is contained in:
Andy Davis 2019-12-06 07:36:55 -08:00 committed by A. Unique TensorFlower
parent 9ca53130f3
commit 41f8e105fa
7 changed files with 86 additions and 114 deletions

View File

@ -28,6 +28,8 @@
#include "mlir/IR/StandardTypes.h" #include "mlir/IR/StandardTypes.h"
namespace mlir { namespace mlir {
class MLIRContext;
class OwningRewritePatternList;
namespace vector { namespace vector {
/// Dialect for Ops on higher-dimensional vector types. /// Dialect for Ops on higher-dimensional vector types.
@ -37,6 +39,10 @@ public:
static StringRef getDialectNamespace() { return "vector"; } static StringRef getDialectNamespace() { return "vector"; }
}; };
/// Collect a set of vector-to-vector canonicalization patterns.
void populateVectorToVectorCanonicalizationPatterns(
OwningRewritePatternList &patterns, MLIRContext *context);
#define GET_OP_CLASSES #define GET_OP_CLASSES
#include "mlir/Dialect/VectorOps/VectorOps.h.inc" #include "mlir/Dialect/VectorOps/VectorOps.h.inc"

View File

@ -49,7 +49,7 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
def Vector_ContractionOp : def Vector_ContractionOp :
Vector_Op<"contract", [NoSideEffect]>, Vector_Op<"contract", [NoSideEffect]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
Variadic<TupleOf<[Index]>>:$masks, Variadic<VectorOf<[I1]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
Results<(outs AnyVector)> { Results<(outs AnyVector)> {
let summary = "vector contraction operation"; let summary = "vector contraction operation";
@ -60,8 +60,9 @@ def Vector_ContractionOp :
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims + vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
num_batch_dims (see dimension type descriptions below)). num_batch_dims (see dimension type descriptions below)).
Optional vector mask arguments specify the dynamic dimension sizes of Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
valid data within the lhs/rhs vector arguments. specify the dynamic dimension sizes of valid data within the lhs/rhs vector
arguments.
An iterator type attribute list must be specified, where each element of An iterator type attribute list must be specified, where each element of
the list represents an iterator with one of the following types: the list represents an iterator with one of the following types:
@ -120,10 +121,8 @@ def Vector_ContractionOp :
// 4D vector contraction with two contracting dimensions and optional // 4D vector contraction with two contracting dimensions and optional
// vector mask arguments. // vector mask arguments.
%lhs_mask = vector.make_tuple %size0, %size1, %size2, %size3 %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
: tuple<index, index, index, index> %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
: tuple<index, index, index, index>
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
@ -138,13 +137,13 @@ def Vector_ContractionOp :
VectorType getAccType() { VectorType getAccType() {
return acc()->getType().cast<VectorType>(); return acc()->getType().cast<VectorType>();
} }
TupleType getLHSVectorMaskType() { VectorType getLHSVectorMaskType() {
if (llvm::size(masks()) != 2) return TupleType(); if (llvm::size(masks()) != 2) return VectorType();
return getOperand(3)->getType().cast<TupleType>(); return getOperand(3)->getType().cast<VectorType>();
} }
TupleType getRHSVectorMaskType() { VectorType getRHSVectorMaskType() {
if (llvm::size(masks()) != 2) return TupleType(); if (llvm::size(masks()) != 2) return VectorType();
return getOperand(4)->getType().cast<TupleType>(); return getOperand(4)->getType().cast<VectorType>();
} }
VectorType getResultType() { VectorType getResultType() {
return getResult()->getType().cast<VectorType>(); return getResult()->getType().cast<VectorType>();
@ -706,20 +705,4 @@ def Vector_CreateMaskOp :
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
} }
// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
def Vector_IndexTupleOp :
Vector_Op<"make_index_tuple", [NoSideEffect]>,
Arguments<(ins Variadic<Index>:$operands)>,
Results<(outs TupleOf<[Index]>)> {
let summary = "creates a tuple of operand values";
let description = [{
Creates and returns a tuple of its operands which must be of index type.
Example:
%1 = vector.make_index_tuple %size0, %size1, %size2
: tuple<index, index, index>
}];
}
#endif // VECTOR_OPS #endif // VECTOR_OPS

View File

@ -82,16 +82,12 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
if (masksInfo.size() != 2) if (masksInfo.size() != 2)
return parser.emitError(parser.getNameLoc(), return parser.emitError(parser.getNameLoc(),
"expected zero or exactly 2 vector mask operands"); "expected zero or exactly 2 vector mask operands");
auto indexType = parser.getBuilder().getIndexType();
auto lhsType = types[0].cast<VectorType>(); auto lhsType = types[0].cast<VectorType>();
auto rhsType = types[1].cast<VectorType>(); auto rhsType = types[1].cast<VectorType>();
auto maskElementType = parser.getBuilder().getI1Type();
SmallVector<Type, 2> maskTypes; SmallVector<Type, 2> maskTypes;
SmallVector<Type, 4> lhsMaskElementTypes(lhsType.getRank(), indexType); maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
maskTypes.push_back( maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
TupleType::get(lhsMaskElementTypes, parser.getBuilder().getContext()));
SmallVector<Type, 4> rhsMaskElementTypes(rhsType.getRank(), indexType);
maskTypes.push_back(
TupleType::get(rhsMaskElementTypes, parser.getBuilder().getContext()));
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
return failure(); return failure();
return success(); return success();
@ -231,15 +227,10 @@ static LogicalResult verify(ContractionOp op) {
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
return op.emitOpError("invalid number of vector masks specified"); return op.emitOpError("invalid number of vector masks specified");
if (lhsMaskType && rhsMaskType) { if (lhsMaskType && rhsMaskType) {
// Verify tuple element size is != rank. // Verify mask rank == argument rank.
if (lhsMaskType.getTypes().size() != lhsType.getShape().size() || if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
rhsMaskType.getTypes().size() != rhsType.getShape().size()) rhsMaskType.getShape().size() != rhsType.getShape().size())
return op.emitOpError("invalid number of vector mask elements"); return op.emitOpError("invalid vector mask rank");
// Verify all tuple elements are index type.
for (auto eltType : lhsMaskType.getTypes()) {
if (!eltType.isa<IndexType>())
return op.emitOpError("vector mask element must have index type");
}
} }
return success(); return success();
} }
@ -1218,33 +1209,9 @@ void CreateMaskOp::getCanonicalizationPatterns(
results.insert<CreateMaskFolder>(context); results.insert<CreateMaskFolder>(context);
} }
//===----------------------------------------------------------------------===// void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
// IndexTupleOp OwningRewritePatternList &patterns, MLIRContext *context) {
//===----------------------------------------------------------------------===// patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
ParseResult parseIndexTupleOp(OpAsmParser &parser, OperationState &result) {
auto indexType = parser.getBuilder().getIndexType();
Type resultType;
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
return failure(
parser.parseOperandList(operandInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(resultType) ||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
static void print(OpAsmPrinter &p, IndexTupleOp &op) {
p << op.getOperationName() << ' ';
p.printOperands(op.operands());
p << " : " << op.getResult()->getType();
}
static LogicalResult verify(IndexTupleOp &op) {
for (auto operand : op.getOperands())
if (!operand->getType().isa<IndexType>())
return op.emitOpError("all operands must be of index type");
return success();
} }
namespace mlir { namespace mlir {

View File

@ -278,9 +278,8 @@ static Value *getOrCreateUnrolledOperandSlice(
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'. // with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
// An iteration space index map argument 'iterationIndexMapList' must be // An iteration space index map argument 'iterationIndexMapList' must be
// specified, with a map for each structured op input and a single map for the // specified, with a map for each structured op input and a single map for the
// single result. The last map in the list must be the single result map. // single result. The map at index 'indexMapListResultIndex' in the list must
// Extra operands can be passed to unrolled instances of 'op' using the // be the single result map.
// 'extraOperands' argument.
// //
// Example: // Example:
// //
@ -310,7 +309,7 @@ static Value *getOrCreateUnrolledOperandSlice(
static Value *unrollSingleResultStructuredOp( static Value *unrollSingleResultStructuredOp(
Operation *op, ArrayRef<int64_t> iterationBounds, Operation *op, ArrayRef<int64_t> iterationBounds,
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList, std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands, unsigned indexMapListResultIndex, ArrayRef<int64_t> targetShape,
PatternRewriter &builder) { PatternRewriter &builder) {
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>(); auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
if (!shapedType || !shapedType.hasStaticShape()) if (!shapedType || !shapedType.hasStaticShape())
@ -334,7 +333,7 @@ static Value *unrollSingleResultStructuredOp(
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
auto basis = computeStrides(unrollFactors); auto basis = computeStrides(unrollFactors);
auto &resultOperandState = unrolledOperandState[numMaps - 1]; auto &resultOperandState = unrolledOperandState[indexMapListResultIndex];
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape, auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
shapedType.getElementType()); shapedType.getElementType());
@ -360,7 +359,6 @@ static Value *unrollSingleResultStructuredOp(
iterationIndexMapList[i], caches[i], builder)); iterationIndexMapList[i], caches[i], builder));
} }
// Create op on sliced vector arguments. // Create op on sliced vector arguments.
operands.append(extraOperands.begin(), extraOperands.end());
auto resultVector = auto resultVector =
cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands, cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
unrolledResultType) unrolledResultType)
@ -368,12 +366,14 @@ static Value *unrollSingleResultStructuredOp(
// Compute linear result index. // Compute linear result index.
int64_t resultIndex = getUnrolledOperandLinearIndex( int64_t resultIndex = getUnrolledOperandLinearIndex(
resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]); resultOperandState, vectorOffsets,
iterationIndexMapList[indexMapListResultIndex]);
// Update result cache at 'resultIndex'. // Update result cache at 'resultIndex'.
caches[numMaps - 1][resultIndex] = resultVector; caches[indexMapListResultIndex][resultIndex] = resultVector;
} }
// Make zero splat into which we will insert results from 'cache[numMaps - 1]' // Make zero splat into which we will insert results from
// 'cache[indexMapListResultIndex]'
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>(); auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType); auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1); SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
@ -384,7 +384,8 @@ static Value *unrollSingleResultStructuredOp(
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
vectorOffsets, resultOperandState.unrolledShape); vectorOffsets, resultOperandState.unrolledShape);
res = builder.create<vector::InsertStridedSliceOp>( res = builder.create<vector::InsertStridedSliceOp>(
op->getLoc(), caches[numMaps - 1][i], res, offsets, strides); op->getLoc(), caches[indexMapListResultIndex][i], res, offsets,
strides);
} }
return res; return res;
@ -434,13 +435,17 @@ Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
// Get map from iteration space index to lhs/rhs/result shape index. // Get map from iteration space index to lhs/rhs/result shape index.
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList; std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
contractionOp.getIterationIndexMap(iterationIndexMapList); contractionOp.getIterationIndexMap(iterationIndexMapList);
// TODO(andydavis) Support unrollable vector masks. if (llvm::size(contractionOp.masks()) == 2) {
SmallVector<Value *, 2> masks(contractionOp.masks().begin(), // Add maps for lhs/rhs vector mask arguments (same lhs/rhs vector shape)
contractionOp.masks().end()); iterationIndexMapList.push_back(iterationIndexMapList[0]);
iterationIndexMapList.push_back(iterationIndexMapList[1]);
}
// Unroll 'op' 'iterationBounds' to 'targetShape'. // Unroll 'op' 'iterationBounds' to 'targetShape'.
return unrollSingleResultStructuredOp(op, iterationBounds, // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
iterationIndexMapList, targetShape, // 'iterationIndexMapList' instead of 'indexMapListResultIndex'.
masks, builder); return unrollSingleResultStructuredOp(
op, iterationBounds, iterationIndexMapList,
/*indexMapListResultIndex=*/2, targetShape, builder);
} }
// TODO(andydavis) Create trivial iteration bounds and index map for // TODO(andydavis) Create trivial iteration bounds and index map for
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove // elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
@ -680,6 +685,7 @@ void mlir::populateVectorToVectorConversionPatterns(
MLIRContext *context, OwningRewritePatternList &patterns, MLIRContext *context, OwningRewritePatternList &patterns,
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) { ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
vector::populateWithGenerated(context, &patterns); vector::populateWithGenerated(context, &patterns);
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
patterns patterns
.insert<ConvertMatchingFakeForkFakeJoinOp, .insert<ConvertMatchingFakeForkFakeJoinOp,
ConvertFakeForkFromBlockArgsOrTransferReadOp, ConvertFakeJoinOp, ConvertFakeForkFromBlockArgsOrTransferReadOp, ConvertFakeJoinOp,

View File

@ -66,33 +66,45 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %[[LMASK1]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[LMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[RMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %[[LMASK2]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [0, 2] // Reducing output vector [0, 2]
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %[[LMASK1]], %[[RMASK4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %[[LMASK2]], %[[RMASK5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 0] // Reducing output vector [2, 0]
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[LMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK3]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[LMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %[[LMASK4]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[LMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %[[LMASK5]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 2] // Reducing output vector [2, 2]
@ -111,9 +123,8 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
%arg2 : vector<4x4xf32>, %arg3 : index) %arg2 : vector<4x4xf32>, %arg3 : index)
-> (vector<4x4xf32>) { -> (vector<4x4xf32>) {
%lhsm = vector.constant_mask [4, 6] : vector<4x6xi1>
%lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index> %rhsm = vector.constant_mask [6, 4] : vector<6x4xi1>
%rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
: vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
@ -138,19 +149,23 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> // CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [0, 2] // Reducing output vector [0, 2]
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 0] // Reducing output vector [2, 0]
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK1]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
// Reducing output vector [2, 2] // Reducing output vector [2, 2]
@ -167,9 +182,8 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
%arg2 : vector<4x4xf32>, %arg3 : index) %arg2 : vector<4x4xf32>, %arg3 : index)
-> (vector<4x4xf32>) { -> (vector<4x4xf32>) {
%lhsm = vector.constant_mask [4, 2] : vector<4x2xi1>
%lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index> %rhsm = vector.constant_mask [2, 4] : vector<2x4xi1>
%rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
%0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
: vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>

View File

@ -597,10 +597,8 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) { %arg4 : index) {
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
: tuple<index, index, index, index> %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}} // expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>

View File

@ -114,10 +114,8 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3 %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
// Test contraction with optional vector mask arguments. // Test contraction with optional vector mask arguments.
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
: tuple<index, index, index, index> %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
%rhs_mask %rhs_mask