diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h index 668eaa5c9d56..8cb0d8516b4a 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.h +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.h @@ -28,6 +28,8 @@ #include "mlir/IR/StandardTypes.h" namespace mlir { +class MLIRContext; +class OwningRewritePatternList; namespace vector { /// Dialect for Ops on higher-dimensional vector types. @@ -37,6 +39,10 @@ public: static StringRef getDialectNamespace() { return "vector"; } }; +/// Collect a set of vector-to-vector canonicalization patterns. +void populateVectorToVectorCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context); + #define GET_OP_CLASSES #include "mlir/Dialect/VectorOps/VectorOps.h.inc" diff --git a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td index f4bfeb73dd77..ebeecfbb715a 100644 --- a/mlir/include/mlir/Dialect/VectorOps/VectorOps.td +++ b/mlir/include/mlir/Dialect/VectorOps/VectorOps.td @@ -49,7 +49,7 @@ class Vector_Op traits = []> : def Vector_ContractionOp : Vector_Op<"contract", [NoSideEffect]>, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc, - Variadic>:$masks, + Variadic>:$masks, AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>, Results<(outs AnyVector)> { let summary = "vector contraction operation"; @@ -60,8 +60,9 @@ def Vector_ContractionOp : vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims + num_batch_dims (see dimension type descriptions below)). - Optional vector mask arguments specify the dynamic dimension sizes of - valid data within the lhs/rhs vector arguments. + Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp) + specify the dynamic dimension sizes of valid data within the lhs/rhs vector + arguments. An iterator type attribute list must be specified, where each element of the list represents an iterator with one of the following types: @@ -120,10 +121,8 @@ def Vector_ContractionOp : // 4D vector contraction with two contracting dimensions and optional // vector mask arguments. - %lhs_mask = vector.make_tuple %size0, %size1, %size2, %size3 - : tuple - %rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7 - : tuple + %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> + %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> @@ -138,13 +137,13 @@ def Vector_ContractionOp : VectorType getAccType() { return acc()->getType().cast(); } - TupleType getLHSVectorMaskType() { - if (llvm::size(masks()) != 2) return TupleType(); - return getOperand(3)->getType().cast(); + VectorType getLHSVectorMaskType() { + if (llvm::size(masks()) != 2) return VectorType(); + return getOperand(3)->getType().cast(); } - TupleType getRHSVectorMaskType() { - if (llvm::size(masks()) != 2) return TupleType(); - return getOperand(4)->getType().cast(); + VectorType getRHSVectorMaskType() { + if (llvm::size(masks()) != 2) return VectorType(); + return getOperand(4)->getType().cast(); } VectorType getResultType() { return getResult()->getType().cast(); @@ -706,20 +705,4 @@ def Vector_CreateMaskOp : let hasCanonicalizer = 1; } -// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask -def Vector_IndexTupleOp : - Vector_Op<"make_index_tuple", [NoSideEffect]>, - Arguments<(ins Variadic:$operands)>, - Results<(outs TupleOf<[Index]>)> { - let summary = "creates a tuple of operand values"; - let description = [{ - Creates and returns a tuple of its operands which must be of index type. - - Example: - - %1 = vector.make_index_tuple %size0, %size1, %size2 - : tuple - - }]; -} #endif // VECTOR_OPS diff --git a/mlir/lib/Dialect/VectorOps/VectorOps.cpp b/mlir/lib/Dialect/VectorOps/VectorOps.cpp index f96d3bacacf9..5d596f388ed0 100644 --- a/mlir/lib/Dialect/VectorOps/VectorOps.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorOps.cpp @@ -82,16 +82,12 @@ static ParseResult parseContractionOp(OpAsmParser &parser, if (masksInfo.size() != 2) return parser.emitError(parser.getNameLoc(), "expected zero or exactly 2 vector mask operands"); - auto indexType = parser.getBuilder().getIndexType(); auto lhsType = types[0].cast(); auto rhsType = types[1].cast(); + auto maskElementType = parser.getBuilder().getI1Type(); SmallVector maskTypes; - SmallVector lhsMaskElementTypes(lhsType.getRank(), indexType); - maskTypes.push_back( - TupleType::get(lhsMaskElementTypes, parser.getBuilder().getContext())); - SmallVector rhsMaskElementTypes(rhsType.getRank(), indexType); - maskTypes.push_back( - TupleType::get(rhsMaskElementTypes, parser.getBuilder().getContext())); + maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType)); + maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType)); if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands)) return failure(); return success(); @@ -231,15 +227,10 @@ static LogicalResult verify(ContractionOp op) { if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType)) return op.emitOpError("invalid number of vector masks specified"); if (lhsMaskType && rhsMaskType) { - // Verify tuple element size is != rank. - if (lhsMaskType.getTypes().size() != lhsType.getShape().size() || - rhsMaskType.getTypes().size() != rhsType.getShape().size()) - return op.emitOpError("invalid number of vector mask elements"); - // Verify all tuple elements are index type. - for (auto eltType : lhsMaskType.getTypes()) { - if (!eltType.isa()) - return op.emitOpError("vector mask element must have index type"); - } + // Verify mask rank == argument rank. + if (lhsMaskType.getShape().size() != lhsType.getShape().size() || + rhsMaskType.getShape().size() != rhsType.getShape().size()) + return op.emitOpError("invalid vector mask rank"); } return success(); } @@ -1218,33 +1209,9 @@ void CreateMaskOp::getCanonicalizationPatterns( results.insert(context); } -//===----------------------------------------------------------------------===// -// IndexTupleOp -//===----------------------------------------------------------------------===// - -ParseResult parseIndexTupleOp(OpAsmParser &parser, OperationState &result) { - auto indexType = parser.getBuilder().getIndexType(); - Type resultType; - SmallVector operandInfo; - return failure( - parser.parseOperandList(operandInfo) || - parser.parseOptionalAttrDict(result.attributes) || - parser.parseColonType(resultType) || - parser.resolveOperands(operandInfo, indexType, result.operands) || - parser.addTypeToList(resultType, result.types)); -} - -static void print(OpAsmPrinter &p, IndexTupleOp &op) { - p << op.getOperationName() << ' '; - p.printOperands(op.operands()); - p << " : " << op.getResult()->getType(); -} - -static LogicalResult verify(IndexTupleOp &op) { - for (auto operand : op.getOperands()) - if (!operand->getType().isa()) - return op.emitOpError("all operands must be of index type"); - return success(); +void mlir::vector::populateVectorToVectorCanonicalizationPatterns( + OwningRewritePatternList &patterns, MLIRContext *context) { + patterns.insert(context); } namespace mlir { diff --git a/mlir/lib/Dialect/VectorOps/VectorToVector.cpp b/mlir/lib/Dialect/VectorOps/VectorToVector.cpp index 4654aff45821..c2726edd9bf1 100644 --- a/mlir/lib/Dialect/VectorOps/VectorToVector.cpp +++ b/mlir/lib/Dialect/VectorOps/VectorToVector.cpp @@ -278,9 +278,8 @@ static Value *getOrCreateUnrolledOperandSlice( // with iteration bounds 'iterationBounds' unrolled to 'targetShape'. // An iteration space index map argument 'iterationIndexMapList' must be // specified, with a map for each structured op input and a single map for the -// single result. The last map in the list must be the single result map. -// Extra operands can be passed to unrolled instances of 'op' using the -// 'extraOperands' argument. +// single result. The map at index 'indexMapListResultIndex' in the list must +// be the single result map. // // Example: // @@ -310,7 +309,7 @@ static Value *getOrCreateUnrolledOperandSlice( static Value *unrollSingleResultStructuredOp( Operation *op, ArrayRef iterationBounds, std::vector> &iterationIndexMapList, - ArrayRef targetShape, ArrayRef extraOperands, + unsigned indexMapListResultIndex, ArrayRef targetShape, PatternRewriter &builder) { auto shapedType = op->getResult(0)->getType().dyn_cast_or_null(); if (!shapedType || !shapedType.hasStaticShape()) @@ -334,7 +333,7 @@ static Value *unrollSingleResultStructuredOp( auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors); auto basis = computeStrides(unrollFactors); - auto &resultOperandState = unrolledOperandState[numMaps - 1]; + auto &resultOperandState = unrolledOperandState[indexMapListResultIndex]; auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape, shapedType.getElementType()); @@ -360,7 +359,6 @@ static Value *unrollSingleResultStructuredOp( iterationIndexMapList[i], caches[i], builder)); } // Create op on sliced vector arguments. - operands.append(extraOperands.begin(), extraOperands.end()); auto resultVector = cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands, unrolledResultType) @@ -368,12 +366,14 @@ static Value *unrollSingleResultStructuredOp( // Compute linear result index. int64_t resultIndex = getUnrolledOperandLinearIndex( - resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]); + resultOperandState, vectorOffsets, + iterationIndexMapList[indexMapListResultIndex]); // Update result cache at 'resultIndex'. - caches[numMaps - 1][resultIndex] = resultVector; + caches[indexMapListResultIndex][resultIndex] = resultVector; } - // Make zero splat into which we will insert results from 'cache[numMaps - 1]' + // Make zero splat into which we will insert results from + // 'cache[indexMapListResultIndex]' auto resultVectorType = op->getResult(0)->getType().cast(); auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType); SmallVector strides(resultOperandState.unrollFactors.size(), 1); @@ -384,7 +384,8 @@ static Value *unrollSingleResultStructuredOp( auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; }, vectorOffsets, resultOperandState.unrolledShape); res = builder.create( - op->getLoc(), caches[numMaps - 1][i], res, offsets, strides); + op->getLoc(), caches[indexMapListResultIndex][i], res, offsets, + strides); } return res; @@ -434,13 +435,17 @@ Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder, // Get map from iteration space index to lhs/rhs/result shape index. std::vector> iterationIndexMapList; contractionOp.getIterationIndexMap(iterationIndexMapList); - // TODO(andydavis) Support unrollable vector masks. - SmallVector masks(contractionOp.masks().begin(), - contractionOp.masks().end()); + if (llvm::size(contractionOp.masks()) == 2) { + // Add maps for lhs/rhs vector mask arguments (same lhs/rhs vector shape) + iterationIndexMapList.push_back(iterationIndexMapList[0]); + iterationIndexMapList.push_back(iterationIndexMapList[1]); + } // Unroll 'op' 'iterationBounds' to 'targetShape'. - return unrollSingleResultStructuredOp(op, iterationBounds, - iterationIndexMapList, targetShape, - masks, builder); + // TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition + // 'iterationIndexMapList' instead of 'indexMapListResultIndex'. + return unrollSingleResultStructuredOp( + op, iterationBounds, iterationIndexMapList, + /*indexMapListResultIndex=*/2, targetShape, builder); } // TODO(andydavis) Create trivial iteration bounds and index map for // elementwise operations and call 'unrollSingleResultStructuredOp'. Remove @@ -680,6 +685,7 @@ void mlir::populateVectorToVectorConversionPatterns( MLIRContext *context, OwningRewritePatternList &patterns, ArrayRef coarseVectorShape, ArrayRef fineVectorShape) { vector::populateWithGenerated(context, &patterns); + vector::populateVectorToVectorCanonicalizationPatterns(patterns, context); patterns .insert, %1: vector<4x4xf32>) -> vector<4x4xf32> { // CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %[[LMASK1]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[LMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[RMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %[[LMASK2]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] // CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[RMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[RMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %[[LMASK1]], %[[RMASK4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[RMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %[[LMASK2]], %[[RMASK5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] // CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[LMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK3]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[LMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %[[LMASK4]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[LMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %[[LMASK5]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] @@ -111,9 +123,8 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> { func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, %arg2 : vector<4x4xf32>, %arg3 : index) -> (vector<4x4xf32>) { - - %lhsm = vector.make_index_tuple %arg3, %arg3 : tuple - %rhsm = vector.make_index_tuple %arg3, %arg3 : tuple + %lhsm = vector.constant_mask [4, 6] : vector<4x6xi1> + %rhsm = vector.constant_mask [6, 4] : vector<6x4xi1> %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm : vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32> @@ -138,19 +149,23 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, // CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> // CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [0, 2] // CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 0] // CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32> // CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32> -// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> +// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1> +// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK1]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32> // Reducing output vector [2, 2] @@ -167,9 +182,8 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>, func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>, %arg2 : vector<4x4xf32>, %arg3 : index) -> (vector<4x4xf32>) { - - %lhsm = vector.make_index_tuple %arg3, %arg3 : tuple - %rhsm = vector.make_index_tuple %arg3, %arg3 : tuple + %lhsm = vector.constant_mask [4, 2] : vector<4x2xi1> + %rhsm = vector.constant_mask [2, 4] : vector<2x4xi1> %0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm : vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32> diff --git a/mlir/test/Dialect/VectorOps/invalid.mlir b/mlir/test/Dialect/VectorOps/invalid.mlir index bd664f715755..3b521f6e9ba4 100644 --- a/mlir/test/Dialect/VectorOps/invalid.mlir +++ b/mlir/test/Dialect/VectorOps/invalid.mlir @@ -597,10 +597,8 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, %arg4 : index) { - %lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 - : tuple - %rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 - : tuple + %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> + %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> // expected-error@+1 {{expected zero or exactly 2 vector mask operands}} %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> diff --git a/mlir/test/Dialect/VectorOps/ops.mlir b/mlir/test/Dialect/VectorOps/ops.mlir index cb87c20a2b90..c1c911098ae2 100644 --- a/mlir/test/Dialect/VectorOps/ops.mlir +++ b/mlir/test/Dialect/VectorOps/ops.mlir @@ -114,10 +114,8 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3 : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> // Test contraction with optional vector mask arguments. - %lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 - : tuple - %rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 - : tuple + %lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1> + %rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1> // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask