forked from OSchip/llvm-project
Add VectorContractionOp to the VectorOps dialect.
PiperOrigin-RevId: 281605471
This commit is contained in:
parent
1145cebdab
commit
d6a70b31be
|
@ -49,6 +49,104 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
|
||||
// with operators other than the current set: {*, +}.
|
||||
// TODO(andydavis) Consider using AffineMaps to express contracting, batch
|
||||
// and free dimension pairs.
|
||||
def VectorContractionOp :
|
||||
Vector_Op<"contract", [NoSideEffect]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
|
||||
Variadic<TupleOf<[Index]>>:$masks)>,
|
||||
Results<(outs AnyVector)> {
|
||||
let summary = "vector contraction operation";
|
||||
let description = [{
|
||||
Computes the sum of products of vector elements along contracting
|
||||
dimension pairs from 2 vectors of rank M and N respectively, adds this
|
||||
intermediate result to the accumulator argument of rank K, and returns a
|
||||
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
|
||||
num_batch_dims (see dimension type descriptions below)).
|
||||
|
||||
Optional vector mask arguments specify the dynamic dimension sizes of
|
||||
valid data within the lhs/rhs vector arguments.
|
||||
|
||||
Dimensions for the arguments and result type fall into three categories:
|
||||
*) Contracting: contracting dimensions are present in the lhs and rhs
|
||||
arguments but not in the output (or optional accumulator
|
||||
argument). These are the dimensions along which the vector
|
||||
contraction op computes the sum of products, and contracting
|
||||
dimension pair dimension sizes must match between lhs/rhs.
|
||||
*) Batch: batch dimensions are non-contracting dimensions and so are
|
||||
present in the output and in the accumulator argument. The lhs
|
||||
and rhs co-iterate along the batch dimension and so dimension
|
||||
sizes must match across all arguments and result.
|
||||
*) Free: free dimensions are non-contraction, non-batch dimensions and
|
||||
are present in the output and accumulator argument. The lhs and
|
||||
rhs free dimensions are unrelated to each other and do not
|
||||
co-iterate.
|
||||
|
||||
Contracting and batch dimensions are specified as dimension pairs
|
||||
of logical dimension numbers: the first in the pair represents the lhs
|
||||
logical dimension number and the second in the pair represents the
|
||||
associated rhs logical dimension number. A dimension pair binds together
|
||||
logical dimension numbers from the lhs/rhs which co-iterate together, either
|
||||
as contracting or batch dimensions.
|
||||
|
||||
Examples:
|
||||
|
||||
// 2D vector contraction with one contracting dimension (matmul).
|
||||
%3 = vector.contract %0, %1, %2
|
||||
{ contracting_dim_map = [[1, 0]] }
|
||||
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
|
||||
|
||||
// 4D to 3D vector contraction with two contracting dimensions and
|
||||
// one batch dimension.
|
||||
%4 = vector.contract %0, %1, %2
|
||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
|
||||
// 4D vector contraction with two contracting dimensions and optional
|
||||
// vector mask arguments.
|
||||
%lhs_mask = vector.make_tuple %size0, %size1, %size2, %size3
|
||||
: tuple<index, index, index, index>
|
||||
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
|
||||
: tuple<index, index, index, index>
|
||||
|
||||
%5 = vector.contract %0, %1, %2, %lhs_mask, %rhs_mask
|
||||
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
}];
|
||||
let extraClassDeclaration = [{
|
||||
VectorType getLhsType() {
|
||||
return lhs()->getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getRhsType() {
|
||||
return rhs()->getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getAccType() {
|
||||
return acc()->getType().cast<VectorType>();
|
||||
}
|
||||
TupleType getLHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return TupleType();
|
||||
return getOperand(3)->getType().cast<TupleType>();
|
||||
}
|
||||
TupleType getRHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return TupleType();
|
||||
return getOperand(4)->getType().cast<TupleType>();
|
||||
}
|
||||
VectorType getResultType() {
|
||||
return getResult()->getType().cast<VectorType>();
|
||||
}
|
||||
static StringRef getContractingDimMapAttrName() {
|
||||
return "contracting_dim_map";
|
||||
}
|
||||
static StringRef getBatchDimMapAttrName() {
|
||||
return "batch_dim_map";
|
||||
}
|
||||
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
|
||||
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
|
||||
}];
|
||||
}
|
||||
|
||||
def VectorExtractElementOp :
|
||||
Vector_Op<"extractelement", [NoSideEffect,
|
||||
PredOpTrait<"operand and result have same element type",
|
||||
|
@ -391,4 +489,21 @@ def VectorTypeCastOp :
|
|||
}
|
||||
}];
|
||||
}
|
||||
|
||||
// TODO(andydavis) Morph this operation into a VectorMaskOp.
|
||||
def VectorIndexTupleOp :
|
||||
Vector_Op<"make_index_tuple", [NoSideEffect]>,
|
||||
Arguments<(ins Variadic<Index>:$operands)>,
|
||||
Results<(outs TupleOf<[Index]>)> {
|
||||
let summary = "creates a tuple of operand values";
|
||||
let description = [{
|
||||
Creates and returns a tuple of its operands which must be of index type.
|
||||
|
||||
Example:
|
||||
|
||||
%1 = vector.make_index_tuple %size0, %size1, %size2
|
||||
: tuple<index, index, index>
|
||||
|
||||
}];
|
||||
}
|
||||
#endif // VECTOR_OPS
|
||||
|
|
|
@ -43,6 +43,185 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
|||
>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorContractionOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static ParseResult parseVectorContractionOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
OpAsmParser::OperandType lhsInfo;
|
||||
OpAsmParser::OperandType rhsInfo;
|
||||
OpAsmParser::OperandType accInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 2> masksInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
Type resultVectorType;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
if (parser.parseOperand(lhsInfo) || parser.parseComma() ||
|
||||
parser.parseOperand(rhsInfo) || parser.parseComma() ||
|
||||
parser.parseOperand(accInfo) ||
|
||||
parser.parseTrailingOperandList(masksInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonTypeList(types) ||
|
||||
parser.parseKeywordType("into", resultVectorType) ||
|
||||
parser.resolveOperand(lhsInfo, types[0], result.operands) ||
|
||||
parser.resolveOperand(rhsInfo, types[1], result.operands) ||
|
||||
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
|
||||
parser.addTypeToList(resultVectorType, result.types))
|
||||
return failure();
|
||||
|
||||
if (masksInfo.empty())
|
||||
return success();
|
||||
if (masksInfo.size() != 2)
|
||||
return parser.emitError(parser.getNameLoc(),
|
||||
"expected zero or exactly 2 vector mask operands");
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
auto lhsType = types[0].cast<VectorType>();
|
||||
auto rhsType = types[1].cast<VectorType>();
|
||||
SmallVector<Type, 2> maskTypes;
|
||||
SmallVector<Type, 4> lhsMaskElementTypes(lhsType.getRank(), indexType);
|
||||
maskTypes.push_back(
|
||||
TupleType::get(lhsMaskElementTypes, parser.getBuilder().getContext()));
|
||||
SmallVector<Type, 4> rhsMaskElementTypes(rhsType.getRank(), indexType);
|
||||
maskTypes.push_back(
|
||||
TupleType::get(rhsMaskElementTypes, parser.getBuilder().getContext()));
|
||||
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, VectorContractionOp op) {
|
||||
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
||||
p << ", " << *op.acc();
|
||||
if (llvm::size(op.masks()) == 2) {
|
||||
p << ", " << **op.masks().begin();
|
||||
p << ", " << **(op.masks().begin() + 1);
|
||||
}
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
|
||||
<< op.getResultType();
|
||||
}
|
||||
|
||||
static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
|
||||
const std::vector<std::pair<int64_t, int64_t>> &map) {
|
||||
for (auto &dimPair : map) {
|
||||
if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
|
||||
dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
|
||||
lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static bool verifyOutputShape(
|
||||
VectorType lhsType, VectorType rhsType, VectorType accType,
|
||||
VectorType resType,
|
||||
const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
|
||||
const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
|
||||
DenseSet<int64_t> lhsContractingDimSet;
|
||||
DenseSet<int64_t> rhsContractingDimSet;
|
||||
for (auto &dimPair : contractingDimMap) {
|
||||
lhsContractingDimSet.insert(dimPair.first);
|
||||
rhsContractingDimSet.insert(dimPair.second);
|
||||
}
|
||||
DenseSet<int64_t> rhsBatchDimSet;
|
||||
for (auto &dimPair : batchDimMap)
|
||||
rhsBatchDimSet.insert(dimPair.second);
|
||||
|
||||
// Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
|
||||
SmallVector<int64_t, 4> expectedResultDims;
|
||||
for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
|
||||
if (lhsContractingDimSet.count(i) > 0)
|
||||
continue;
|
||||
expectedResultDims.push_back(lhsType.getDimSize(i));
|
||||
}
|
||||
|
||||
// Add free dimensions from 'rhsType' to 'expectedResultDims'.
|
||||
for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
|
||||
if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
|
||||
continue;
|
||||
expectedResultDims.push_back(rhsType.getDimSize(i));
|
||||
}
|
||||
|
||||
// Verify dimension from 'resType' against 'expectedResultDims'.
|
||||
if (resType.getShape().size() != expectedResultDims.size() ||
|
||||
accType.getShape().size() != expectedResultDims.size())
|
||||
return false;
|
||||
for (int64_t i = 0, e = resType.getRank(); i < e; ++i) {
|
||||
if (resType.getDimSize(i) != expectedResultDims[i] ||
|
||||
accType.getDimSize(i) != expectedResultDims[i])
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
static LogicalResult verify(VectorContractionOp op) {
|
||||
auto lhsType = op.getLhsType();
|
||||
auto rhsType = op.getRhsType();
|
||||
auto accType = op.getAccType();
|
||||
auto resType = op.getResultType();
|
||||
auto contractingDimMap = op.getContractingDimMap();
|
||||
auto batchDimMap = op.getBatchDimMap();
|
||||
|
||||
// Verify at least one contracting dimension pair was specified.
|
||||
if (contractingDimMap.empty())
|
||||
return op.emitOpError("expected at least one contracting dimension pair");
|
||||
|
||||
// Verify contracting dimension map was properly constructed.
|
||||
if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
|
||||
return op.emitOpError("invalid contracting dimension map");
|
||||
|
||||
// Verify batch dimension map was properly constructed.
|
||||
if (!verifyDimMap(lhsType, rhsType, batchDimMap))
|
||||
return op.emitOpError("invalid batch dimension map");
|
||||
|
||||
// Verify 'accType' and 'resType' shape.
|
||||
if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap,
|
||||
batchDimMap))
|
||||
return op.emitOpError("invalid accumulator/result vector shape");
|
||||
|
||||
// Verify that either two vector masks are set or none are set.
|
||||
auto lhsMaskType = op.getLHSVectorMaskType();
|
||||
auto rhsMaskType = op.getRHSVectorMaskType();
|
||||
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
|
||||
return op.emitOpError("invalid number of vector masks specified");
|
||||
if (lhsMaskType && rhsMaskType) {
|
||||
// Verify tuple element size is != rank.
|
||||
if (lhsMaskType.getTypes().size() != lhsType.getShape().size() ||
|
||||
rhsMaskType.getTypes().size() != rhsType.getShape().size())
|
||||
return op.emitOpError("invalid number of vector mask elements");
|
||||
// Verify all tuple elements are index type.
|
||||
for (auto eltType : lhsMaskType.getTypes()) {
|
||||
if (!eltType.isa<IndexType>())
|
||||
return op.emitOpError("vector mask element must have index type");
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static std::vector<std::pair<int64_t, int64_t>> getDimMap(Attribute attr) {
|
||||
std::vector<std::pair<int64_t, int64_t>> dimMap;
|
||||
auto dimPairs = attr.dyn_cast_or_null<ArrayAttr>();
|
||||
if (!dimPairs)
|
||||
return dimMap;
|
||||
for (auto dimPairAttr : dimPairs) {
|
||||
auto dimPair = dimPairAttr.cast<ArrayAttr>();
|
||||
assert(dimPair.size() == 2);
|
||||
auto lhsDim = dimPair.begin()->cast<IntegerAttr>().getInt();
|
||||
auto rhsDim = std::prev(dimPair.end())->cast<IntegerAttr>().getInt();
|
||||
dimMap.push_back({lhsDim, rhsDim});
|
||||
}
|
||||
return dimMap;
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>>
|
||||
VectorContractionOp::getContractingDimMap() {
|
||||
return getDimMap(getAttr(getContractingDimMapAttrName()));
|
||||
}
|
||||
|
||||
std::vector<std::pair<int64_t, int64_t>> VectorContractionOp::getBatchDimMap() {
|
||||
return getDimMap(getAttr(getBatchDimMapAttrName()));
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorExtractElementOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -541,6 +720,36 @@ static LogicalResult verify(VectorTypeCastOp &op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// VectorIndexTupleOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
ParseResult parseVectorIndexTupleOp(OpAsmParser &parser,
|
||||
OperationState &result) {
|
||||
auto indexType = parser.getBuilder().getIndexType();
|
||||
Type resultType;
|
||||
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
||||
return failure(
|
||||
parser.parseOperandList(operandInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonType(resultType) ||
|
||||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
||||
parser.addTypeToList(resultType, result.types));
|
||||
}
|
||||
|
||||
static void print(OpAsmPrinter &p, VectorIndexTupleOp &op) {
|
||||
p << op.getOperationName() << ' ';
|
||||
p.printOperands(op.operands());
|
||||
p << " : " << op.getResult()->getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(VectorIndexTupleOp &op) {
|
||||
for (auto operand : op.getOperands())
|
||||
if (!operand->getType().isa<IndexType>())
|
||||
return op.emitOpError("all operands must be of index type");
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace mlir {
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
|
|
|
@ -303,3 +303,71 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
|||
// expected-error@+1 {{op expected result type to be 'vector<2x8x16xf32>'}}
|
||||
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||
%arg4 : index) {
|
||||
// expected-error@+1 {{op expected at least one contracting dimension pair}}
|
||||
%0 = vector.contract %arg0, %arg1, %arg2
|
||||
{ batch_dim_map = [[1, 0]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||
%arg4 : index) {
|
||||
// expected-error@+1 {{invalid contracting dimension map}}
|
||||
%0 = vector.contract %arg0, %arg1, %arg2
|
||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[1, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||
%arg4 : index) {
|
||||
// expected-error@+1 {{invalid batch dimension map}}
|
||||
%0 = vector.contract %arg0, %arg1, %arg2
|
||||
{ batch_dim_map = [[1, 2]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||
%arg2: vector<8x15x88xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||
%arg4 : index) {
|
||||
// expected-error@+1 {{invalid accumulator/result vector shape}}
|
||||
%0 = vector.contract %arg0, %arg1, %arg2
|
||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x88xf32>
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||
%arg4 : index) {
|
||||
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||
: tuple<index, index, index, index>
|
||||
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||
: tuple<index, index, index, index>
|
||||
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
|
||||
%0 = vector.contract %arg0, %arg1, %arg2, %lhs_mask
|
||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -48,3 +48,33 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
|
|||
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
|
||||
return %1: vector<2x2x16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: contraction
|
||||
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
||||
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||
%arg4 : index) {
|
||||
// Test contraction with batch and contracting dims.
|
||||
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {batch_dim_map = {{.*}}1, 0{{.*}}, contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
%0 = vector.contract %arg0, %arg1, %arg2
|
||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||
|
||||
// Test contraction with only contracting dims. In this case the lhs/rhs
|
||||
// dimension of size 8 will be considered a free dim for lhs/rhs and will
|
||||
// appear twice in the output.
|
||||
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
%1 = vector.contract %arg0, %arg1, %arg3
|
||||
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
|
||||
// Test contraction with optional vector mask arguments.
|
||||
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||
: tuple<index, index, index, index>
|
||||
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||
: tuple<index, index, index, index>
|
||||
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
%2 = vector.contract %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask
|
||||
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue