Add VectorContractionOp to the VectorOps dialect.

PiperOrigin-RevId: 281605471
This commit is contained in:
Andy Davis 2019-11-20 14:43:15 -08:00 committed by A. Unique TensorFlower
parent 1145cebdab
commit d6a70b31be
4 changed files with 422 additions and 0 deletions

View File

@ -49,6 +49,104 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
let parser = [{ return ::parse$cppClass(parser, result); }];
}
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
// with operators other than the current set: {*, +}.
// TODO(andydavis) Consider using AffineMaps to express contracting, batch
// and free dimension pairs.
def VectorContractionOp :
Vector_Op<"contract", [NoSideEffect]>,
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
Variadic<TupleOf<[Index]>>:$masks)>,
Results<(outs AnyVector)> {
let summary = "vector contraction operation";
let description = [{
Computes the sum of products of vector elements along contracting
dimension pairs from 2 vectors of rank M and N respectively, adds this
intermediate result to the accumulator argument of rank K, and returns a
vector result of rank K (where K = num_lhs_free_dims + num_rhs_free_dims +
num_batch_dims (see dimension type descriptions below)).
Optional vector mask arguments specify the dynamic dimension sizes of
valid data within the lhs/rhs vector arguments.
Dimensions for the arguments and result type fall into three categories:
*) Contracting: contracting dimensions are present in the lhs and rhs
arguments but not in the output (or optional accumulator
argument). These are the dimensions along which the vector
contraction op computes the sum of products, and contracting
dimension pair dimension sizes must match between lhs/rhs.
*) Batch: batch dimensions are non-contracting dimensions and so are
present in the output and in the accumulator argument. The lhs
and rhs co-iterate along the batch dimension and so dimension
sizes must match across all arguments and result.
*) Free: free dimensions are non-contraction, non-batch dimensions and
are present in the output and accumulator argument. The lhs and
rhs free dimensions are unrelated to each other and do not
co-iterate.
Contracting and batch dimensions are specified as dimension pairs
of logical dimension numbers: the first in the pair represents the lhs
logical dimension number and the second in the pair represents the
associated rhs logical dimension number. A dimension pair binds together
logical dimension numbers from the lhs/rhs which co-iterate together, either
as contracting or batch dimensions.
Examples:
// 2D vector contraction with one contracting dimension (matmul).
%3 = vector.contract %0, %1, %2
{ contracting_dim_map = [[1, 0]] }
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
// 4D to 3D vector contraction with two contracting dimensions and
// one batch dimension.
%4 = vector.contract %0, %1, %2
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
// 4D vector contraction with two contracting dimensions and optional
// vector mask arguments.
%lhs_mask = vector.make_tuple %size0, %size1, %size2, %size3
: tuple<index, index, index, index>
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
: tuple<index, index, index, index>
%5 = vector.contract %0, %1, %2, %lhs_mask, %rhs_mask
{ contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
}];
let extraClassDeclaration = [{
VectorType getLhsType() {
return lhs()->getType().cast<VectorType>();
}
VectorType getRhsType() {
return rhs()->getType().cast<VectorType>();
}
VectorType getAccType() {
return acc()->getType().cast<VectorType>();
}
TupleType getLHSVectorMaskType() {
if (llvm::size(masks()) != 2) return TupleType();
return getOperand(3)->getType().cast<TupleType>();
}
TupleType getRHSVectorMaskType() {
if (llvm::size(masks()) != 2) return TupleType();
return getOperand(4)->getType().cast<TupleType>();
}
VectorType getResultType() {
return getResult()->getType().cast<VectorType>();
}
static StringRef getContractingDimMapAttrName() {
return "contracting_dim_map";
}
static StringRef getBatchDimMapAttrName() {
return "batch_dim_map";
}
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();
}];
}
def VectorExtractElementOp :
Vector_Op<"extractelement", [NoSideEffect,
PredOpTrait<"operand and result have same element type",
@ -391,4 +489,21 @@ def VectorTypeCastOp :
}
}];
}
// TODO(andydavis) Morph this operation into a VectorMaskOp.
def VectorIndexTupleOp :
Vector_Op<"make_index_tuple", [NoSideEffect]>,
Arguments<(ins Variadic<Index>:$operands)>,
Results<(outs TupleOf<[Index]>)> {
let summary = "creates a tuple of operand values";
let description = [{
Creates and returns a tuple of its operands which must be of index type.
Example:
%1 = vector.make_index_tuple %size0, %size1, %size2
: tuple<index, index, index>
}];
}
#endif // VECTOR_OPS

View File

@ -43,6 +43,185 @@ mlir::vector::VectorOpsDialect::VectorOpsDialect(MLIRContext *context)
>();
}
//===----------------------------------------------------------------------===//
// VectorContractionOp
//===----------------------------------------------------------------------===//
static ParseResult parseVectorContractionOp(OpAsmParser &parser,
OperationState &result) {
OpAsmParser::OperandType lhsInfo;
OpAsmParser::OperandType rhsInfo;
OpAsmParser::OperandType accInfo;
SmallVector<OpAsmParser::OperandType, 2> masksInfo;
SmallVector<Type, 2> types;
Type resultVectorType;
auto loc = parser.getCurrentLocation();
if (parser.parseOperand(lhsInfo) || parser.parseComma() ||
parser.parseOperand(rhsInfo) || parser.parseComma() ||
parser.parseOperand(accInfo) ||
parser.parseTrailingOperandList(masksInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonTypeList(types) ||
parser.parseKeywordType("into", resultVectorType) ||
parser.resolveOperand(lhsInfo, types[0], result.operands) ||
parser.resolveOperand(rhsInfo, types[1], result.operands) ||
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
parser.addTypeToList(resultVectorType, result.types))
return failure();
if (masksInfo.empty())
return success();
if (masksInfo.size() != 2)
return parser.emitError(parser.getNameLoc(),
"expected zero or exactly 2 vector mask operands");
auto indexType = parser.getBuilder().getIndexType();
auto lhsType = types[0].cast<VectorType>();
auto rhsType = types[1].cast<VectorType>();
SmallVector<Type, 2> maskTypes;
SmallVector<Type, 4> lhsMaskElementTypes(lhsType.getRank(), indexType);
maskTypes.push_back(
TupleType::get(lhsMaskElementTypes, parser.getBuilder().getContext()));
SmallVector<Type, 4> rhsMaskElementTypes(rhsType.getRank(), indexType);
maskTypes.push_back(
TupleType::get(rhsMaskElementTypes, parser.getBuilder().getContext()));
if (parser.resolveOperands(masksInfo, maskTypes, loc, result.operands))
return failure();
return success();
}
static void print(OpAsmPrinter &p, VectorContractionOp op) {
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
p << ", " << *op.acc();
if (llvm::size(op.masks()) == 2) {
p << ", " << **op.masks().begin();
p << ", " << **(op.masks().begin() + 1);
}
p.printOptionalAttrDict(op.getAttrs());
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
<< op.getResultType();
}
static bool verifyDimMap(VectorType lhsType, VectorType rhsType,
const std::vector<std::pair<int64_t, int64_t>> &map) {
for (auto &dimPair : map) {
if (dimPair.first < 0 || dimPair.first >= lhsType.getRank() ||
dimPair.second < 0 || dimPair.second >= rhsType.getRank() ||
lhsType.getDimSize(dimPair.first) != rhsType.getDimSize(dimPair.second))
return false;
}
return true;
}
static bool verifyOutputShape(
VectorType lhsType, VectorType rhsType, VectorType accType,
VectorType resType,
const std::vector<std::pair<int64_t, int64_t>> &contractingDimMap,
const std::vector<std::pair<int64_t, int64_t>> &batchDimMap) {
DenseSet<int64_t> lhsContractingDimSet;
DenseSet<int64_t> rhsContractingDimSet;
for (auto &dimPair : contractingDimMap) {
lhsContractingDimSet.insert(dimPair.first);
rhsContractingDimSet.insert(dimPair.second);
}
DenseSet<int64_t> rhsBatchDimSet;
for (auto &dimPair : batchDimMap)
rhsBatchDimSet.insert(dimPair.second);
// Add free and batch dimensions from 'lhsType' to 'expectedResultDims'.
SmallVector<int64_t, 4> expectedResultDims;
for (int64_t i = 0, e = lhsType.getRank(); i < e; ++i) {
if (lhsContractingDimSet.count(i) > 0)
continue;
expectedResultDims.push_back(lhsType.getDimSize(i));
}
// Add free dimensions from 'rhsType' to 'expectedResultDims'.
for (int64_t i = 0, e = rhsType.getRank(); i < e; ++i) {
if (rhsContractingDimSet.count(i) > 0 || rhsBatchDimSet.count(i) > 0)
continue;
expectedResultDims.push_back(rhsType.getDimSize(i));
}
// Verify dimension from 'resType' against 'expectedResultDims'.
if (resType.getShape().size() != expectedResultDims.size() ||
accType.getShape().size() != expectedResultDims.size())
return false;
for (int64_t i = 0, e = resType.getRank(); i < e; ++i) {
if (resType.getDimSize(i) != expectedResultDims[i] ||
accType.getDimSize(i) != expectedResultDims[i])
return false;
}
return true;
}
static LogicalResult verify(VectorContractionOp op) {
auto lhsType = op.getLhsType();
auto rhsType = op.getRhsType();
auto accType = op.getAccType();
auto resType = op.getResultType();
auto contractingDimMap = op.getContractingDimMap();
auto batchDimMap = op.getBatchDimMap();
// Verify at least one contracting dimension pair was specified.
if (contractingDimMap.empty())
return op.emitOpError("expected at least one contracting dimension pair");
// Verify contracting dimension map was properly constructed.
if (!verifyDimMap(lhsType, rhsType, contractingDimMap))
return op.emitOpError("invalid contracting dimension map");
// Verify batch dimension map was properly constructed.
if (!verifyDimMap(lhsType, rhsType, batchDimMap))
return op.emitOpError("invalid batch dimension map");
// Verify 'accType' and 'resType' shape.
if (!verifyOutputShape(lhsType, rhsType, accType, resType, contractingDimMap,
batchDimMap))
return op.emitOpError("invalid accumulator/result vector shape");
// Verify that either two vector masks are set or none are set.
auto lhsMaskType = op.getLHSVectorMaskType();
auto rhsMaskType = op.getRHSVectorMaskType();
if ((lhsMaskType && !rhsMaskType) || (!lhsMaskType && rhsMaskType))
return op.emitOpError("invalid number of vector masks specified");
if (lhsMaskType && rhsMaskType) {
// Verify tuple element size is != rank.
if (lhsMaskType.getTypes().size() != lhsType.getShape().size() ||
rhsMaskType.getTypes().size() != rhsType.getShape().size())
return op.emitOpError("invalid number of vector mask elements");
// Verify all tuple elements are index type.
for (auto eltType : lhsMaskType.getTypes()) {
if (!eltType.isa<IndexType>())
return op.emitOpError("vector mask element must have index type");
}
}
return success();
}
static std::vector<std::pair<int64_t, int64_t>> getDimMap(Attribute attr) {
std::vector<std::pair<int64_t, int64_t>> dimMap;
auto dimPairs = attr.dyn_cast_or_null<ArrayAttr>();
if (!dimPairs)
return dimMap;
for (auto dimPairAttr : dimPairs) {
auto dimPair = dimPairAttr.cast<ArrayAttr>();
assert(dimPair.size() == 2);
auto lhsDim = dimPair.begin()->cast<IntegerAttr>().getInt();
auto rhsDim = std::prev(dimPair.end())->cast<IntegerAttr>().getInt();
dimMap.push_back({lhsDim, rhsDim});
}
return dimMap;
}
std::vector<std::pair<int64_t, int64_t>>
VectorContractionOp::getContractingDimMap() {
return getDimMap(getAttr(getContractingDimMapAttrName()));
}
std::vector<std::pair<int64_t, int64_t>> VectorContractionOp::getBatchDimMap() {
return getDimMap(getAttr(getBatchDimMapAttrName()));
}
//===----------------------------------------------------------------------===//
// VectorExtractElementOp
//===----------------------------------------------------------------------===//
@ -541,6 +720,36 @@ static LogicalResult verify(VectorTypeCastOp &op) {
return success();
}
//===----------------------------------------------------------------------===//
// VectorIndexTupleOp
//===----------------------------------------------------------------------===//
ParseResult parseVectorIndexTupleOp(OpAsmParser &parser,
OperationState &result) {
auto indexType = parser.getBuilder().getIndexType();
Type resultType;
SmallVector<OpAsmParser::OperandType, 4> operandInfo;
return failure(
parser.parseOperandList(operandInfo) ||
parser.parseOptionalAttrDict(result.attributes) ||
parser.parseColonType(resultType) ||
parser.resolveOperands(operandInfo, indexType, result.operands) ||
parser.addTypeToList(resultType, result.types));
}
static void print(OpAsmPrinter &p, VectorIndexTupleOp &op) {
p << op.getOperationName() << ' ';
p.printOperands(op.operands());
p << " : " << op.getResult()->getType();
}
static LogicalResult verify(VectorIndexTupleOp &op) {
for (auto operand : op.getOperands())
if (!operand->getType().isa<IndexType>())
return op.emitOpError("all operands must be of index type");
return success();
}
namespace mlir {
#define GET_OP_CLASSES

View File

@ -303,3 +303,71 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) {
// expected-error@+1 {{op expected result type to be 'vector<2x8x16xf32>'}}
%1 = vector.strided_slice %arg0 {offsets = [2], sizes = [2], strides = [1]} : vector<4x8x16xf32> to vector<3x1xf32>
}
// -----
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{op expected at least one contracting dimension pair}}
%0 = vector.contract %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 0]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}
// -----
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{invalid contracting dimension map}}
%0 = vector.contract %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[1, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}
// -----
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{invalid batch dimension map}}
%0 = vector.contract %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 2]], contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}
// -----
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x88xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// expected-error@+1 {{invalid accumulator/result vector shape}}
%0 = vector.contract %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x88xf32>
return
}
// -----
func @contraction(%arg0: vector<7x8x16x15xf32>, %arg1: vector<8x16x7x5xf32>,
%arg2: vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
// expected-error@+1 {{expected zero or exactly 2 vector mask operands}}
%0 = vector.contract %arg0, %arg1, %arg2, %lhs_mask
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
return
}

View File

@ -48,3 +48,33 @@ func @strided_slice(%arg0: vector<4x8x16xf32>) -> vector<2x2x16xf32> {
%1 = vector.strided_slice %arg0 {offsets = [2, 2], sizes = [2, 2], strides = [1, 1]} : vector<4x8x16xf32> to vector<2x2x16xf32>
return %1: vector<2x2x16xf32>
}
// CHECK-LABEL: contraction
func @contraction(%arg0 : vector<7x8x16x15xf32>, %arg1 : vector<8x16x7x5xf32>,
%arg2 : vector<8x15x5xf32>, %arg3 : vector<8x15x8x5xf32>,
%arg4 : index) {
// Test contraction with batch and contracting dims.
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {batch_dim_map = {{.*}}1, 0{{.*}}, contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
%0 = vector.contract %arg0, %arg1, %arg2
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
// Test contraction with only contracting dims. In this case the lhs/rhs
// dimension of size 8 will be considered a free dim for lhs/rhs and will
// appear twice in the output.
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
%1 = vector.contract %arg0, %arg1, %arg3
{ contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
// Test contraction with optional vector mask arguments.
%lhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
%rhs_mask = vector.make_index_tuple %arg4, %arg4, %arg4, %arg4
: tuple<index, index, index, index>
// CHECK: vector.contract {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}} {contracting_dim_map = {{.*}}0, 2{{.*}}, {{.*}}2, 1{{.*}} : vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
%2 = vector.contract %arg0, %arg1, %arg3, %lhs_mask, %rhs_mask
{ contracting_dim_map = [[0, 2], [2, 1]] }
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
return
}