forked from OSchip/llvm-project
Update VectorContractionOp to take iterator types and index mapping attributes compatible with linalg ops.
PiperOrigin-RevId: 282412311
This commit is contained in:
parent
d60133f89b
commit
8fc44a4d13
|
@ -46,12 +46,11 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||||
|
|
||||||
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
|
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
|
||||||
// with operators other than the current set: {*, +}.
|
// with operators other than the current set: {*, +}.
|
||||||
// TODO(andydavis) Consider using AffineMaps to express contracting, batch
|
|
||||||
// and free dimension pairs.
|
|
||||||
def Vector_ContractionOp :
|
def Vector_ContractionOp :
|
||||||
Vector_Op<"contract", [NoSideEffect]>,
|
Vector_Op<"contract", [NoSideEffect]>,
|
||||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
|
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
|
||||||
Variadic<TupleOf<[Index]>>:$masks)>,
|
Variadic<TupleOf<[Index]>>:$masks,
|
||||||
|
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
||||||
Results<(outs AnyVector)> {
|
Results<(outs AnyVector)> {
|
||||||
let summary = "vector contraction operation";
|
let summary = "vector contraction operation";
|
||||||
let description = [{
|
let description = [{
|
||||||
|
@ -64,39 +63,59 @@ def Vector_ContractionOp :
|
||||||
Optional vector mask arguments specify the dynamic dimension sizes of
|
Optional vector mask arguments specify the dynamic dimension sizes of
|
||||||
valid data within the lhs/rhs vector arguments.
|
valid data within the lhs/rhs vector arguments.
|
||||||
|
|
||||||
Dimensions for the arguments and result type fall into three categories:
|
An iterator type attribute list must be specified, where each element of
|
||||||
*) Contracting: contracting dimensions are present in the lhs and rhs
|
the list represents an iterator with one of the following types:
|
||||||
|
|
||||||
|
*) "reduction": reduction dimensions are present in the lhs and rhs
|
||||||
arguments but not in the output (or optional accumulator
|
arguments but not in the output (or optional accumulator
|
||||||
argument). These are the dimensions along which the vector
|
argument). These are the dimensions along which the vector
|
||||||
contraction op computes the sum of products, and contracting
|
contraction op computes the sum of products, and
|
||||||
dimension pair dimension sizes must match between lhs/rhs.
|
contracting dimension pair dimension sizes must match
|
||||||
*) Batch: batch dimensions are non-contracting dimensions and so are
|
between lhs/rhs.
|
||||||
present in the output and in the accumulator argument. The lhs
|
*) "parallel": Batch dimensions are iterator type "parallel", and
|
||||||
and rhs co-iterate along the batch dimension and so dimension
|
are non-contracting dimensions present in the lhs, rhs and
|
||||||
sizes must match across all arguments and result.
|
output. The lhs/rhs co-iterate along the batch dimensions,
|
||||||
*) Free: free dimensions are non-contraction, non-batch dimensions and
|
which should be expressed in their indexing maps.
|
||||||
are present in the output and accumulator argument. The lhs and
|
|
||||||
rhs free dimensions are unrelated to each other and do not
|
|
||||||
co-iterate.
|
|
||||||
|
|
||||||
Contracting and batch dimensions are specified as dimension pairs
|
Free dimensions are iterator type "parallel", and are
|
||||||
of logical dimension numbers: the first in the pair represents the lhs
|
non-contraction, non-batch dimensions accessed by either the
|
||||||
logical dimension number and the second in the pair represents the
|
lhs or rhs (but not both). The lhs and rhs free dimensions
|
||||||
associated rhs logical dimension number. A dimension pair binds together
|
are unrelated to each other and do not co-iterate, which
|
||||||
logical dimension numbers from the lhs/rhs which co-iterate together, either
|
should be expressed in their indexing maps.
|
||||||
as contracting or batch dimensions.
|
|
||||||
|
An indexing map attribute list must be specified with an entry for lhs, rhs
|
||||||
|
and acc arguments. An indexing map attribute specifies a mapping from each
|
||||||
|
iterator in the iterator type list, to each dimension of an N-D vector.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
|
|
||||||
// 2D vector contraction with one contracting dimension (matmul).
|
// 2D vector contraction with one contracting dimension (matmul).
|
||||||
%3 = vector.contract %0, %1, %2
|
#contraction_accesses = [
|
||||||
{ contracting_dim_map = [[1, 0]] }
|
(i, j, k) -> (i, k),
|
||||||
|
(i, j, k) -> (k, j),
|
||||||
|
(i, j, k) -> (i, j)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = [parallel, parallel, reduction]
|
||||||
|
}
|
||||||
|
|
||||||
|
%3 = vector.contract #contraction_trait %0, %1, %2
|
||||||
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
|
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
|
||||||
|
|
||||||
// 4D to 3D vector contraction with two contracting dimensions and
|
// 4D to 3D vector contraction with two contracting dimensions and
|
||||||
// one batch dimension.
|
// one batch dimension.
|
||||||
%4 = vector.contract %0, %1, %2
|
#contraction_accesses = [
|
||||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = [parallel, parallel, parallel reduction, reduction]
|
||||||
|
}
|
||||||
|
|
||||||
|
%4 = vector.contract #contraction_trait %0, %1, %2
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
|
||||||
// 4D vector contraction with two contracting dimensions and optional
|
// 4D vector contraction with two contracting dimensions and optional
|
||||||
|
@ -106,8 +125,7 @@ def Vector_ContractionOp :
|
||||||
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
|
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
|
||||||
: tuple<index, index, index, index>
|
: tuple<index, index, index, index>
|
||||||
|
|
||||||
%5 = vector.contract %0, %1, %2, %lhs_mask, %rhs_mask
|
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
|
||||||
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
}];
|
}];
|
||||||
let extraClassDeclaration = [{
|
let extraClassDeclaration = [{
|
||||||
|
@ -131,11 +149,13 @@ def Vector_ContractionOp :
|
||||||
VectorType getResultType() {
|
VectorType getResultType() {
|
||||||
return getResult()->getType().cast<VectorType>();
|
return getResult()->getType().cast<VectorType>();
|
||||||
}
|
}
|
||||||
static StringRef getContractingDimMapAttrName() {
|
SmallVector<StringRef, 2> getTraitAttrNames();
|
||||||
return "contracting_dim_map";
|
SmallVector<AffineMap, 4> getIndexingMaps();
|
||||||
|
static StringRef getReductionIteratorTypeName() {
|
||||||
|
return "reduction";
|
||||||
}
|
}
|
||||||
static StringRef getBatchDimMapAttrName() {
|
static StringRef getParallelIteratorTypeName() {
|
||||||
return "batch_dim_map";
|
return "parallel";
|
||||||
}
|
}
|
||||||
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
|
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
|
||||||
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
|
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
|
||||||
|
|
|
@ -27,6 +27,7 @@
|
||||||
#include "mlir/IR/OpImplementation.h"
|
#include "mlir/IR/OpImplementation.h"
|
||||||
#include "mlir/IR/TypeUtilities.h"
|
#include "mlir/IR/TypeUtilities.h"
|
||||||
#include "mlir/Support/LLVM.h"
|
#include "mlir/Support/LLVM.h"
|
||||||
|
#include "llvm/ADT/StringSet.h"
|
||||||
|
|
||||||
using namespace mlir;
|
using namespace mlir;
|
||||||
using namespace mlir::vector;
|
using namespace mlir::vector;
|
||||||
|
@ -56,7 +57,10 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
|
||||||
SmallVector<Type, 2> types;
|
SmallVector<Type, 2> types;
|
||||||
Type resultVectorType;
|
Type resultVectorType;
|
||||||
auto loc = parser.getCurrentLocation();
|
auto loc = parser.getCurrentLocation();
|
||||||
if (parser.parseOperand(lhsInfo) || parser.parseComma() ||
|
DictionaryAttr dictAttr;
|
||||||
|
// TODO(andydavis, ntv) Unify linalg op attribute parsing.
|
||||||
|
if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
|
||||||
|
parser.parseOperand(lhsInfo) || parser.parseComma() ||
|
||||||
parser.parseOperand(rhsInfo) || parser.parseComma() ||
|
parser.parseOperand(rhsInfo) || parser.parseComma() ||
|
||||||
parser.parseOperand(accInfo) ||
|
parser.parseOperand(accInfo) ||
|
||||||
parser.parseTrailingOperandList(masksInfo) ||
|
parser.parseTrailingOperandList(masksInfo) ||
|
||||||
|
@ -68,7 +72,8 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
|
||||||
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
|
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
|
||||||
parser.addTypeToList(resultVectorType, result.types))
|
parser.addTypeToList(resultVectorType, result.types))
|
||||||
return failure();
|
return failure();
|
||||||
|
result.attributes.assign(dictAttr.getValue().begin(),
|
||||||
|
dictAttr.getValue().end());
|
||||||
if (masksInfo.empty())
|
if (masksInfo.empty())
|
||||||
return success();
|
return success();
|
||||||
if (masksInfo.size() != 2)
|
if (masksInfo.size() != 2)
|
||||||
|
@ -90,13 +95,23 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
|
||||||
}
|
}
|
||||||
|
|
||||||
static void print(OpAsmPrinter &p, ContractionOp op) {
|
static void print(OpAsmPrinter &p, ContractionOp op) {
|
||||||
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
|
// TODO(andydavis, ntv) Unify printing code with linalg ops.
|
||||||
p << ", " << *op.acc();
|
auto attrNames = op.getTraitAttrNames();
|
||||||
|
llvm::StringSet<> traitAttrsSet;
|
||||||
|
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
|
||||||
|
SmallVector<NamedAttribute, 8> attrs;
|
||||||
|
for (auto attr : op.getAttrs()) {
|
||||||
|
if (traitAttrsSet.count(attr.first.strref()) > 0)
|
||||||
|
attrs.push_back(attr);
|
||||||
|
}
|
||||||
|
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
|
||||||
|
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
|
||||||
|
p << *op.rhs() << ", " << *op.acc();
|
||||||
if (llvm::size(op.masks()) == 2) {
|
if (llvm::size(op.masks()) == 2) {
|
||||||
p << ", " << **op.masks().begin();
|
p << ", " << **op.masks().begin();
|
||||||
p << ", " << **(op.masks().begin() + 1);
|
p << ", " << **(op.masks().begin() + 1);
|
||||||
}
|
}
|
||||||
p.printOptionalAttrDict(op.getAttrs());
|
p.printOptionalAttrDict(op.getAttrs(), attrNames);
|
||||||
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
|
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
|
||||||
<< op.getResultType();
|
<< op.getResultType();
|
||||||
}
|
}
|
||||||
|
@ -159,6 +174,34 @@ static LogicalResult verify(ContractionOp op) {
|
||||||
auto rhsType = op.getRhsType();
|
auto rhsType = op.getRhsType();
|
||||||
auto accType = op.getAccType();
|
auto accType = op.getAccType();
|
||||||
auto resType = op.getResultType();
|
auto resType = op.getResultType();
|
||||||
|
|
||||||
|
// Verify that an indexing map was specified for each vector operand.
|
||||||
|
if (op.indexing_maps().size() != 3)
|
||||||
|
return op.emitOpError("expected an indexing map for each vector operand");
|
||||||
|
|
||||||
|
// Verify that each index map has 'numIterators' inputs, no symbols, and
|
||||||
|
// that the number of map outputs equals the rank of its associated
|
||||||
|
// vector operand.
|
||||||
|
unsigned numIterators = op.iterator_types().getValue().size();
|
||||||
|
for (auto it : llvm::enumerate(op.indexing_maps())) {
|
||||||
|
auto index = it.index();
|
||||||
|
auto map = it.value().cast<AffineMapAttr>().getValue();
|
||||||
|
if (map.getNumSymbols() != 0)
|
||||||
|
return op.emitOpError("expected indexing map ")
|
||||||
|
<< index << " to have no symbols";
|
||||||
|
if (map.getNumDims() != numIterators)
|
||||||
|
return op.emitOpError("expected indexing map ")
|
||||||
|
<< index << " to have " << numIterators << " number of inputs";
|
||||||
|
auto operandType = op.getOperand(index)->getType().cast<VectorType>();
|
||||||
|
unsigned rank = operandType.getShape().size();
|
||||||
|
if (map.getNumResults() != rank)
|
||||||
|
return op.emitOpError("expected indexing map ")
|
||||||
|
<< index << " to have " << rank << " number of outputs";
|
||||||
|
if (!map.isProjectedPermutation())
|
||||||
|
return op.emitOpError("expected indexing map ")
|
||||||
|
<< index << " to be a projected permutation of its inputs";
|
||||||
|
}
|
||||||
|
|
||||||
auto contractingDimMap = op.getContractingDimMap();
|
auto contractingDimMap = op.getContractingDimMap();
|
||||||
auto batchDimMap = op.getBatchDimMap();
|
auto batchDimMap = op.getBatchDimMap();
|
||||||
|
|
||||||
|
@ -198,27 +241,54 @@ static LogicalResult verify(ContractionOp op) {
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
static std::vector<std::pair<int64_t, int64_t>> getDimMap(Attribute attr) {
|
SmallVector<StringRef, 2> ContractionOp::getTraitAttrNames() {
|
||||||
|
return SmallVector<StringRef, 2>{"indexing_maps", "iterator_types"};
|
||||||
|
}
|
||||||
|
|
||||||
|
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
|
||||||
|
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
|
||||||
|
if (targetExpr == map.getResult(i))
|
||||||
|
return i;
|
||||||
|
return -1;
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<std::pair<int64_t, int64_t>>
|
||||||
|
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
|
||||||
|
StringRef targetIteratorTypeName, MLIRContext *context) {
|
||||||
std::vector<std::pair<int64_t, int64_t>> dimMap;
|
std::vector<std::pair<int64_t, int64_t>> dimMap;
|
||||||
auto dimPairs = attr.dyn_cast_or_null<ArrayAttr>();
|
for (auto it : llvm::enumerate(iteratorTypes)) {
|
||||||
if (!dimPairs)
|
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
|
||||||
return dimMap;
|
if (iteratorTypeName != targetIteratorTypeName)
|
||||||
for (auto dimPairAttr : dimPairs) {
|
continue;
|
||||||
auto dimPair = dimPairAttr.cast<ArrayAttr>();
|
// Search lhs/rhs map results for 'targetExpr'.
|
||||||
assert(dimPair.size() == 2);
|
auto targetExpr = getAffineDimExpr(it.index(), context);
|
||||||
auto lhsDim = dimPair.begin()->cast<IntegerAttr>().getInt();
|
int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
|
||||||
auto rhsDim = std::prev(dimPair.end())->cast<IntegerAttr>().getInt();
|
int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
|
||||||
|
if (lhsDim >= 0 && rhsDim >= 0)
|
||||||
dimMap.push_back({lhsDim, rhsDim});
|
dimMap.push_back({lhsDim, rhsDim});
|
||||||
}
|
}
|
||||||
return dimMap;
|
return dimMap;
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
|
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
|
||||||
return getDimMap(getAttr(getContractingDimMapAttrName()));
|
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
||||||
|
return getDimMap(indexingMaps, iterator_types(),
|
||||||
|
getReductionIteratorTypeName(), getContext());
|
||||||
}
|
}
|
||||||
|
|
||||||
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
|
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
|
||||||
return getDimMap(getAttr(getBatchDimMapAttrName()));
|
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
||||||
|
return getDimMap(indexingMaps, iterator_types(),
|
||||||
|
getParallelIteratorTypeName(), getContext());
|
||||||
|
}
|
||||||
|
|
||||||
|
SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
|
||||||
|
SmallVector<AffineMap, 4> res;
|
||||||
|
auto mapAttrs = indexing_maps().getValue();
|
||||||
|
res.reserve(mapAttrs.size());
|
||||||
|
for (auto mapAttr : mapAttrs)
|
||||||
|
res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
|
||||||
|
return res;
|
||||||
}
|
}
|
||||||
|
|
||||||
//===----------------------------------------------------------------------===//
|
//===----------------------------------------------------------------------===//
|
||||||
|
|
|
@ -341,56 +341,196 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) {
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{expected an indexing map for each vector operand}}
|
||||||
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, c0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{expected indexing map 0 to be a projected permutation of its inputs}}
|
||||||
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1)[s0] -> (b0, s0, c0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{op expected indexing map 1 to have no symbols}}
|
||||||
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
|
||||||
|
(b0, f0, f1, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{expected indexing map 2 to have 5 number of inputs}}
|
||||||
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
|
%arg4 : index) {
|
||||||
|
// expected-error@+1 {{expected indexing map 1 to have 4 number of outputs}}
|
||||||
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, b1, b2) -> (b1, b0, b2, f0),
|
||||||
|
(b0, f0, f1, b1, b2) -> (b0, b2, b1, f1),
|
||||||
|
(b0, f0, f1, b1, b2) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
|
||||||
|
}
|
||||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
%arg4 : index) {
|
%arg4 : index) {
|
||||||
// expected-error@+1 {{op expected at least one contracting dimension pair}}
|
// expected-error@+1 {{op expected at least one contracting dimension pair}}
|
||||||
%0 = vector.contract %arg0, %arg1, %arg2
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
{ batch_dim_map = [[1, 0]] }
|
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c1, b0, c0, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
%arg4 : index) {
|
%arg4 : index) {
|
||||||
// expected-error@+1 {{invalid contracting dimension map}}
|
// expected-error@+1 {{invalid contracting dimension map}}
|
||||||
%0 = vector.contract %arg0, %arg1, %arg2
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[1, 2], [2, 1]] }
|
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (f1, c1, c0, b0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
%arg4 : index) {
|
%arg4 : index) {
|
||||||
// expected-error@+1 {{invalid batch dimension map}}
|
// expected-error@+1 {{invalid batch dimension map}}
|
||||||
%0 = vector.contract %arg0, %arg1, %arg2
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
{ batch_dim_map = [[1, 2]], contracting_dim_map = [[0, 2], [2, 1]] }
|
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
%arg2: vector<8x15x88xf32>, %arg3 : vector<8x15x8x5xf32>,
|
%arg2: vector<88x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
%arg4 : index) {
|
%arg4 : index) {
|
||||||
// expected-error@+1 {{invalid accumulator/result vector shape}}
|
// expected-error@+1 {{invalid accumulator/result vector shape}}
|
||||||
%0 = vector.contract %arg0, %arg1, %arg2
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<88x15x5xf32>
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x88xf32>
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
|
|
||||||
|
#contraction_accesses = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait = {
|
||||||
|
indexing_maps = #contraction_accesses,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
%arg4 : index) {
|
%arg4 : index) {
|
||||||
|
@ -399,8 +539,7 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
||||||
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||||
: tuple<index, index, index, index>
|
: tuple<index, index, index, index>
|
||||||
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
|
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
|
||||||
%0 = vector.contract %arg0, %arg1, %arg2, %lhs_mask
|
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask
|
||||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,32 +60,47 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
|
||||||
return %1: vector<2x2x16xf32>
|
return %1: vector<2x2x16xf32>
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#contraction_accesses0 = [
|
||||||
|
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
|
||||||
|
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
|
||||||
|
]
|
||||||
|
#contraction_trait0 = {
|
||||||
|
indexing_maps = #contraction_accesses0,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
|
||||||
|
}
|
||||||
|
#contraction_accesses1 = [
|
||||||
|
(f0, f1, f2, f3, c0, c1) -> (c0, f0, c1, f2),
|
||||||
|
(f0, f1, f2, f3, c0, c1) -> (f1, c1, c0, f3),
|
||||||
|
(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)
|
||||||
|
]
|
||||||
|
#contraction_trait1 = {
|
||||||
|
indexing_maps = #contraction_accesses1,
|
||||||
|
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction",
|
||||||
|
"reduction"]
|
||||||
|
}
|
||||||
// CHECK-LABEL: contraction
|
// CHECK-LABEL: contraction
|
||||||
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
|
||||||
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
|
||||||
%arg4 : index) {
|
%arg4 : index) {
|
||||||
// Test contraction with batch and contracting dims.
|
// Test contraction with batch and contracting dims.
|
||||||
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {batch_dim_map = {{.*}}1, 0{{.*}}, contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
%0 = vector.contract %arg0, %arg1, %arg2
|
%0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
|
||||||
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
|
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
|
||||||
|
|
||||||
// Test contraction with only contracting dims. In this case the lhs/rhs
|
// Test contraction with only contracting dims. In this case the lhs/rhs
|
||||||
// dimension of size 8 will be considered a free dim for lhs/rhs and will
|
// dimension of size 8 will be considered a parallel dim for lhs/rhs and will
|
||||||
// appear twice in the output.
|
// appear twice in the output.
|
||||||
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
%1 = vector.contract %arg0, %arg1, %arg3
|
%1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
|
||||||
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
|
|
||||||
// Test contraction with optional vector mask arguments.
|
// Test contraction with optional vector mask arguments.
|
||||||
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||||
: tuple<index, index, index, index>
|
: tuple<index, index, index, index>
|
||||||
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
|
||||||
: tuple<index, index, index, index>
|
: tuple<index, index, index, index>
|
||||||
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
// CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
%2 = vector.contract %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask
|
%2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
|
||||||
{ contracting_dim_map = [[0, 2], [2, 1]] }
|
%rhs_mask
|
||||||
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue