forked from OSchip/llvm-project
Unroll vector masks along with their associated vector arguments.
Updates vector ContractionOp to use proper vector masks (produced by CreateMaskOp/ConstantMaskOp). Leverages the following canonicalizations in unrolling unit test: CreateMaskOp -> ConstantMaskOp, StridedSliceOp(ConstantMaskOp) -> ConstantMaskOp Removes IndexTupleOp (no longer needed now that we have vector mask ops). Updates all unit tests. PiperOrigin-RevId: 284182168
This commit is contained in:
parent
9ca53130f3
commit
41f8e105fa
|
@ -28,6 +28,8 @@
|
||||||
#include "mlir/IR/StandardTypes.h"
|
#include "mlir/IR/StandardTypes.h"
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
class MLIRContext;
|
||||||
|
class OwningRewritePatternList;
|
||||||
namespace vector {
|
namespace vector {
|
||||||
|
|
||||||
/// Dialect for Ops on higher-dimensional vector types.
|
/// Dialect for Ops on higher-dimensional vector types.
|
||||||
|
@ -37,6 +39,10 @@ public:
|
||||||
static StringRef getDialectNamespace() { return "vector"; }
|
static StringRef getDialectNamespace() { return "vector"; }
|
||||||
};
|
};
|
||||||
|
|
||||||
|
/// Collect a set of vector-to-vector canonicalization patterns.
|
||||||
|
void populateVectorToVectorCanonicalizationPatterns(
|
||||||
|
OwningRewritePatternList &patterns, MLIRContext *context);
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
#include "mlir/Dialect/VectorOps/VectorOps.h.inc"
|
#include "mlir/Dialect/VectorOps/VectorOps.h.inc"
|
||||||
|
|
||||||
|
|
|
@ -49,7 +49,7 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
def Vector_ContractionOp :
|
def Vector_ContractionOp :
|
||||||
Vector_Op<"contract", [NoSideEffect]>,
|
Vector_Op<"contract", [NoSideEffect]>,
|
||||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
|
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
|
||||||
Variadic<TupleOf<[Index]>>:$masks,
|
Variadic<VectorOf<[I1]>>:$masks,
|
||||||
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
||||||
Results<(outs AnyVector)> {
|
Results<(outs AnyVector)> {
|
||||||
let summary = "vector contraction operation";
|
let summary = "vector contraction operation";
|
||||||
|
@ -60,8 +60,9 @@ def Vector_ContractionOp :
|
||||||
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
|
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
|
||||||
num_batch_dims (see dimension type descriptions below)).
|
num_batch_dims (see dimension type descriptions below)).
|
||||||
|
|
||||||
Optional vector mask arguments specify the dynamic dimension sizes of
|
Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
|
||||||
valid data within the lhs/rhs vector arguments.
|
specify the dynamic dimension sizes of valid data within the lhs/rhs vector
|
||||||
|
arguments.
|
||||||
|
|
||||||
An iterator type attribute list must be specified, where each element of
|
An iterator type attribute list must be specified, where each element of
|
||||||
the list represents an iterator with one of the following types:
|
the list represents an iterator with one of the following types:
|
||||||
|
@ -120,10 +121,8 @@ def Vector_ContractionOp :
|
||||||
|
|
||||||
// 4D vector contraction with two contracting dimensions and optional
|
// 4D vector contraction with two contracting dimensions and optional
|
||||||
// vector mask arguments.
|
// vector mask arguments.
|
||||||
%lhs_mask = vector.make_tuple %size0, %size1, %size2, %size3
|
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
|
||||||
: tuple<index, index, index, index>
|
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
|
||||||
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
|
|
||||||
: tuple<index, index, index, index>
|
|
||||||
|
|
||||||
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
|
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
|
@ -138,13 +137,13 @@ def Vector_ContractionOp :
|
||||||
VectorType getAccType() {
|
VectorType getAccType() {
|
||||||
return acc()->getType().cast<VectorType>();
|
return acc()->getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
TupleType getLHSVectorMaskType() {
|
VectorType getLHSVectorMaskType() {
|
||||||
if (llvm::size(masks()) != 2) return TupleType();
|
if (llvm::size(masks()) != 2) return VectorType();
|
||||||
return getOperand(3)->getType().cast<TupleType>();
|
return getOperand(3)->getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
TupleType getRHSVectorMaskType() {
|
VectorType getRHSVectorMaskType() {
|
||||||
if (llvm::size(masks()) != 2) return TupleType();
|
if (llvm::size(masks()) != 2) return VectorType();
|
||||||
return getOperand(4)->getType().cast<TupleType>();
|
return getOperand(4)->getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
VectorType getResultType() {
|
VectorType getResultType() {
|
||||||
return getResult()->getType().cast<VectorType>();
|
return getResult()->getType().cast<VectorType>();
|
||||||
|
@ -706,20 +705,4 @@ def Vector_CreateMaskOp :
|
||||||
let hasCanonicalizer = 1;
|
let hasCanonicalizer = 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(andydavis) Delete this op once ContractOp is converted to use VectorMask
|
|
||||||
def Vector_IndexTupleOp :
|
|
||||||
Vector_Op<"make_index_tuple", [NoSideEffect]>,
|
|
||||||
Arguments<(ins Variadic<Index>:$operands)>,
|
|
||||||
Results<(outs TupleOf<[Index]>)> {
|
|
||||||
let summary = "creates a tuple of operand values";
|
|
||||||
let description = [{
|
|
||||||
Creates and returns a tuple of its operands which must be of index type.
|
|
||||||
|
|
||||||
Example:
|
|
||||||
|
|
||||||
%1 = vector.make_index_tuple %size0, %size1, %size2
|
|
||||||
: tuple<index, index, index>
|
|
||||||
|
|
||||||
}];
|
|
||||||
}
|
|
||||||
#endif // VECTOR_OPS
|
#endif // VECTOR_OPS
|
||||||
|
|
|
@ -82,16 +82,12 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
|
||||||
if (masksInfo.size() != 2)
|
if (masksInfo.size() != 2)
|
||||||
return parser.emitError(parser.getNameLoc(),
|
return parser.emitError(parser.getNameLoc(),
|
||||||
"expected zero or exactly 2 vector mask operands");
|
"expected zero or exactly 2 vector mask operands");
|
||||||
auto indexType = parser.getBuilder().getIndexType();
|
|
||||||
auto lhsType = types[0].cast<VectorType>();
|
auto lhsType = types[0].cast<VectorType>();
|
||||||
auto rhsType = types[1].cast<VectorType>();
|
auto rhsType = types[1].cast<VectorType>();
|
||||||
|
auto maskElementType = parser.getBuilder().getI1Type();
|
||||||
SmallVector<Type, 2> maskTypes;
|
SmallVector<Type, 2> maskTypes;
|
||||||
SmallVector<Type, 4> lhsMaskElementTypes(lhsType.getRank(), indexType);
|
maskTypes.push_back(VectorType::get(lhsType.getShape(), maskElementType));
|
||||||
maskTypes.push_back(
|
maskTypes.push_back(VectorType::get(rhsType.getShape(), maskElementType));
|
||||||
TupleType::get(lhsMaskElementTypes, parser.getBuilder().getContext()));
|
|
||||||
SmallVector<Type, 4> rhsMaskElementTypes(rhsType.getRank(), indexType);
|
|
||||||
maskTypes.push_back(
|
|
||||||
TupleType::get(rhsMaskElementTypes, parser.getBuilder().getContext()));
|
|
||||||
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
|
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
|
||||||
return failure();
|
return failure();
|
||||||
return success();
|
return success();
|
||||||
|
@ -231,15 +227,10 @@ static LogicalResult verify(ContractionOp op) {
|
||||||
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
|
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
|
||||||
return op.emitOpError("invalid number of vector masks specified");
|
return op.emitOpError("invalid number of vector masks specified");
|
||||||
if (lhsMaskType && rhsMaskType) {
|
if (lhsMaskType && rhsMaskType) {
|
||||||
// Verify tuple element size is != rank.
|
// Verify mask rank == argument rank.
|
||||||
if (lhsMaskType.getTypes().size() != lhsType.getShape().size() ||
|
if (lhsMaskType.getShape().size() != lhsType.getShape().size() ||
|
||||||
rhsMaskType.getTypes().size() != rhsType.getShape().size())
|
rhsMaskType.getShape().size() != rhsType.getShape().size())
|
||||||
return op.emitOpError("invalid number of vector mask elements");
|
return op.emitOpError("invalid vector mask rank");
|
||||||
// Verify all tuple elements are index type.
|
|
||||||
for (auto eltType : lhsMaskType.getTypes()) {
|
|
||||||
if (!eltType.isa<IndexType>())
|
|
||||||
return op.emitOpError("vector mask element must have index type");
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
@ -1218,33 +1209,9 @@ void CreateMaskOp::getCanonicalizationPatterns(
|
||||||
results.insert<CreateMaskFolder>(context);
|
results.insert<CreateMaskFolder>(context);
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
void mlir::vector::populateVectorToVectorCanonicalizationPatterns(
|
||||||
// IndexTupleOp
|
OwningRewritePatternList &patterns, MLIRContext *context) {
|
||||||
//===----------------------------------------------------------------------===//
|
patterns.insert<CreateMaskFolder, StridedSliceConstantMaskFolder>(context);
|
||||||
|
|
||||||
ParseResult parseIndexTupleOp(OpAsmParser &parser, OperationState &result) {
|
|
||||||
auto indexType = parser.getBuilder().getIndexType();
|
|
||||||
Type resultType;
|
|
||||||
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
|
||||||
return failure(
|
|
||||||
parser.parseOperandList(operandInfo) ||
|
|
||||||
parser.parseOptionalAttrDict(result.attributes) ||
|
|
||||||
parser.parseColonType(resultType) ||
|
|
||||||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
|
||||||
parser.addTypeToList(resultType, result.types));
|
|
||||||
}
|
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, IndexTupleOp &op) {
|
|
||||||
p << op.getOperationName() << ' ';
|
|
||||||
p.printOperands(op.operands());
|
|
||||||
p << " : " << op.getResult()->getType();
|
|
||||||
}
|
|
||||||
|
|
||||||
static LogicalResult verify(IndexTupleOp &op) {
|
|
||||||
for (auto operand : op.getOperands())
|
|
||||||
if (!operand->getType().isa<IndexType>())
|
|
||||||
return op.emitOpError("all operands must be of index type");
|
|
||||||
return success();
|
|
||||||
}
|
}
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
|
@ -278,9 +278,8 @@ static Value *getOrCreateUnrolledOperandSlice(
|
||||||
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
|
// with iteration bounds 'iterationBounds' unrolled to 'targetShape'.
|
||||||
// An iteration space index map argument 'iterationIndexMapList' must be
|
// An iteration space index map argument 'iterationIndexMapList' must be
|
||||||
// specified, with a map for each structured op input and a single map for the
|
// specified, with a map for each structured op input and a single map for the
|
||||||
// single result. The last map in the list must be the single result map.
|
// single result. The map at index 'indexMapListResultIndex' in the list must
|
||||||
// Extra operands can be passed to unrolled instances of 'op' using the
|
// be the single result map.
|
||||||
// 'extraOperands' argument.
|
|
||||||
//
|
//
|
||||||
// Example:
|
// Example:
|
||||||
//
|
//
|
||||||
|
@ -310,7 +309,7 @@ static Value *getOrCreateUnrolledOperandSlice(
|
||||||
static Value *unrollSingleResultStructuredOp(
|
static Value *unrollSingleResultStructuredOp(
|
||||||
Operation *op, ArrayRef<int64_t> iterationBounds,
|
Operation *op, ArrayRef<int64_t> iterationBounds,
|
||||||
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
|
std::vector<DenseMap<int64_t, int64_t>> &iterationIndexMapList,
|
||||||
ArrayRef<int64_t> targetShape, ArrayRef<Value *> extraOperands,
|
unsigned indexMapListResultIndex, ArrayRef<int64_t> targetShape,
|
||||||
PatternRewriter &builder) {
|
PatternRewriter &builder) {
|
||||||
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
|
auto shapedType = op->getResult(0)->getType().dyn_cast_or_null<ShapedType>();
|
||||||
if (!shapedType || !shapedType.hasStaticShape())
|
if (!shapedType || !shapedType.hasStaticShape())
|
||||||
|
@ -334,7 +333,7 @@ static Value *unrollSingleResultStructuredOp(
|
||||||
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
|
auto numUnrolledInstances = computeMaxLinearIndex(unrollFactors);
|
||||||
auto basis = computeStrides(unrollFactors);
|
auto basis = computeStrides(unrollFactors);
|
||||||
|
|
||||||
auto &resultOperandState = unrolledOperandState[numMaps - 1];
|
auto &resultOperandState = unrolledOperandState[indexMapListResultIndex];
|
||||||
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
|
auto unrolledResultType = VectorType::get(resultOperandState.unrolledShape,
|
||||||
shapedType.getElementType());
|
shapedType.getElementType());
|
||||||
|
|
||||||
|
@ -360,7 +359,6 @@ static Value *unrollSingleResultStructuredOp(
|
||||||
iterationIndexMapList[i], caches[i], builder));
|
iterationIndexMapList[i], caches[i], builder));
|
||||||
}
|
}
|
||||||
// Create op on sliced vector arguments.
|
// Create op on sliced vector arguments.
|
||||||
operands.append(extraOperands.begin(), extraOperands.end());
|
|
||||||
auto resultVector =
|
auto resultVector =
|
||||||
cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
|
cloneOpWithOperandsAndTypes(builder, op->getLoc(), op, operands,
|
||||||
unrolledResultType)
|
unrolledResultType)
|
||||||
|
@ -368,12 +366,14 @@ static Value *unrollSingleResultStructuredOp(
|
||||||
|
|
||||||
// Compute linear result index.
|
// Compute linear result index.
|
||||||
int64_t resultIndex = getUnrolledOperandLinearIndex(
|
int64_t resultIndex = getUnrolledOperandLinearIndex(
|
||||||
resultOperandState, vectorOffsets, iterationIndexMapList[numMaps - 1]);
|
resultOperandState, vectorOffsets,
|
||||||
|
iterationIndexMapList[indexMapListResultIndex]);
|
||||||
// Update result cache at 'resultIndex'.
|
// Update result cache at 'resultIndex'.
|
||||||
caches[numMaps - 1][resultIndex] = resultVector;
|
caches[indexMapListResultIndex][resultIndex] = resultVector;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Make zero splat into which we will insert results from 'cache[numMaps - 1]'
|
// Make zero splat into which we will insert results from
|
||||||
|
// 'cache[indexMapListResultIndex]'
|
||||||
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
|
auto resultVectorType = op->getResult(0)->getType().cast<VectorType>();
|
||||||
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
|
auto *res = makeSplatZero(op->getLoc(), builder, resultVectorType);
|
||||||
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
|
SmallVector<int64_t, 4> strides(resultOperandState.unrollFactors.size(), 1);
|
||||||
|
@ -384,7 +384,8 @@ static Value *unrollSingleResultStructuredOp(
|
||||||
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
auto offsets = zipMap([](int64_t v1, int64_t v2) { return v1 * v2; },
|
||||||
vectorOffsets, resultOperandState.unrolledShape);
|
vectorOffsets, resultOperandState.unrolledShape);
|
||||||
res = builder.create<vector::InsertStridedSliceOp>(
|
res = builder.create<vector::InsertStridedSliceOp>(
|
||||||
op->getLoc(), caches[numMaps - 1][i], res, offsets, strides);
|
op->getLoc(), caches[indexMapListResultIndex][i], res, offsets,
|
||||||
|
strides);
|
||||||
}
|
}
|
||||||
|
|
||||||
return res;
|
return res;
|
||||||
|
@ -434,13 +435,17 @@ Value * mlir::vector::unrollSingleResultOpMatchingType(PatternRewriter &builder,
|
||||||
// Get map from iteration space index to lhs/rhs/result shape index.
|
// Get map from iteration space index to lhs/rhs/result shape index.
|
||||||
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
std::vector<DenseMap<int64_t, int64_t>> iterationIndexMapList;
|
||||||
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
contractionOp.getIterationIndexMap(iterationIndexMapList);
|
||||||
// TODO(andydavis) Support unrollable vector masks.
|
if (llvm::size(contractionOp.masks()) == 2) {
|
||||||
SmallVector<Value *, 2> masks(contractionOp.masks().begin(),
|
// Add maps for lhs/rhs vector mask arguments (same lhs/rhs vector shape)
|
||||||
contractionOp.masks().end());
|
iterationIndexMapList.push_back(iterationIndexMapList[0]);
|
||||||
|
iterationIndexMapList.push_back(iterationIndexMapList[1]);
|
||||||
|
}
|
||||||
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
// Unroll 'op' 'iterationBounds' to 'targetShape'.
|
||||||
return unrollSingleResultStructuredOp(op, iterationBounds,
|
// TODO(andydavis) Use linalg style 'args_in'/'args_out' to partition
|
||||||
iterationIndexMapList, targetShape,
|
// 'iterationIndexMapList' instead of 'indexMapListResultIndex'.
|
||||||
masks, builder);
|
return unrollSingleResultStructuredOp(
|
||||||
|
op, iterationBounds, iterationIndexMapList,
|
||||||
|
/*indexMapListResultIndex=*/2, targetShape, builder);
|
||||||
}
|
}
|
||||||
// TODO(andydavis) Create trivial iteration bounds and index map for
|
// TODO(andydavis) Create trivial iteration bounds and index map for
|
||||||
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
|
// elementwise operations and call 'unrollSingleResultStructuredOp'. Remove
|
||||||
|
@ -680,6 +685,7 @@ void mlir::populateVectorToVectorConversionPatterns(
|
||||||
MLIRContext *context, OwningRewritePatternList &patterns,
|
MLIRContext *context, OwningRewritePatternList &patterns,
|
||||||
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
|
ArrayRef<int64_t> coarseVectorShape, ArrayRef<int64_t> fineVectorShape) {
|
||||||
vector::populateWithGenerated(context, &patterns);
|
vector::populateWithGenerated(context, &patterns);
|
||||||
|
vector::populateVectorToVectorCanonicalizationPatterns(patterns, context);
|
||||||
patterns
|
patterns
|
||||||
.insert<ConvertMatchingFakeForkFakeJoinOp,
|
.insert<ConvertMatchingFakeForkFakeJoinOp,
|
||||||
ConvertFakeForkFromBlockArgsOrTransferReadOp, ConvertFakeJoinOp,
|
ConvertFakeForkFromBlockArgsOrTransferReadOp, ConvertFakeJoinOp,
|
||||||
|
|
|
@ -66,33 +66,45 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
|
||||||
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A1S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R2S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S20]], %[[R1S00]], %[[LMASK1]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A0S04:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A1S40:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 0], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[LMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[RMASK2:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R3S00:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S40]], %[[R2S00]], %[[LMASK2]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
|
|
||||||
// Reducing output vector [0, 2]
|
// Reducing output vector [0, 2]
|
||||||
|
|
||||||
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[RMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK3]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A1S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[RMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R2S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S02]], %[[A1S22]], %[[R1S02]], %[[LMASK1]], %[[RMASK4]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A1S42:.*]] = vector.strided_slice %{{.*}} {offsets = [4, 2], sizes = [2, 2], strides = [1, 1]} : vector<6x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[RMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R3S02:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S04]], %[[A1S42]], %[[R2S02]], %[[LMASK2]], %[[RMASK5]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
|
|
||||||
// Reducing output vector [2, 0]
|
// Reducing output vector [2, 0]
|
||||||
|
|
||||||
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[LMASK3:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK3]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A0S22:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[LMASK4:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R2S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S22]], %[[A1S20]], %[[R1S20]], %[[LMASK4]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A0S24:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 4], sizes = [2, 2], strides = [1, 1]} : vector<4x6xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[LMASK5:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R3S20:.*]] = vector.contract {indexing_maps = [#map0, #map1, #map2], iterator_types = ["parallel", "parallel", "reduction"]} %[[A0S24]], %[[A1S40]], %[[R2S20]], %[[LMASK5]], %[[RMASK2]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
|
|
||||||
// Reducing output vector [2, 2]
|
// Reducing output vector [2, 2]
|
||||||
|
|
||||||
|
@ -111,9 +123,8 @@ func @add4x4(%0: vector<4x4xf32>, %1: vector<4x4xf32>) -> vector<4x4xf32> {
|
||||||
func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
|
func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
|
||||||
%arg2 : vector<4x4xf32>, %arg3 : index)
|
%arg2 : vector<4x4xf32>, %arg3 : index)
|
||||||
-> (vector<4x4xf32>) {
|
-> (vector<4x4xf32>) {
|
||||||
|
%lhsm = vector.constant_mask [4, 6] : vector<4x6xi1>
|
||||||
%lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
|
%rhsm = vector.constant_mask [6, 4] : vector<6x4xi1>
|
||||||
%rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
|
|
||||||
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
|
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2, %lhsm, %rhsm
|
||||||
: vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
|
: vector<4x6xf32>, vector<6x4xf32> into vector<4x4xf32>
|
||||||
|
|
||||||
|
@ -138,19 +149,23 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
|
||||||
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
|
// CHECK: %[[A0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A1S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[R0S00:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[LMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[RMASK0:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R1S00:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S00]], %[[R0S00]], %[[LMASK0]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
|
|
||||||
// Reducing output vector [0, 2]
|
// Reducing output vector [0, 2]
|
||||||
|
|
||||||
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A1S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<2x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[R0S02:.*]] = vector.strided_slice %{{.*}} {offsets = [0, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[RMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R1S02:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S00]], %[[A1S02]], %[[R0S02]], %[[LMASK0]], %[[RMASK1]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
|
|
||||||
// Reducing output vector [2, 0]
|
// Reducing output vector [2, 0]
|
||||||
|
|
||||||
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[A0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x2xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
// CHECK-NEXT: %[[R0S20:.*]] = vector.strided_slice %{{.*}} {offsets = [2, 0], sizes = [2, 2], strides = [1, 1]} : vector<4x4xf32> to vector<2x2xf32>
|
||||||
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %{{.*}}, %{{.*}} : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
// CHECK-NEXT: %[[LMASK1:.*]] = vector.constant_mask [2, 2] : vector<2x2xi1>
|
||||||
|
// CHECK-NEXT: %[[R1S20:.*]] = vector.contract {indexing_maps = [#map2, #map3, #map0], iterator_types = ["parallel", "reduction", "parallel"]} %[[A0S20]], %[[A1S00]], %[[R0S20]], %[[LMASK1]], %[[RMASK0]] : vector<2x2xf32>, vector<2x2xf32> into vector<2x2xf32>
|
||||||
|
|
||||||
// Reducing output vector [2, 2]
|
// Reducing output vector [2, 2]
|
||||||
|
|
||||||
|
@ -167,9 +182,8 @@ func @contraction4x4_ijk(%arg0 : vector<4x6xf32>, %arg1 : vector<6x4xf32>,
|
||||||
func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
|
func @contraction4x4_ikj(%arg0 : vector<4x2xf32>, %arg1 : vector<2x4xf32>,
|
||||||
%arg2 : vector<4x4xf32>, %arg3 : index)
|
%arg2 : vector<4x4xf32>, %arg3 : index)
|
||||||
-> (vector<4x4xf32>) {
|
-> (vector<4x4xf32>) {
|
||||||
|
%lhsm = vector.constant_mask [4, 2] : vector<4x2xi1>
|
||||||
%lhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
|
%rhsm = vector.constant_mask [2, 4] : vector<2x4xi1>
|
||||||
%rhsm = vector.make_index_tuple %arg3, %arg3 : tuple<index, index>
|
|
||||||
%0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
|
%0 = vector.contract #contraction_trait1 %arg0, %arg1, %arg2, %lhsm, %rhsm
|
||||||
: vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
|
: vector<4x2xf32>, vector<2x4xf32> into vector<4x4xf32>
|
||||||
|
|
||||||
|
|
|
@ -597,10 +597,8 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
%arg4 : index) {
|
%arg4 : index) {
|
||||||
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
|
||||||
: tuple<index, index, index, index>
|
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
|
||||||
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
|
||||||
: tuple<index, index, index, index>
|
|
||||||
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
|
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
|
||||||
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
|
|
@ -114,10 +114,8 @@ func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
||||||
%1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
|
%1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
// Test contraction with optional vector mask arguments.
|
// Test contraction with optional vector mask arguments.
|
||||||
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
%lhs_mask = vector.constant_mask [7, 8, 16, 15] : vector<7x8x16x15xi1>
|
||||||
: tuple<index, index, index, index>
|
%rhs_mask = vector.constant_mask [8, 16, 7, 5] : vector<8x16x7x5xi1>
|
||||||
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
|
||||||
: tuple<index, index, index, index>
|
|
||||||
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
|
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
|
||||||
%rhs_mask
|
%rhs_mask
|
||||||
|
|
Loading…
Reference in New Issue