forked from OSchip/llvm-project
[mlir] [VectorOps] generalized vector.contract semantics
Summary: Previously, vector.contract did not allow an empty set of free or batch dimensions (K = 0) which defines a basic reduction into a scalar (like a dot product). This CL relaxes that restriction. Also adds constraints on element type of operands and results. With tests. Reviewers: nicolasvasilache, andydavis1, rriddle Reviewed By: andydavis1 Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits Tags: #llvm Differential Revision: https://reviews.llvm.org/D74014
This commit is contained in:
parent
0c3b2986ac
commit
6e2309d7fa
|
@ -38,18 +38,25 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
|
|||
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
|
||||
// with operators other than the current set: {*, +}.
|
||||
def Vector_ContractionOp :
|
||||
Vector_Op<"contract", [NoSideEffect]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
|
||||
Vector_Op<"contract", [NoSideEffect,
|
||||
PredOpTrait<"first operand lhs and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"second operand rhs and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 1>>,
|
||||
PredOpTrait<"third operand acc and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 1>>]>,
|
||||
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
|
||||
Variadic<VectorOf<[I1]>>:$masks,
|
||||
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
|
||||
Results<(outs AnyVector)> {
|
||||
Results<(outs AnyType)> {
|
||||
let summary = "vector contraction operation";
|
||||
let description = [{
|
||||
Computes the sum of products of vector elements along contracting
|
||||
dimension pairs from 2 vectors of rank M and N respectively, adds this
|
||||
intermediate result to the accumulator argument of rank K, and returns a
|
||||
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
|
||||
num_batch_dims (see dimension type descriptions below)).
|
||||
num_batch_dims (see dimension type descriptions below)). For K = 0 (no
|
||||
free or batch dimensions), the accumulator and output are a scalar.
|
||||
|
||||
Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
|
||||
specify the dynamic dimension sizes of valid data within the lhs/rhs vector
|
||||
|
@ -59,7 +66,7 @@ def Vector_ContractionOp :
|
|||
the list represents an iterator with one of the following types:
|
||||
|
||||
*) "reduction": reduction dimensions are present in the lhs and rhs
|
||||
arguments but not in the output (or optional accumulator
|
||||
arguments but not in the output (and accumulator
|
||||
argument). These are the dimensions along which the vector
|
||||
contraction op computes the sum of products, and
|
||||
contracting dimension pair dimension sizes must match
|
||||
|
@ -81,7 +88,20 @@ def Vector_ContractionOp :
|
|||
|
||||
Examples:
|
||||
|
||||
// 2D vector contraction with one contracting dimension (matmul).
|
||||
// Simple dot product (K = 0).
|
||||
#contraction_accesses = [
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> ()>
|
||||
]
|
||||
#contraction_trait = {
|
||||
indexing_maps = #contraction_accesses,
|
||||
iterator_types = ["reduction"]
|
||||
}
|
||||
%3 = vector.contract #contraction_trait %0, %1, %2
|
||||
: vector<10xf32>, vector<10xf32> into f32
|
||||
|
||||
// 2D vector contraction with one contracting dimension (matmul, K = 2).
|
||||
#contraction_accesses = [
|
||||
affine_map<(i, j, k) -> (i, k)>,
|
||||
affine_map<(i, j, k) -> (k, j)>,
|
||||
|
@ -96,7 +116,7 @@ def Vector_ContractionOp :
|
|||
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
|
||||
|
||||
// 4D to 3D vector contraction with two contracting dimensions and
|
||||
// one batch dimension.
|
||||
// one batch dimension (K = 3).
|
||||
#contraction_accesses = [
|
||||
affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>,
|
||||
affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,
|
||||
|
@ -105,7 +125,7 @@ def Vector_ContractionOp :
|
|||
#contraction_trait = {
|
||||
indexing_maps = #contraction_accesses,
|
||||
iterator_types = ["parallel", "parallel", "parallel",
|
||||
"reduction", "reduction"]
|
||||
"reduction", "reduction"]
|
||||
}
|
||||
|
||||
%4 = vector.contract #contraction_trait %0, %1, %2
|
||||
|
@ -129,9 +149,7 @@ def Vector_ContractionOp :
|
|||
VectorType getRhsType() {
|
||||
return rhs().getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getAccType() {
|
||||
return acc().getType().cast<VectorType>();
|
||||
}
|
||||
Type getAccType() { return acc().getType(); }
|
||||
VectorType getLHSVectorMaskType() {
|
||||
if (llvm::size(masks()) != 2) return VectorType();
|
||||
return getOperand(3).getType().cast<VectorType>();
|
||||
|
@ -140,9 +158,7 @@ def Vector_ContractionOp :
|
|||
if (llvm::size(masks()) != 2) return VectorType();
|
||||
return getOperand(4).getType().cast<VectorType>();
|
||||
}
|
||||
VectorType getResultType() {
|
||||
return getResult().getType().cast<VectorType>();
|
||||
}
|
||||
Type getResultType() { return getResult().getType(); }
|
||||
ArrayRef<StringRef> getTraitAttrNames();
|
||||
SmallVector<AffineMap, 4> getIndexingMaps();
|
||||
static unsigned getAccOperandIndex() { return 2; }
|
||||
|
|
|
@ -81,7 +81,7 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
|
|||
OpAsmParser::OperandType accInfo;
|
||||
SmallVector<OpAsmParser::OperandType, 2> masksInfo;
|
||||
SmallVector<Type, 2> types;
|
||||
Type resultVectorType;
|
||||
Type resultType;
|
||||
auto loc = parser.getCurrentLocation();
|
||||
DictionaryAttr dictAttr;
|
||||
// TODO(andydavis, ntv) Unify linalg op attribute parsing.
|
||||
|
@ -92,11 +92,11 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
|
|||
parser.parseTrailingOperandList(masksInfo) ||
|
||||
parser.parseOptionalAttrDict(result.attributes) ||
|
||||
parser.parseColonTypeList(types) ||
|
||||
parser.parseKeywordType("into", resultVectorType) ||
|
||||
parser.parseKeywordType("into", resultType) ||
|
||||
parser.resolveOperand(lhsInfo, types[0], result.operands) ||
|
||||
parser.resolveOperand(rhsInfo, types[1], result.operands) ||
|
||||
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
|
||||
parser.addTypeToList(resultVectorType, result.types))
|
||||
parser.resolveOperand(accInfo, resultType, result.operands) ||
|
||||
parser.addTypeToList(resultType, result.types))
|
||||
return failure();
|
||||
result.attributes.assign(dictAttr.getValue().begin(),
|
||||
dictAttr.getValue().end());
|
||||
|
@ -149,8 +149,7 @@ static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
|
|||
}
|
||||
|
||||
static bool verifyOutputShape(
|
||||
VectorType lhsType, VectorType rhsType, VectorType accType,
|
||||
VectorType resType,
|
||||
VectorType lhsType, VectorType rhsType, Type accType, Type resType,
|
||||
const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
|
||||
const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
|
||||
DenseSet<int64_t> lhsContractingDimSet;
|
||||
|
@ -178,14 +177,28 @@ static bool verifyOutputShape(
|
|||
expectedResultDims.push_back(rhsType.getDimSize(i));
|
||||
}
|
||||
|
||||
// Verify dimension from 'resType' against 'expectedResultDims'.
|
||||
if (resType.getShape().size() != expectedResultDims.size() ||
|
||||
accType.getShape().size() != expectedResultDims.size())
|
||||
return false;
|
||||
for (int64_t i = 0, e = resType.getRank(); i < e; ++i) {
|
||||
if (resType.getDimSize(i) != expectedResultDims[i] ||
|
||||
accType.getDimSize(i) != expectedResultDims[i])
|
||||
// Verify 'expectedResultDims'.
|
||||
if (expectedResultDims.size() == 0) {
|
||||
// No batch or free dimension implies a scalar result.
|
||||
if (resType.isa<VectorType>() || accType.isa<VectorType>())
|
||||
return false;
|
||||
|
||||
} else {
|
||||
// At least one batch or free dimension implies a vector result.
|
||||
auto resVectorType = resType.dyn_cast<VectorType>();
|
||||
auto accVectorType = accType.dyn_cast<VectorType>();
|
||||
if (!resVectorType || !accVectorType)
|
||||
return false;
|
||||
|
||||
// Verify dimension from 'resType' against 'expectedResultDims'.
|
||||
if (resVectorType.getShape().size() != expectedResultDims.size() ||
|
||||
accVectorType.getShape().size() != expectedResultDims.size())
|
||||
return false;
|
||||
for (int64_t i = 0, e = resVectorType.getRank(); i < e; ++i) {
|
||||
if (resVectorType.getDimSize(i) != expectedResultDims[i] ||
|
||||
accVectorType.getDimSize(i) != expectedResultDims[i])
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
@ -210,11 +223,18 @@ static LogicalResult verify(ContractionOp op) {
|
|||
if (map.getNumSymbols() != 0)
|
||||
return op.emitOpError("expected indexing map ")
|
||||
<< index << " to have no symbols";
|
||||
auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
|
||||
unsigned rank = vectorType ? vectorType.getShape().size() : 0;
|
||||
// Since (...) -> () is parsed into an empty map, we need to add
|
||||
// a special case for this situation: continue the verification
|
||||
// of an empty map if the resulting rank is indeed zero, i.e. this
|
||||
// is a reduction into a scalar.
|
||||
if (map.getNumDims() == 0 && map.getNumResults() == 0 && rank == 0)
|
||||
continue;
|
||||
// Verify that the map has the right number of inputs, outputs, and indices.
|
||||
if (map.getNumDims() != numIterators)
|
||||
return op.emitOpError("expected indexing map ")
|
||||
<< index << " to have " << numIterators << " number of inputs";
|
||||
auto operandType = op.getOperand(index).getType().cast<VectorType>();
|
||||
unsigned rank = operandType.getShape().size();
|
||||
if (map.getNumResults() != rank)
|
||||
return op.emitOpError("expected indexing map ")
|
||||
<< index << " to have " << rank << " number of outputs";
|
||||
|
@ -292,7 +312,7 @@ getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
|
|||
void ContractionOp::getIterationBounds(
|
||||
SmallVectorImpl<int64_t> &iterationBounds) {
|
||||
auto lhsShape = getLhsType().getShape();
|
||||
auto resShape = getResultType().getShape();
|
||||
auto resVectorType = getResultType().dyn_cast<VectorType>();
|
||||
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
|
||||
SmallVector<int64_t, 2> iterationShape;
|
||||
for (auto it : llvm::enumerate(iterator_types())) {
|
||||
|
@ -309,7 +329,8 @@ void ContractionOp::getIterationBounds(
|
|||
// Get parallel dimension size from result shape.
|
||||
int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
|
||||
assert(resDimIndex >= 0);
|
||||
iterationBounds.push_back(resShape[resDimIndex]);
|
||||
assert(resVectorType != nullptr);
|
||||
iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -707,6 +707,25 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
|
|||
|
||||
// -----
|
||||
|
||||
#contraction_accesses = [
|
||||
affine_map<(i, j, k) -> (i, k)>,
|
||||
affine_map<(i, j, k) -> (k, j)>,
|
||||
affine_map<(i, j, k) -> (i, j)>
|
||||
]
|
||||
#contraction_trait = {
|
||||
indexing_maps = #contraction_accesses,
|
||||
iterator_types = ["parallel", "parallel", "reduction"]
|
||||
}
|
||||
func @contraction(%arg0: vector<4x3xi32>,
|
||||
%arg1: vector<3x7xf32>,
|
||||
%arg2: vector<4x7xf32>) -> vector<4x7xf32> {
|
||||
// expected-error@+1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}}
|
||||
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
|
||||
: vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @create_mask() {
|
||||
%c2 = constant 2 : index
|
||||
%c3 = constant 3 : index
|
||||
|
|
|
@ -127,6 +127,26 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
|
|||
return %1: vector<2x2x16xf32>
|
||||
}
|
||||
|
||||
#contraction_to_scalar_accesses = [
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> (i)>,
|
||||
affine_map<(i) -> ()>
|
||||
]
|
||||
#contraction_to_scalar_trait = {
|
||||
indexing_maps = #contraction_to_scalar_accesses,
|
||||
iterator_types = ["reduction"]
|
||||
}
|
||||
// CHECK-LABEL: contraction_to_scalar
|
||||
func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
|
||||
// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32
|
||||
%f0 = constant 0.0: f32
|
||||
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"]} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
|
||||
%0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
|
||||
: vector<10xf32>, vector<10xf32> into f32
|
||||
// CHECK: return %[[X]] : f32
|
||||
return %0 : f32
|
||||
}
|
||||
|
||||
#contraction_accesses0 = [
|
||||
affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>,
|
||||
affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,
|
||||
|
|
Loading…
Reference in New Issue