forked from OSchip/llvm-project
Unroll vector masks along with their associated vector arguments.
Updates vector ContractionOp to use proper vector masks (produced by CreateMaskOp/ConstantMaskOp). Leverages the following canonicalizations in unrolling unit test: CreateMaskOp -> ConstantMaskOp, StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp Removes IndexTupleOp (no longer needed now that we have vector mask ops). Updates all unit tests. PiperOrigin-RevId: 284182168
This commit is contained in:
parent
9ca53130f3
commit
41f8e105fa
|
@ -28,6 +28,8 @@
|
|||
#include "mlir/IR/StandardTypes.h"
|
||||
|
||||
namespace mlir {
|
||||
class MLIRContext;
|
||||
class OwningRewritePatternList;
|
||||
namespace vector {
|
||||
|
||||
/// Dialect for Ops on higher-dimensional vector types.
|
||||
|
@ -37,6 +39,10 @@ public:
|
|||
static StringRef getDialectNamespace() { return "vector"; }
|
||||
};
|
||||
|
||||
/// Collect a set of vector-to-vector canonicalization patterns.
|
||||
void populateVectorToVectorCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context);
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/VectorOps/VectorOps.h.inc"
|
||||
|
||||
|
|
|
@ -49,7 +49,7 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
def Vector_ContractionOp :
|
||||
Vector_Op<"contract", [NoSideEffect]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
|
||||
Variadic<TupleOf<[Index]>>:$masks,
|
||||
Variadic<VectorOf<[I1]>>:$masks,
|
||||
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
||||
Results<(outs AnyVector)> {
|
||||
let summary = "vector contraction operation";
|
||||
|
@ -60,8 +60,9 @@ def Vector_ContractionOp :
|
|||
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
|
||||
num_batch_dims (see dimension type descriptions below)).
|
||||
|
||||
Optional vector mask arguments specify the dynamic dimension sizes of
|
||||
valid data within the lhs/rhs vector arguments.
|
||||
Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
|
||||
specify the dynamic dimension sizes of valid data within the lhs/rhs vector
|
||||
arguments.
|
||||
|
||||
An iterator type attribute list must be specified, where each element of
|
||||
the list represents an iterator with one of the following types:
|
||||
|
@ -120,10 +121,8 @@ def Vector_ContractionOp :
|
|||
|
||||
// 4D vector contraction with two contracting dimensions and optional
|
||||
// vector mask arguments.
|
||||
%lhs_mask = vector.make_tuple %size0, %size1, %size2, %size3
|
||||
: tuple<index, index, index, index>
|
||||
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
|
||||
: tuple<index, index, index, index>
|
||||
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
|
||||
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
|
||||
|
||||
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
|
@ -138,13 +137,13 @@ def Vector_ContractionOp :
|
|||
VectorType getAccType() {
|
||||
return acc()->getType().cast<VectorType>();
|
||||
}
|
||||
TupleType getLHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return TupleType();
|
||||
return getOperand(3)->getType().cast<TupleType>();
|
||||
VectorType getLHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return VectorType();
|
||||
return getOperand(3)->getType().cast<VectorType>();
|
||||
}
|
||||
TupleType getRHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return TupleType();
|
||||
return getOperand(4)->getType().cast<TupleType>();
|
||||
VectorType getRHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return VectorType();
|
||||
return getOperand(4)->getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getResultType() {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
|
@ -706,20 +705,4 @@ def Vector_CreateMaskOp :
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
|
||||
def Vector_IndexTupleOp :
|
||||
Vector_Op<"make_index_tuple", [NoSideEffect]>,
|
||||
Arguments<(ins Variadic<Index>:$operands)>,
|
||||
Results<(outs TupleOf<[Index]>)> {
|
||||
let summary = "creates a tuple of operand values";
|
||||
let description = [{
|
||||
Creates and returns a tuple of its operands which must be of index type.
|
||||
|
||||
Example:
|
||||
|
||||
%1 = vector.make_index_tuple %size0, %size1, %size2
|
||||
: tuple<index, index, index>
|
||||
|
||||
}];
|
||||
}
|
||||
#endif // VECTOR_OPS
|
||||
|
|
|
@ -82,16 +82,12 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
|
|||
if (masksInfo.size() != 2)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected zero or exactly 2 vector mask operands");
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
auto lhsType = types[0].cast<VectorType>();
|
||||
auto rhsType = types[1].cast<VectorType>();
|
||||
auto maskElementType = parser.getBuilder().getI1Type();
|
||||
SmallVector<Type, 2> maskTypes;
|
||||
SmallVector<Type, 4> lhsMaskElementTypes(lhsType.getRank(), indexType);
|
||||
maskTypes.push_back(
|
||||
TupleType::get(lhsMaskElementTypes, parser.getBuilder().getContext()));
|
||||
SmallVector<Type, 4> rhsMaskElementTypes(rhsType.getRank(), indexType);
|
||||
maskTypes.push_back(
|
||||
TupleType::get(rhsMaskElementTypes, parser.getBuilder().getContext()));
|
||||
maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
|
||||
maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
|
||||
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
|
@ -231,15 +227,10 @@ static LogicalResult verify(ContractionOp op) {
|
|||
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
|
||||
return op.emitOpError("invalid number of vector masks specified");
|
||||
if (lhsMaskType && rhsMaskType) {
|
||||
// Verify tuple element size is != rank.
|
||||
if (lhsMaskType.getTypes().size() != lhsType.getShape().size() ||
|
||||
rhsMaskType.getTypes().size() != rhsType.getShape().size())
|
||||
return op.emitOpError("invalid number of vector mask elements");
|
||||
// Verify all tuple elements are index type.
|
||||
for (auto eltType : lhsMaskType.getTypes()) {
|
||||
if (!eltType.isa<IndexType>())
|
||||
return op.emitOpError("vector mask element must have index type");
|
||||
}
|
||||
// Verify mask rank == argument rank.
|
||||
if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
|
||||
rhsMaskType.getShape().size() != rhsType.getShape().size())
|
||||
return op.emitOpError("invalid vector mask rank");
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -1218,33 +1209,9 @@ void CreateMaskOp::getCanonicalizationPatterns(
|
|||
results.insert<CreateMaskFolder>(context);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IndexTupleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult parseIndexTupleOp(OpAsmParser &parser, OperationState &result) {
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
Type resultType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
||||
return failure(
|
||||
parser.parseOperandList(operandInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(resultType) ||
|
||||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
||||
parser.addTypeToList(resultType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, IndexTupleOp &op) {
|
||||
p << op.getOperationName() << ' ';
|
||||
p.printOperands(op.operands());
|
||||
p << " : " << op.getResult()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(IndexTupleOp &op) {
|
||||
for (auto operand : op.getOperands())
|
||||
if (!operand->getType().isa<IndexType>())
|
||||
return op.emitOpError("all operands must be of index type");
|
||||
return success();
|
||||
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
|
||||
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||
patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -278,9 +278,8 @@ static Value *getOrCreateUnrolledOperandSlice(
|
|||
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
|
||||
// An iteration space index map argument 'iterationIndexMapList' must be
|
||||
// specified, with a map for each structured op input and a single map for the
|
||||
// single result. The last map in the list must be the single result map.
|
||||
// Extra operands can be passed to unrolled instances of 'op' using the
|
||||
// 'extraOperands' argument.
|
||||
// single result. The map at index 'indexMapListResultIndex' in the list must
|
||||
// be the single result map.
|
||||
//
|
||||
// Example:
|
||||
//
|
||||
|
@ -310,7 +309,7 @@ static Value *getOrCreateUnrolledOperandSlice(
|
|||
static Value *unrollSingleResultStructuredOp(
|
||||
Operation *op, ArrayRef<int64_t> iterationBounds,
|
||||
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
|
||||
ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands,
|
||||
unsigned indexMapListResultIndex, ArrayRef<int64_t> targetShape,
|
||||
PatternRewriter &builder) {
|
||||
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
|
||||
if (!shapedType || !shapedType.hasStaticShape())
|
||||
|
@ -334,7 +333,7 @@ static Value *unrollSingleResultStructuredOp(
|
|||
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
|
||||
auto basis = computeStrides(unrollFactors);
|
||||
|
||||
auto &resultOperandState = unrolledOperandState[numMaps - 1];
|
||||
auto &resultOperandState = unrolledOperandState[indexMapListResultIndex];
|
||||
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
|
||||
shapedType.getElementType());
|
||||
|
||||
|
@ -360,7 +359,6 @@ static Value *unrollSingleResultStructuredOp(
|
|||
iterationIndexMapList[i], caches[i], builder));
|
||||
}
|
||||
// Create op on sliced vector arguments.
|
||||
operands.append(extraOperands.begin(), extraOperands.end());
|
||||
auto resultVector =
|
||||
cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
|
||||
unrolledResultType)
|
||||
|
@ -368,12 +366,14 @@ static Value *unrollSingleResultStructuredOp(
|
|||
|
||||
// Compute linear result index.
|
||||
int64_t resultIndex = getUnrolledOperandLinearIndex(
|
||||
resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]);
|
||||
resultOperandState, vectorOffsets,
|
||||
iterationIndexMapList[indexMapListResultIndex]);
|
||||
// Update result cache at 'resultIndex'.
|
||||
caches[numMaps - 1][resultIndex] = resultVector;
|
||||
caches[indexMapListResultIndex][resultIndex] = resultVector;
|
||||
}
|
||||
|
||||
// Make zero splat into which we will insert results from 'cache[numMaps - 1]'
|
||||
// Make zero splat into which we will insert results from
|
||||
// 'cache[indexMapListResultIndex]'
|
||||
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
|
||||
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
|
||||
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
|
||||
|
@ -384,7 +384,8 @@ static Value *unrollSingleResultStructuredOp(
|
|||
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
||||
vectorOffsets, resultOperandState.unrolledShape);
|
||||
res = builder.create<vector::InsertStridedSliceOp>(
|
||||
op->getLoc(), caches[numMaps - 1][i], res, offsets, strides);
|
||||
op->getLoc(), caches[indexMapListResultIndex][i], res, offsets,
|
||||
strides);
|
||||
}
|
||||
|
||||
return res;
|
||||
|
@ -434,13 +435,17 @@ Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
|
|||
// Get map from iteration space index to lhs/rhs/result shape index.
|
||||
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
||||
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
||||
// TODO(andydavis) Support unrollable vector masks.
|
||||
SmallVector<Value *, 2> masks(contractionOp.masks().begin(),
|
||||
contractionOp.masks().end());
|
||||
if (llvm::size(contractionOp.masks()) == 2) {
|
||||
// Add maps for lhs/rhs vector mask arguments (same lhs/rhs vector shape)
|
||||
iterationIndexMapList.push_back(iterationIndexMapList[0]);
|
||||
iterationIndexMapList.push_back(iterationIndexMapList[1]);
|
||||
}
|
||||
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
||||
return unrollSingleResultStructuredOp(op, iterationBounds,
|
||||
iterationIndexMapList, targetShape,
|
||||
masks, builder);
|
||||
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
|
||||
// 'iterationIndexMapList' instead of 'indexMapListResultIndex'.
|
||||
return unrollSingleResultStructuredOp(
|
||||
op, iterationBounds, iterationIndexMapList,
|
||||
/*indexMapListResultIndex=*/2, targetShape, builder);
|
||||
}
|
||||
// TODO(andydavis) Create trivial iteration bounds and index map for
|
||||
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
|
||||
|
@ -680,6 +685,7 @@ void mlir::populateVectorToVectorConversionPatterns(
|
|||
MLIRContext *context, OwningRewritePatternList &patterns,
|
||||
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
|
||||
vector::populateWithGenerated(context, &patterns);
|
||||
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
||||
patterns
|
||||
.insert<ConvertMatchingFakeForkFakeJoinOp,
|
||||
ConvertFakeForkFromBlockArgsOrTransferReadOp, ConvertFakeJoinOp,
|
||||
|
|
|
@ -66,33 +66,45 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
|
|||
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %[[LMASK1]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[LMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[RMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %[[LMASK2]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
// Reducing output vector [0, 2]
|
||||
|
||||
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[RMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[RMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %[[LMASK1]], %[[RMASK4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[RMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %[[LMASK2]], %[[RMASK5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
// Reducing output vector [2, 0]
|
||||
|
||||
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[LMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK3]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[LMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %[[LMASK4]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[LMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %[[LMASK5]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
// Reducing output vector [2, 2]
|
||||
|
||||
|
@ -111,9 +123,8 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
|
|||
func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
|
||||
%arg2 : vector<4x4xf32>, %arg3 : index)
|
||||
-> (vector<4x4xf32>) {
|
||||
|
||||
%lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
|
||||
%rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
|
||||
%lhsm = vector.constant_mask [4, 6] : vector<4x6xi1>
|
||||
%rhsm = vector.constant_mask [6, 4] : vector<6x4xi1>
|
||||
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
|
||||
: vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
|
||||
|
||||
|
@ -138,19 +149,23 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
|
|||
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
// Reducing output vector [0, 2]
|
||||
|
||||
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
// Reducing output vector [2, 0]
|
||||
|
||||
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK1]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||
|
||||
// Reducing output vector [2, 2]
|
||||
|
||||
|
@ -167,9 +182,8 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
|
|||
func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
|
||||
%arg2 : vector<4x4xf32>, %arg3 : index)
|
||||
-> (vector<4x4xf32>) {
|
||||
|
||||
%lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
|
||||
%rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
|
||||
%lhsm = vector.constant_mask [4, 2] : vector<4x2xi1>
|
||||
%rhsm = vector.constant_mask [2, 4] : vector<2x4xi1>
|
||||
%0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
|
||||
: vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
|
||||
|
||||
|
|
|
@ -597,10 +597,8 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
|||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||
%arg4 : index) {
|
||||
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||
: tuple<index, index, index, index>
|
||||
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||
: tuple<index, index, index, index>
|
||||
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
|
||||
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
|
||||
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
|
||||
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
|
|
|
@ -114,10 +114,8 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
|||
%1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
// Test contraction with optional vector mask arguments.
|
||||
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||
: tuple<index, index, index, index>
|
||||
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||
: tuple<index, index, index, index>
|
||||
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
|
||||
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
|
||||
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
|
||||
%rhs_mask
|
||||
|
|
Loading…
Reference in New Issue