[mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors

Summary: This op mirrors the llvm.intr counterpart and allows lowering + type conversions in a progressive fashion.

Differential Revision: https://reviews.llvm.org/D75775
This commit is contained in:
Nicolas Vasilache 2020-03-09 13:29:13 -04:00
parent 47caa69120
commit 63b683a816
6 changed files with 111 additions and 3 deletions

View File

@ -15,6 +15,12 @@ class LLVMTypeConverter;
class ModuleOp;
template <typename T> class OpPassBase;
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
/// will be needed when invoking LLVM.
void populateVectorToLLVMMatrixConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
OwningRewritePatternList &patterns);

View File

@ -836,12 +836,12 @@ def LLVM_MatrixMultiplyOp
: LLVM_OneResultOp<"intr.matrix.multiply">,
Arguments<(
ins LLVM_Type:$lhs, LLVM_Type:$rhs,
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> {
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)> {
string llvmBuilder = [{
llvm::MatrixBuilder<decltype(builder)> mb(builder);
$res = mb.CreateMatrixMultiply(
$lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(),
$rhs_rows.getZExtValue());
$rhs_columns.getZExtValue());
}];
let assemblyFormat = "$lhs `,` $rhs attr-dict "
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";

View File

@ -1336,4 +1336,65 @@ def Vector_PrintOp :
let assemblyFormat = "$source attr-dict `:` type($source)";
}
//===----------------------------------------------------------------------===//
// Ops used for supporting progressive lowering and conversion type changes.
//===----------------------------------------------------------------------===//
/// Vector dialect matrix multiplication op that operates on flattened 1-D
/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
/// This may seem redundant with vector.contract but it serves the purposes of
/// more progressive lowering and localized type conversion on the path:
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
PredOpTrait<"lhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 0>>,
PredOpTrait<"rhs operand and result have same element type",
TCresVTEtIsSameAsOpBase<0, 1>>]>,
Arguments<(
// TODO(ntv, fhahn): tighten vector element types that make sense.
ins VectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs,
VectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs,
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
Results<(
outs VectorOfRankAndType<[1],
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)>
{
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
" MLIR vectors";
let description = [{
This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
purposes of more progressive lowering and localized type conversion.
The vector.matrix_multiply op treats `lhs` as matrix with <lhs_rows> rows
and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
<rhs_columns> and multiplies them. The result matrix is returned embedded in
the result vector.
Example:
```
%C = vector.matrix_multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
```
}];
let builders = [
OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, "
"unsigned lhsRows, unsigned lhsColumns, unsigned rhsColumns",
[{
result.addOperands({lhs, rhs});
result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows));
result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns));
result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns));
result.addTypes(VectorType::get(lhsRows * lhsColumns,
lhs.getType().cast<VectorType>().getElementType()));
}]>,
];
let verifier = ?;
let assemblyFormat = "$lhs `,` $rhs attr-dict "
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
}
#endif // VECTOR_OPS

View File

@ -275,6 +275,28 @@ private:
}
};
/// Conversion pattern for a vector.matrix_multiply.
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorMatmulOpConversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
typeConverter) {}
PatternMatchResult
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
auto matmulOp = cast<vector::MatmulOp>(op);
auto adaptor = vector::MatmulOpOperandAdaptor(operands);
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
matmulOp.rhs_columns());
return matchSuccess();
}
};
class VectorReductionOpConversion : public ConvertToLLVMPattern {
public:
explicit VectorReductionOpConversion(MLIRContext *context,
@ -1141,6 +1163,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
VectorPrintOpConversion>(ctx, converter);
}
void mlir::populateVectorToLLVMMatrixConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
MLIRContext *ctx = converter.getDialect()->getContext();
patterns.insert<VectorMatmulOpConversion>(ctx, converter);
}
namespace {
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
void runOnModule() override;
@ -1160,6 +1188,7 @@ void LowerVectorToLLVMPass::runOnModule() {
// Convert to the LLVM IR dialect.
LLVMTypeConverter converter(&getContext());
OwningRewritePatternList patterns;
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
populateVectorToLLVMConversionPatterns(converter, patterns);
populateStdToLLVMConversionPatterns(converter, patterns);

View File

@ -701,3 +701,15 @@ func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
// CHECK: llvm.return %[[V]] : !llvm.i64
// 4x16 16x3 4x3
func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
%C = vector.matrix_multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
return %C: vector<12xf64>
}
// CHECK-LABEL: llvm.func @matrix_ops
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">

View File

@ -136,7 +136,7 @@ llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>"
%ptr: !llvm<"float*">, %stride: !llvm.i32) {
// CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3)
%C = llvm.intr.matrix.multiply %A, %B
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} :
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} :
(!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>">
// CHECK: call <48 x float> @llvm.matrix.transpose.v48f32(<48 x float> %1, i32 3, i32 16)
%D = llvm.intr.matrix.transpose %B { rows = 3: i32, columns = 16: i32} :