Update VectorContractionOp to take iterator types and index mapping attributes compatible with linalg ops.

PiperOrigin-RevId: 282412311
This commit is contained in:
Andy Davis 2019-11-25 12:39:30 -08:00 committed by A. Unique TensorFlower
parent d60133f89b
commit 8fc44a4d13
4 changed files with 319 additions and 75 deletions

View File

@ -46,12 +46,11 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO(andydavis, ntv) Add an attribute to specify a different algebra // TODO(andydavis, ntv) Add an attribute to specify a different algebra
// with operators other than the current set: {*, +}. // with operators other than the current set: {*, +}.
// TODO(andydavis) Consider using AffineMaps to express contracting, batch
// and free dimension pairs.
def Vector_ContractionOp : def Vector_ContractionOp :
Vector_Op<"contract", [NoSideEffect]>, Vector_Op<"contract", [NoSideEffect]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc, Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
Variadic<TupleOf<[Index]>>:$masks)>, Variadic<TupleOf<[Index]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
Results<(outs AnyVector)> { Results<(outs AnyVector)> {
let summary = "vector contraction operation"; let summary = "vector contraction operation";
let description = [{ let description = [{
@ -64,39 +63,59 @@ def Vector_ContractionOp :
Optional vector mask arguments specify the dynamic dimension sizes of Optional vector mask arguments specify the dynamic dimension sizes of
valid data within the lhs/rhs vector arguments. valid data within the lhs/rhs vector arguments.
Dimensions for the arguments and result type fall into three categories: An iterator type attribute list must be specified, where each element of
*) Contracting: contracting dimensions are present in the lhs and rhs the list represents an iterator with one of the following types:
*) "reduction": reduction dimensions are present in the lhs and rhs
arguments but not in the output (or optional accumulator arguments but not in the output (or optional accumulator
argument). These are the dimensions along which the vector argument). These are the dimensions along which the vector
contraction op computes the sum of products, and contracting contraction op computes the sum of products, and
dimension pair dimension sizes must match between lhs/rhs. contracting dimension pair dimension sizes must match
*) Batch: batch dimensions are non-contracting dimensions and so are between lhs/rhs.
present in the output and in the accumulator argument. The lhs *) "parallel": Batch dimensions are iterator type "parallel", and
and rhs co-iterate along the batch dimension and so dimension are non-contracting dimensions present in the lhs, rhs and
sizes must match across all arguments and result. output. The lhs/rhs co-iterate along the batch dimensions,
*) Free: free dimensions are non-contraction, non-batch dimensions and which should be expressed in their indexing maps.
are present in the output and accumulator argument. The lhs and
rhs free dimensions are unrelated to each other and do not
co-iterate.
Contracting and batch dimensions are specified as dimension pairs Free dimensions are iterator type "parallel", and are
of logical dimension numbers: the first in the pair represents the lhs non-contraction, non-batch dimensions accessed by either the
logical dimension number and the second in the pair represents the lhs or rhs (but not both). The lhs and rhs free dimensions
associated rhs logical dimension number. A dimension pair binds together are unrelated to each other and do not co-iterate, which
logical dimension numbers from the lhs/rhs which co-iterate together, either should be expressed in their indexing maps.
as contracting or batch dimensions.
An indexing map attribute list must be specified with an entry for lhs, rhs
and acc arguments. An indexing map attribute specifies a mapping from each
iterator in the iterator type list, to each dimension of an N-D vector.
Examples: Examples:
// 2D vector contraction with one contracting dimension (matmul). // 2D vector contraction with one contracting dimension (matmul).
%3 = vector.contract %0, %1, %2 #contraction_accesses = [
{ contracting_dim_map = [[1, 0]] } (i, j, k) -> (i, k),
(i, j, k) -> (k, j),
(i, j, k) -> (i, j)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = [parallel, parallel, reduction]
}
%3 = vector.contract #contraction_trait %0, %1, %2
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32> : vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
// 4D to 3D vector contraction with two contracting dimensions and // 4D to 3D vector contraction with two contracting dimensions and
// one batch dimension. // one batch dimension.
%4 = vector.contract %0, %1, %2 #contraction_accesses = [
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] } (b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = [parallel, parallel, parallel reduction, reduction]
}
%4 = vector.contract #contraction_trait %0, %1, %2
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
// 4D vector contraction with two contracting dimensions and optional // 4D vector contraction with two contracting dimensions and optional
@ -106,8 +125,7 @@ def Vector_ContractionOp :
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7 %rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
: tuple<index, index, index, index> : tuple<index, index, index, index>
%5 = vector.contract %0, %1, %2, %lhs_mask, %rhs_mask %5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
{ contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
}]; }];
let extraClassDeclaration = [{ let extraClassDeclaration = [{
@ -131,11 +149,13 @@ def Vector_ContractionOp :
VectorType getResultType() { VectorType getResultType() {
return getResult()->getType().cast<VectorType>(); return getResult()->getType().cast<VectorType>();
} }
static StringRef getContractingDimMapAttrName() { SmallVector<StringRef, 2> getTraitAttrNames();
return "contracting_dim_map"; SmallVector<AffineMap, 4> getIndexingMaps();
static StringRef getReductionIteratorTypeName() {
return "reduction";
} }
static StringRef getBatchDimMapAttrName() { static StringRef getParallelIteratorTypeName() {
return "batch_dim_map"; return "parallel";
} }
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap(); std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap(); std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();

View File

@ -27,6 +27,7 @@
#include "mlir/IR/OpImplementation.h" #include "mlir/IR/OpImplementation.h"
#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/TypeUtilities.h"
#include "mlir/Support/LLVM.h" #include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringSet.h"
using namespace mlir; using namespace mlir;
using namespace mlir::vector; using namespace mlir::vector;
@ -56,7 +57,10 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
SmallVector<Type, 2> types; SmallVector<Type, 2> types;
Type resultVectorType; Type resultVectorType;
auto loc = parser.getCurrentLocation(); auto loc = parser.getCurrentLocation();
if (parser.parseOperand(lhsInfo) || parser.parseComma() || DictionaryAttr dictAttr;
// TODO(andydavis, ntv) Unify linalg op attribute parsing.
if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
parser.parseOperand(lhsInfo) || parser.parseComma() ||
parser.parseOperand(rhsInfo) || parser.parseComma() || parser.parseOperand(rhsInfo) || parser.parseComma() ||
parser.parseOperand(accInfo) || parser.parseOperand(accInfo) ||
parser.parseTrailingOperandList(masksInfo) || parser.parseTrailingOperandList(masksInfo) ||
@ -68,7 +72,8 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
parser.resolveOperand(accInfo, resultVectorType, result.operands) || parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
parser.addTypeToList(resultVectorType, result.types)) parser.addTypeToList(resultVectorType, result.types))
return failure(); return failure();
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
if (masksInfo.empty()) if (masksInfo.empty())
return success(); return success();
if (masksInfo.size() != 2) if (masksInfo.size() != 2)
@ -90,13 +95,23 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
} }
static void print(OpAsmPrinter &p, ContractionOp op) { static void print(OpAsmPrinter &p, ContractionOp op) {
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs(); // TODO(andydavis, ntv) Unify printing code with linalg ops.
p << ", " << *op.acc(); auto attrNames = op.getTraitAttrNames();
llvm::StringSet<> traitAttrsSet;
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
SmallVector<NamedAttribute, 8> attrs;
for (auto attr : op.getAttrs()) {
if (traitAttrsSet.count(attr.first.strref()) > 0)
attrs.push_back(attr);
}
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
p << *op.rhs() << ", " << *op.acc();
if (llvm::size(op.masks()) == 2) { if (llvm::size(op.masks()) == 2) {
p << ", " << **op.masks().begin(); p << ", " << **op.masks().begin();
p << ", " << **(op.masks().begin() + 1); p << ", " << **(op.masks().begin() + 1);
} }
p.printOptionalAttrDict(op.getAttrs()); p.printOptionalAttrDict(op.getAttrs(), attrNames);
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into " p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
<< op.getResultType(); << op.getResultType();
} }
@ -159,6 +174,34 @@ static LogicalResult verify(ContractionOp op) {
auto rhsType = op.getRhsType(); auto rhsType = op.getRhsType();
auto accType = op.getAccType(); auto accType = op.getAccType();
auto resType = op.getResultType(); auto resType = op.getResultType();
// Verify that an indexing map was specified for each vector operand.
if (op.indexing_maps().size() != 3)
return op.emitOpError("expected an indexing map for each vector operand");
// Verify that each index map has 'numIterators' inputs, no symbols, and
// that the number of map outputs equals the rank of its associated
// vector operand.
unsigned numIterators = op.iterator_types().getValue().size();
for (auto it : llvm::enumerate(op.indexing_maps())) {
auto index = it.index();
auto map = it.value().cast<AffineMapAttr>().getValue();
if (map.getNumSymbols() != 0)
return op.emitOpError("expected indexing map ")
<< index << " to have no symbols";
if (map.getNumDims() != numIterators)
return op.emitOpError("expected indexing map ")
<< index << " to have " << numIterators << " number of inputs";
auto operandType = op.getOperand(index)->getType().cast<VectorType>();
unsigned rank = operandType.getShape().size();
if (map.getNumResults() != rank)
return op.emitOpError("expected indexing map ")
<< index << " to have " << rank << " number of outputs";
if (!map.isProjectedPermutation())
return op.emitOpError("expected indexing map ")
<< index << " to be a projected permutation of its inputs";
}
auto contractingDimMap = op.getContractingDimMap(); auto contractingDimMap = op.getContractingDimMap();
auto batchDimMap = op.getBatchDimMap(); auto batchDimMap = op.getBatchDimMap();
@ -198,27 +241,54 @@ static LogicalResult verify(ContractionOp op) {
return success(); return success();
} }
static std::vector<std::pair<int64_t, int64_t>> getDimMap(Attribute attr) { SmallVector<StringRef, 2> ContractionOp::getTraitAttrNames() {
return SmallVector<StringRef, 2>{"indexing_maps", "iterator_types"};
}
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
if (targetExpr == map.getResult(i))
return i;
return -1;
}
static std::vector<std::pair<int64_t, int64_t>>
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
StringRef targetIteratorTypeName, MLIRContext *context) {
std::vector<std::pair<int64_t, int64_t>> dimMap; std::vector<std::pair<int64_t, int64_t>> dimMap;
auto dimPairs = attr.dyn_cast_or_null<ArrayAttr>(); for (auto it : llvm::enumerate(iteratorTypes)) {
if (!dimPairs) auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
return dimMap; if (iteratorTypeName != targetIteratorTypeName)
for (auto dimPairAttr : dimPairs) { continue;
auto dimPair = dimPairAttr.cast<ArrayAttr>(); // Search lhs/rhs map results for 'targetExpr'.
assert(dimPair.size() == 2); auto targetExpr = getAffineDimExpr(it.index(), context);
auto lhsDim = dimPair.begin()->cast<IntegerAttr>().getInt(); int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
auto rhsDim = std::prev(dimPair.end())->cast<IntegerAttr>().getInt(); int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
if (lhsDim >= 0 && rhsDim >= 0)
dimMap.push_back({lhsDim, rhsDim}); dimMap.push_back({lhsDim, rhsDim});
} }
return dimMap; return dimMap;
} }
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() { std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
return getDimMap(getAttr(getContractingDimMapAttrName())); SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
return getDimMap(indexingMaps, iterator_types(),
getReductionIteratorTypeName(), getContext());
} }
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() { std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
return getDimMap(getAttr(getBatchDimMapAttrName())); SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
return getDimMap(indexingMaps, iterator_types(),
getParallelIteratorTypeName(), getContext());
}
SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
SmallVector<AffineMap, 4> res;
auto mapAttrs = indexing_maps().getValue();
res.reserve(mapAttrs.size());
for (auto mapAttr : mapAttrs)
res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
return res;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//

View File

@ -341,56 +341,196 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) {
// ----- // -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{expected an indexing map for each vector operand}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}
// -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, c0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{expected indexing map 0 to be a projected permutation of its inputs}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}
// -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1)[s0] -> (b0, s0, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{op expected indexing map 1 to have no symbols}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}
// -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{expected indexing map 2 to have 5 number of inputs}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}
// -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{expected indexing map 1 to have 4 number of outputs}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}
// -----
#contraction_accesses = [
(b0, f0, f1, b1, b2) -> (b1, b0, b2, f0),
(b0, f0, f1, b1, b2) -> (b0, b2, b1, f1),
(b0, f0, f1, b1, b2) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "parallel"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) { %arg4 : index) {
// expected-error@+1 {{op expected at least one contracting dimension pair}} // expected-error@+1 {{op expected at least one contracting dimension pair}}
%0 = vector.contract %arg0, %arg1, %arg2 %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 0]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return return
} }
// ----- // -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c1, b0, c0, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) { %arg4 : index) {
// expected-error@+1 {{invalid contracting dimension map}} // expected-error@+1 {{invalid contracting dimension map}}
%0 = vector.contract %arg0, %arg1, %arg2 %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[1, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return return
} }
// ----- // -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (f1, c1, c0, b0),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) { %arg4 : index) {
// expected-error@+1 {{invalid batch dimension map}} // expected-error@+1 {{invalid batch dimension map}}
%0 = vector.contract %arg0, %arg1, %arg2 %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 2]], contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return return
} }
// ----- // -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x88xf32>, %arg3 : vector<8x15x8x5xf32>, %arg2: vector<88x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) { %arg4 : index) {
// expected-error@+1 {{invalid accumulator/result vector shape}} // expected-error@+1 {{invalid accumulator/result vector shape}}
%0 = vector.contract %arg0, %arg1, %arg2 %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] } : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<88x15x5xf32>
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x88xf32>
return return
} }
// ----- // -----
#contraction_accesses = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>, func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, %arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) { %arg4 : index) {
@ -399,8 +539,7 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 %rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index> : tuple<index, index, index, index>
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}} // expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
%0 = vector.contract %arg0, %arg1, %arg2, %lhs_mask %0 = vector.contract #contraction_trait %arg0, %arg1, %arg2, %lhs_mask
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return return
} }

View File

@ -60,32 +60,47 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
return %1: vector<2x2x16xf32> return %1: vector<2x2x16xf32>
} }
#contraction_accesses0 = [
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
]
#contraction_trait0 = {
indexing_maps = #contraction_accesses0,
iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]
}
#contraction_accesses1 = [
(f0, f1, f2, f3, c0, c1) -> (c0, f0, c1, f2),
(f0, f1, f2, f3, c0, c1) -> (f1, c1, c0, f3),
(f0, f1, f2, f3, c0, c1) -> (f0, f1, f2, f3)
]
#contraction_trait1 = {
indexing_maps = #contraction_accesses1,
iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction",
"reduction"]
}
// CHECK-LABEL: contraction // CHECK-LABEL: contraction
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>, func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>, %arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) { %arg4 : index) {
// Test contraction with batch and contracting dims. // Test contraction with batch and contracting dims.
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {batch_dim_map = {{.*}}1, 0{{.*}}, contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
%0 = vector.contract %arg0, %arg1, %arg2 %0 = vector.contract #contraction_trait0 %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
// Test contraction with only contracting dims. In this case the lhs/rhs // Test contraction with only contracting dims. In this case the lhs/rhs
// dimension of size 8 will be considered a free dim for lhs/rhs and will // dimension of size 8 will be considered a parallel dim for lhs/rhs and will
// appear twice in the output. // appear twice in the output.
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
%1 = vector.contract %arg0, %arg1, %arg3 %1 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3
{ contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
// Test contraction with optional vector mask arguments. // Test contraction with optional vector mask arguments.
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 %lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index> : tuple<index, index, index, index>
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4 %rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index> : tuple<index, index, index, index>
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> // CHECK: vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["parallel", "parallel", "parallel", "parallel", "reduction", "reduction"]} {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
%2 = vector.contract %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask %2 = vector.contract #contraction_trait1 %arg0, %arg1, %arg3, %lhs_mask,
{ contracting_dim_map = [[0, 2], [2, 1]] } %rhs_mask
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32> : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
return return
} }