[mlir] [VectorOps] generalized vector.contract semantics

Summary:
Previously, vector.contract did not allow an empty set of
free or batch dimensions (K = 0) which defines a basic
reduction into a scalar (like a dot product). This CL
relaxes that restriction. Also adds constraints on
element type of operands and results. With tests.

Reviewers: nicolasvasilache, andydavis1, rriddle

Reviewed By: andydavis1

Subscribers: merge_guards_bot, mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74014
This commit is contained in:
aartbik 2020-02-05 16:45:39 -08:00
parent 0c3b2986ac
commit 6e2309d7fa
4 changed files with 107 additions and 31 deletions

View File

@ -38,18 +38,25 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
// with operators other than the current set: {*, +}.
def Vector_ContractionOp :
Vector_Op<"contract", [NoSideEffect]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
Vector_Op<"contract", [NoSideEffect,
PredOpTrait<"first operand lhs and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"second operand rhs and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>,
PredOpTrait<"third operand acc and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyType:$acc,
Variadic<VectorOf<[I1]>>:$masks,
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
Results<(outs AnyVector)> {
Results<(outs AnyType)> {
let summary = "vector contraction operation";
let description = [{
Computes the sum of products of vector elements along contracting
dimension pairs from 2 vectors of rank M and N respectively, adds this
intermediate result to the accumulator argument of rank K, and returns a
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
num_batch_dims (see dimension type descriptions below)).
num_batch_dims (see dimension type descriptions below)). For K = 0 (no
free or batch dimensions), the accumulator and output are a scalar.
Optional vector mask arguments (produced by CreateMaskOp or ConstantMaskOp)
specify the dynamic dimension sizes of valid data within the lhs/rhs vector
@ -59,7 +66,7 @@ def Vector_ContractionOp :
the list represents an iterator with one of the following types:
*) "reduction": reduction dimensions are present in the lhs and rhs
arguments but not in the output (or optional accumulator
arguments but not in the output (and accumulator
argument). These are the dimensions along which the vector
contraction op computes the sum of products, and
contracting dimension pair dimension sizes must match
@ -81,7 +88,20 @@ def Vector_ContractionOp :
Examples:
// 2D vector contraction with one contracting dimension (matmul).
// Simple dot product (K = 0).
#contraction_accesses = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
affine_map<(i) -> ()>
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["reduction"]
}
%3 = vector.contract #contraction_trait %0, %1, %2
: vector<10xf32>, vector<10xf32> into f32
// 2D vector contraction with one contracting dimension (matmul, K = 2).
#contraction_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
@ -96,7 +116,7 @@ def Vector_ContractionOp :
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
// 4D to 3D vector contraction with two contracting dimensions and
// one batch dimension.
// one batch dimension (K = 3).
#contraction_accesses = [
affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>,
affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,
@ -105,7 +125,7 @@ def Vector_ContractionOp :
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "parallel",
"reduction", "reduction"]
"reduction", "reduction"]
}
%4 = vector.contract #contraction_trait %0, %1, %2
@ -129,9 +149,7 @@ def Vector_ContractionOp :
VectorType getRhsType() {
return rhs().getType().cast<VectorType>();
}
VectorType getAccType() {
return acc().getType().cast<VectorType>();
}
Type getAccType() { return acc().getType(); }
VectorType getLHSVectorMaskType() {
if (llvm::size(masks()) != 2) return VectorType();
return getOperand(3).getType().cast<VectorType>();
@ -140,9 +158,7 @@ def Vector_ContractionOp :
if (llvm::size(masks()) != 2) return VectorType();
return getOperand(4).getType().cast<VectorType>();
}
VectorType getResultType() {
return getResult().getType().cast<VectorType>();
}
Type getResultType() { return getResult().getType(); }
ArrayRef<StringRef> getTraitAttrNames();
SmallVector<AffineMap, 4> getIndexingMaps();
static unsigned getAccOperandIndex() { return 2; }

View File

@ -81,7 +81,7 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
OpAsmParser::OperandType accInfo;
SmallVector<OpAsmParser::OperandType, 2> masksInfo;
SmallVector<Type, 2> types;
Type resultVectorType;
Type resultType;
auto loc = parser.getCurrentLocation();
DictionaryAttr dictAttr;
// TODO(andydavis, ntv) Unify linalg op attribute parsing.
@ -92,11 +92,11 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
parser.parseTrailingOperandList(masksInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(types) ||
parser.parseKeywordType("into", resultVectorType) ||
parser.parseKeywordType("into", resultType) ||
parser.resolveOperand(lhsInfo, types[0], result.operands) ||
parser.resolveOperand(rhsInfo, types[1], result.operands) ||
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
parser.addTypeToList(resultVectorType, result.types))
parser.resolveOperand(accInfo, resultType, result.operands) ||
parser.addTypeToList(resultType, result.types))
return failure();
result.attributes.assign(dictAttr.getValue().begin(),
dictAttr.getValue().end());
@ -149,8 +149,7 @@ static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
}
static bool verifyOutputShape(
VectorType lhsType, VectorType rhsType, VectorType accType,
VectorType resType,
VectorType lhsType, VectorType rhsType, Type accType, Type resType,
const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
DenseSet<int64_t> lhsContractingDimSet;
@ -178,14 +177,28 @@ static bool verifyOutputShape(
expectedResultDims.push_back(rhsType.getDimSize(i));
}
// Verify dimension from 'resType' against 'expectedResultDims'.
if (resType.getShape().size() != expectedResultDims.size() ||
accType.getShape().size() != expectedResultDims.size())
return false;
for (int64_t i = 0, e = resType.getRank(); i < e; ++i) {
if (resType.getDimSize(i) != expectedResultDims[i] ||
accType.getDimSize(i) != expectedResultDims[i])
// Verify 'expectedResultDims'.
if (expectedResultDims.size() == 0) {
// No batch or free dimension implies a scalar result.
if (resType.isa<VectorType>() || accType.isa<VectorType>())
return false;
} else {
// At least one batch or free dimension implies a vector result.
auto resVectorType = resType.dyn_cast<VectorType>();
auto accVectorType = accType.dyn_cast<VectorType>();
if (!resVectorType || !accVectorType)
return false;
// Verify dimension from 'resType' against 'expectedResultDims'.
if (resVectorType.getShape().size() != expectedResultDims.size() ||
accVectorType.getShape().size() != expectedResultDims.size())
return false;
for (int64_t i = 0, e = resVectorType.getRank(); i < e; ++i) {
if (resVectorType.getDimSize(i) != expectedResultDims[i] ||
accVectorType.getDimSize(i) != expectedResultDims[i])
return false;
}
}
return true;
}
@ -210,11 +223,18 @@ static LogicalResult verify(ContractionOp op) {
if (map.getNumSymbols() != 0)
return op.emitOpError("expected indexing map ")
<< index << " to have no symbols";
auto vectorType = op.getOperand(index).getType().dyn_cast<VectorType>();
unsigned rank = vectorType ? vectorType.getShape().size() : 0;
// Since (...) -> () is parsed into an empty map, we need to add
// a special case for this situation: continue the verification
// of an empty map if the resulting rank is indeed zero, i.e. this
// is a reduction into a scalar.
if (map.getNumDims() == 0 && map.getNumResults() == 0 && rank == 0)
continue;
// Verify that the map has the right number of inputs, outputs, and indices.
if (map.getNumDims() != numIterators)
return op.emitOpError("expected indexing map ")
<< index << " to have " << numIterators << " number of inputs";
auto operandType = op.getOperand(index).getType().cast<VectorType>();
unsigned rank = operandType.getShape().size();
if (map.getNumResults() != rank)
return op.emitOpError("expected indexing map ")
<< index << " to have " << rank << " number of outputs";
@ -292,7 +312,7 @@ getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
void ContractionOp::getIterationBounds(
SmallVectorImpl<int64_t> &iterationBounds) {
auto lhsShape = getLhsType().getShape();
auto resShape = getResultType().getShape();
auto resVectorType = getResultType().dyn_cast<VectorType>();
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
SmallVector<int64_t, 2> iterationShape;
for (auto it : llvm::enumerate(iterator_types())) {
@ -309,7 +329,8 @@ void ContractionOp::getIterationBounds(
// Get parallel dimension size from result shape.
int64_t resDimIndex = getResultIndex(indexingMaps[2], targetExpr);
assert(resDimIndex >= 0);
iterationBounds.push_back(resShape[resDimIndex]);
assert(resVectorType != nullptr);
iterationBounds.push_back(resVectorType.getShape()[resDimIndex]);
}
}

View File

@ -707,6 +707,25 @@ func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
// -----
#contraction_accesses = [
affine_map<(i, j, k) -> (i, k)>,
affine_map<(i, j, k) -> (k, j)>,
affine_map<(i, j, k) -> (i, j)>
]
#contraction_trait = {
indexing_maps = #contraction_accesses,
iterator_types = ["parallel", "parallel", "reduction"]
}
func @contraction(%arg0: vector<4x3xi32>,
%arg1: vector<3x7xf32>,
%arg2: vector<4x7xf32>) -> vector<4x7xf32> {
// expected-error@+1 {{'vector.contract' op failed to verify that first operand lhs and result have same element type}}
%0 = vector.contract #contraction_trait %arg0, %arg1, %arg2
: vector<4x3xi32>, vector<3x7xf32> into vector<4x7xf32>
}
// -----
func @create_mask() {
%c2 = constant 2 : index
%c3 = constant 3 : index

View File

@ -127,6 +127,26 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
return %1: vector<2x2x16xf32>
}
#contraction_to_scalar_accesses = [
affine_map<(i) -> (i)>,
affine_map<(i) -> (i)>,
affine_map<(i) -> ()>
]
#contraction_to_scalar_trait = {
indexing_maps = #contraction_to_scalar_accesses,
iterator_types = ["reduction"]
}
// CHECK-LABEL: contraction_to_scalar
func @contraction_to_scalar(%arg0: vector<10xf32>, %arg1: vector<10xf32>) -> f32 {
// CHECK: %[[C0:.*]] = constant 0.000000e+00 : f32
%f0 = constant 0.0: f32
// CHECK: %[[X:.*]] = vector.contract {indexing_maps = [#{{.*}}, #{{.*}}, #{{.*}}], iterator_types = ["reduction"]} %{{.*}}, %{{.*}}, %[[C0]] : vector<10xf32>, vector<10xf32> into f32
%0 = vector.contract #contraction_to_scalar_trait %arg0, %arg1, %f0
: vector<10xf32>, vector<10xf32> into f32
// CHECK: return %[[X]] : f32
return %0 : f32
}
#contraction_accesses0 = [
affine_map<(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0)>,
affine_map<(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1)>,