forked from OSchip/llvm-project
Add VectorContractionOp to the VectorOps dialect.
PiperOrigin-RevId: 281605471
This commit is contained in:
parent
1145cebdab
commit
d6a70b31be
|
@ -49,6 +49,104 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
|
||||||
|
// with operators other than the current set: {*, +}.
|
||||||
|
// TODO(andydavis) Consider using AffineMaps to express contracting, batch
|
||||||
|
// and free dimension pairs.
|
||||||
|
def VectorContractionOp :
|
||||||
|
Vector_Op<"contract", [NoSideEffect]>,
|
||||||
|
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
|
||||||
|
Variadic<TupleOf<[Index]>>:$masks)>,
|
||||||
|
Results<(outs AnyVector)> {
|
||||||
|
let summary = "vector contraction operation";
|
||||||
|
let description = [{
|
||||||
|
Computes the sum of products of vector elements along contracting
|
||||||
|
dimension pairs from 2 vectors of rank M and N respectively, adds this
|
||||||
|
intermediate result to the accumulator argument of rank K, and returns a
|
||||||
|
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
|
||||||
|
num_batch_dims (see dimension type descriptions below)).
|
||||||
|
|
||||||
|
Optional vector mask arguments specify the dynamic dimension sizes of
|
||||||
|
valid data within the lhs/rhs vector arguments.
|
||||||
|
|
||||||
|
Dimensions for the arguments and result type fall into three categories:
|
||||||
|
*) Contracting: contracting dimensions are present in the lhs and rhs
|
||||||
|
arguments but not in the output (or optional accumulator
|
||||||
|
argument). These are the dimensions along which the vector
|
||||||
|
contraction op computes the sum of products, and contracting
|
||||||
|
dimension pair dimension sizes must match between lhs/rhs.
|
||||||
|
*) Batch: batch dimensions are non-contracting dimensions and so are
|
||||||
|
present in the output and in the accumulator argument. The lhs
|
||||||
|
and rhs co-iterate along the batch dimension and so dimension
|
||||||
|
sizes must match across all arguments and result.
|
||||||
|
*) Free: free dimensions are non-contraction, non-batch dimensions and
|
||||||
|
are present in the output and accumulator argument. The lhs and
|
||||||
|
rhs free dimensions are unrelated to each other and do not
|
||||||
|
co-iterate.
|
||||||
|
|
||||||
|
Contracting and batch dimensions are specified as dimension pairs
|
||||||
|
of logical dimension numbers: the first in the pair represents the lhs
|
||||||
|
logical dimension number and the second in the pair represents the
|
||||||
|
associated rhs logical dimension number. A dimension pair binds together
|
||||||
|
logical dimension numbers from the lhs/rhs which co-iterate together, either
|
||||||
|
as contracting or batch dimensions.
|
||||||
|
|
||||||
|
Examples:
|
||||||
|
|
||||||
|
// 2D vector contraction with one contracting dimension (matmul).
|
||||||
|
%3 = vector.contract %0, %1, %2
|
||||||
|
{ contracting_dim_map = [[1, 0]] }
|
||||||
|
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
|
||||||
|
|
||||||
|
// 4D to 3D vector contraction with two contracting dimensions and
|
||||||
|
// one batch dimension.
|
||||||
|
%4 = vector.contract %0, %1, %2
|
||||||
|
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
|
||||||
|
// 4D vector contraction with two contracting dimensions and optional
|
||||||
|
// vector mask arguments.
|
||||||
|
%lhs_mask = vector.make_tuple %size0, %size1, %size2, %size3
|
||||||
|
: tuple<index, index, index, index>
|
||||||
|
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
|
||||||
|
: tuple<index, index, index, index>
|
||||||
|
|
||||||
|
%5 = vector.contract %0, %1, %2, %lhs_mask, %rhs_mask
|
||||||
|
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
|
}];
|
||||||
|
let extraClassDeclaration = [{
|
||||||
|
VectorType getLhsType() {
|
||||||
|
return lhs()->getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
VectorType getRhsType() {
|
||||||
|
return rhs()->getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
VectorType getAccType() {
|
||||||
|
return acc()->getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
TupleType getLHSVectorMaskType() {
|
||||||
|
if (llvm::size(masks()) != 2) return TupleType();
|
||||||
|
return getOperand(3)->getType().cast<TupleType>();
|
||||||
|
}
|
||||||
|
TupleType getRHSVectorMaskType() {
|
||||||
|
if (llvm::size(masks()) != 2) return TupleType();
|
||||||
|
return getOperand(4)->getType().cast<TupleType>();
|
||||||
|
}
|
||||||
|
VectorType getResultType() {
|
||||||
|
return getResult()->getType().cast<VectorType>();
|
||||||
|
}
|
||||||
|
static StringRef getContractingDimMapAttrName() {
|
||||||
|
return "contracting_dim_map";
|
||||||
|
}
|
||||||
|
static StringRef getBatchDimMapAttrName() {
|
||||||
|
return "batch_dim_map";
|
||||||
|
}
|
||||||
|
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
|
||||||
|
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
|
||||||
|
}];
|
||||||
|
}
|
||||||
|
|
||||||
def VectorExtractElementOp :
|
def VectorExtractElementOp :
|
||||||
Vector_Op<"extractelement", [NoSideEffect,
|
Vector_Op<"extractelement", [NoSideEffect,
|
||||||
PredOpTrait<"operand and result have same element type",
|
PredOpTrait<"operand and result have same element type",
|
||||||
|
@ -391,4 +489,21 @@ def VectorTypeCastOp :
|
||||||
}
|
}
|
||||||
}];
|
}];
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(andydavis) Morph this operation into a VectorMaskOp.
|
||||||
|
def VectorIndexTupleOp :
|
||||||
|
Vector_Op<"make_index_tuple", [NoSideEffect]>,
|
||||||
|
Arguments<(ins Variadic<Index>:$operands)>,
|
||||||
|
Results<(outs TupleOf<[Index]>)> {
|
||||||
|
let summary = "creates a tuple of operand values";
|
||||||
|
let description = [{
|
||||||
|
Creates and returns a tuple of its operands which must be of index type.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
|
||||||
|
%1 = vector.make_index_tuple %size0, %size1, %size2
|
||||||
|
: tuple<index, index, index>
|
||||||
|
|
||||||
|
}];
|
||||||
|
}
|
||||||
#endif // VECTOR_OPS
|
#endif // VECTOR_OPS
|
||||||
|
|
|
@ -43,6 +43,185 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
|
||||||
>();
|
>();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// VectorContractionOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
static ParseResult parseVectorContractionOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
|
OpAsmParser::OperandType lhsInfo;
|
||||||
|
OpAsmParser::OperandType rhsInfo;
|
||||||
|
OpAsmParser::OperandType accInfo;
|
||||||
|
SmallVector<OpAsmParser::OperandType, 2> masksInfo;
|
||||||
|
SmallVector<Type, 2> types;
|
||||||
|
Type resultVectorType;
|
||||||
|
auto loc = parser.getCurrentLocation();
|
||||||
|
if (parser.parseOperand(lhsInfo) || parser.parseComma() ||
|
||||||
|
parser.parseOperand(rhsInfo) || parser.parseComma() ||
|
||||||
|
parser.parseOperand(accInfo) ||
|
||||||
|
parser.parseTrailingOperandList(masksInfo) ||
|
||||||
|
parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
|
parser.parseColonTypeList(types) ||
|
||||||
|
parser.parseKeywordType("into", resultVectorType) ||
|
||||||
|
parser.resolveOperand(lhsInfo, types[0], result.operands) ||
|
||||||
|
parser.resolveOperand(rhsInfo, types[1], result.operands) ||
|
||||||
|
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
|
||||||
|
parser.addTypeToList(resultVectorType, result.types))
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
if (masksInfo.empty())
|
||||||
|
return success();
|
||||||
|
if (masksInfo.size() != 2)
|
||||||
|
return parser.emitError(parser.getNameLoc(),
|
||||||
|
"expected zero or exactly 2 vector mask operands");
|
||||||
|
auto indexType = parser.getBuilder().getIndexType();
|
||||||
|
auto lhsType = types[0].cast<VectorType>();
|
||||||
|
auto rhsType = types[1].cast<VectorType>();
|
||||||
|
SmallVector<Type, 2> maskTypes;
|
||||||
|
SmallVector<Type, 4> lhsMaskElementTypes(lhsType.getRank(), indexType);
|
||||||
|
maskTypes.push_back(
|
||||||
|
TupleType::get(lhsMaskElementTypes, parser.getBuilder().getContext()));
|
||||||
|
SmallVector<Type, 4> rhsMaskElementTypes(rhsType.getRank(), indexType);
|
||||||
|
maskTypes.push_back(
|
||||||
|
TupleType::get(rhsMaskElementTypes, parser.getBuilder().getContext()));
|
||||||
|
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
|
||||||
|
return failure();
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print(OpAsmPrinter &p, VectorContractionOp op) {
|
||||||
|
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
||||||
|
p << ", " << *op.acc();
|
||||||
|
if (llvm::size(op.masks()) == 2) {
|
||||||
|
p << ", " << **op.masks().begin();
|
||||||
|
p << ", " << **(op.masks().begin() + 1);
|
||||||
|
}
|
||||||
|
p.printOptionalAttrDict(op.getAttrs());
|
||||||
|
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
|
||||||
|
<< op.getResultType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
|
||||||
|
const std::vector<std::pair<int64_t, int64_t>> &map) {
|
||||||
|
for (auto &dimPair : map) {
|
||||||
|
if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
|
||||||
|
dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
|
||||||
|
lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static bool verifyOutputShape(
|
||||||
|
VectorType lhsType, VectorType rhsType, VectorType accType,
|
||||||
|
VectorType resType,
|
||||||
|
const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
|
||||||
|
const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
|
||||||
|
DenseSet<int64_t> lhsContractingDimSet;
|
||||||
|
DenseSet<int64_t> rhsContractingDimSet;
|
||||||
|
for (auto &dimPair : contractingDimMap) {
|
||||||
|
lhsContractingDimSet.insert(dimPair.first);
|
||||||
|
rhsContractingDimSet.insert(dimPair.second);
|
||||||
|
}
|
||||||
|
DenseSet<int64_t> rhsBatchDimSet;
|
||||||
|
for (auto &dimPair : batchDimMap)
|
||||||
|
rhsBatchDimSet.insert(dimPair.second);
|
||||||
|
|
||||||
|
// Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
|
||||||
|
SmallVector<int64_t, 4> expectedResultDims;
|
||||||
|
for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
|
||||||
|
if (lhsContractingDimSet.count(i) > 0)
|
||||||
|
continue;
|
||||||
|
expectedResultDims.push_back(lhsType.getDimSize(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add free dimensions from 'rhsType' to 'expectedResultDims'.
|
||||||
|
for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
|
||||||
|
if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
|
||||||
|
continue;
|
||||||
|
expectedResultDims.push_back(rhsType.getDimSize(i));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify dimension from 'resType' against 'expectedResultDims'.
|
||||||
|
if (resType.getShape().size() != expectedResultDims.size() ||
|
||||||
|
accType.getShape().size() != expectedResultDims.size())
|
||||||
|
return false;
|
||||||
|
for (int64_t i = 0, e = resType.getRank(); i < e; ++i) {
|
||||||
|
if (resType.getDimSize(i) != expectedResultDims[i] ||
|
||||||
|
accType.getDimSize(i) != expectedResultDims[i])
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(VectorContractionOp op) {
|
||||||
|
auto lhsType = op.getLhsType();
|
||||||
|
auto rhsType = op.getRhsType();
|
||||||
|
auto accType = op.getAccType();
|
||||||
|
auto resType = op.getResultType();
|
||||||
|
auto contractingDimMap = op.getContractingDimMap();
|
||||||
|
auto batchDimMap = op.getBatchDimMap();
|
||||||
|
|
||||||
|
// Verify at least one contracting dimension pair was specified.
|
||||||
|
if (contractingDimMap.empty())
|
||||||
|
return op.emitOpError("expected at least one contracting dimension pair");
|
||||||
|
|
||||||
|
// Verify contracting dimension map was properly constructed.
|
||||||
|
if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
|
||||||
|
return op.emitOpError("invalid contracting dimension map");
|
||||||
|
|
||||||
|
// Verify batch dimension map was properly constructed.
|
||||||
|
if (!verifyDimMap(lhsType, rhsType, batchDimMap))
|
||||||
|
return op.emitOpError("invalid batch dimension map");
|
||||||
|
|
||||||
|
// Verify 'accType' and 'resType' shape.
|
||||||
|
if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap,
|
||||||
|
batchDimMap))
|
||||||
|
return op.emitOpError("invalid accumulator/result vector shape");
|
||||||
|
|
||||||
|
// Verify that either two vector masks are set or none are set.
|
||||||
|
auto lhsMaskType = op.getLHSVectorMaskType();
|
||||||
|
auto rhsMaskType = op.getRHSVectorMaskType();
|
||||||
|
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
|
||||||
|
return op.emitOpError("invalid number of vector masks specified");
|
||||||
|
if (lhsMaskType && rhsMaskType) {
|
||||||
|
// Verify tuple element size is != rank.
|
||||||
|
if (lhsMaskType.getTypes().size() != lhsType.getShape().size() ||
|
||||||
|
rhsMaskType.getTypes().size() != rhsType.getShape().size())
|
||||||
|
return op.emitOpError("invalid number of vector mask elements");
|
||||||
|
// Verify all tuple elements are index type.
|
||||||
|
for (auto eltType : lhsMaskType.getTypes()) {
|
||||||
|
if (!eltType.isa<IndexType>())
|
||||||
|
return op.emitOpError("vector mask element must have index type");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::pair<int64_t, int64_t>> getDimMap(Attribute attr) {
|
||||||
|
std::vector<std::pair<int64_t, int64_t>> dimMap;
|
||||||
|
auto dimPairs = attr.dyn_cast_or_null<ArrayAttr>();
|
||||||
|
if (!dimPairs)
|
||||||
|
return dimMap;
|
||||||
|
for (auto dimPairAttr : dimPairs) {
|
||||||
|
auto dimPair = dimPairAttr.cast<ArrayAttr>();
|
||||||
|
assert(dimPair.size() == 2);
|
||||||
|
auto lhsDim = dimPair.begin()->cast<IntegerAttr>().getInt();
|
||||||
|
auto rhsDim = std::prev(dimPair.end())->cast<IntegerAttr>().getInt();
|
||||||
|
dimMap.push_back({lhsDim, rhsDim});
|
||||||
|
}
|
||||||
|
return dimMap;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<int64_t, int64_t>>
|
||||||
|
VectorContractionOp::getContractingDimMap() {
|
||||||
|
return getDimMap(getAttr(getContractingDimMapAttrName()));
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<std::pair<int64_t, int64_t>> VectorContractionOp::getBatchDimMap() {
|
||||||
|
return getDimMap(getAttr(getBatchDimMapAttrName()));
|
||||||
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
// VectorExtractElementOp
|
// VectorExtractElementOp
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
@ -541,6 +720,36 @@ static LogicalResult verify(VectorTypeCastOp &op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
// VectorIndexTupleOp
|
||||||
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
||||||
|
ParseResult parseVectorIndexTupleOp(OpAsmParser &parser,
|
||||||
|
OperationState &result) {
|
||||||
|
auto indexType = parser.getBuilder().getIndexType();
|
||||||
|
Type resultType;
|
||||||
|
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
|
||||||
|
return failure(
|
||||||
|
parser.parseOperandList(operandInfo) ||
|
||||||
|
parser.parseOptionalAttrDict(result.attributes) ||
|
||||||
|
parser.parseColonType(resultType) ||
|
||||||
|
parser.resolveOperands(operandInfo, indexType, result.operands) ||
|
||||||
|
parser.addTypeToList(resultType, result.types));
|
||||||
|
}
|
||||||
|
|
||||||
|
static void print(OpAsmPrinter &p, VectorIndexTupleOp &op) {
|
||||||
|
p << op.getOperationName() << ' ';
|
||||||
|
p.printOperands(op.operands());
|
||||||
|
p << " : " << op.getResult()->getType();
|
||||||
|
}
|
||||||
|
|
||||||
|
static LogicalResult verify(VectorIndexTupleOp &op) {
|
||||||
|
for (auto operand : op.getOperands())
|
||||||
|
if (!operand->getType().isa<IndexType>())
|
||||||
|
return op.emitOpError("all operands must be of index type");
|
||||||
|
return success();
|
||||||
|
}
|
||||||
|
|
||||||
namespace mlir {
|
namespace mlir {
|
||||||
|
|
||||||
#define GET_OP_CLASSES
|
#define GET_OP_CLASSES
|
||||||
|
|
|
@ -303,3 +303,71 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||||
// expected-error@+1 {{op expected result type to be 'vector<2x8x16xf32>'}}
|
// expected-error@+1 {{op expected result type to be 'vector<2x8x16xf32>'}}
|
||||||
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
|
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{op expected at least one contracting dimension pair}}
|
||||||
|
%0 = vector.contract %arg0, %arg1, %arg2
|
||||||
|
{ batch_dim_map = [[1, 0]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{invalid contracting dimension map}}
|
||||||
|
%0 = vector.contract %arg0, %arg1, %arg2
|
||||||
|
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[1, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{invalid batch dimension map}}
|
||||||
|
%0 = vector.contract %arg0, %arg1, %arg2
|
||||||
|
{ batch_dim_map = [[1, 2]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x88xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{invalid accumulator/result vector shape}}
|
||||||
|
%0 = vector.contract %arg0, %arg1, %arg2
|
||||||
|
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x88xf32>
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||||
|
: tuple<index, index, index, index>
|
||||||
|
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||||
|
: tuple<index, index, index, index>
|
||||||
|
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
|
||||||
|
%0 = vector.contract %arg0, %arg1, %arg2, %lhs_mask
|
||||||
|
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -48,3 +48,33 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
|
||||||
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
|
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
|
||||||
return %1: vector<2x2x16xf32>
|
return %1: vector<2x2x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CHECK-LABEL: contraction
|
||||||
|
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
||||||
|
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// Test contraction with batch and contracting dims.
|
||||||
|
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {batch_dim_map = {{.*}}1, 0{{.*}}, contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
%0 = vector.contract %arg0, %arg1, %arg2
|
||||||
|
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
|
||||||
|
// Test contraction with only contracting dims. In this case the lhs/rhs
|
||||||
|
// dimension of size 8 will be considered a free dim for lhs/rhs and will
|
||||||
|
// appear twice in the output.
|
||||||
|
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
|
%1 = vector.contract %arg0, %arg1, %arg3
|
||||||
|
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
|
|
||||||
|
// Test contraction with optional vector mask arguments.
|
||||||
|
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||||
|
: tuple<index, index, index, index>
|
||||||
|
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||||
|
: tuple<index, index, index, index>
|
||||||
|
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
|
%2 = vector.contract %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask
|
||||||
|
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue