forked from OSchip/llvm-project
[mlir][Vector] Add a vector.matrix_multiply op on 1-D vectors
Summary: This op mirrors the llvm.intr counterpart and allows lowering + type conversions in a progressive fashion. Differential Revision: https://reviews.llvm.org/D75775
This commit is contained in:
parent
47caa69120
commit
63b683a816
|
@ -15,6 +15,12 @@ class LLVMTypeConverter;
|
|||
class ModuleOp;
|
||||
template <typename T> class OpPassBase;
|
||||
|
||||
/// Collect a set of patterns to convert from Vector contractions to LLVM Matrix
|
||||
/// Intrinsics. To lower to assembly, the LLVM flag -lower-matrix-intrinsics
|
||||
/// will be needed when invoking LLVM.
|
||||
void populateVectorToLLVMMatrixConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns);
|
||||
|
||||
/// Collect a set of patterns to convert from the Vector dialect to LLVM.
|
||||
void populateVectorToLLVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
OwningRewritePatternList &patterns);
|
||||
|
|
|
@ -836,12 +836,12 @@ def LLVM_MatrixMultiplyOp
|
|||
: LLVM_OneResultOp<"intr.matrix.multiply">,
|
||||
Arguments<(
|
||||
ins LLVM_Type:$lhs, LLVM_Type:$rhs,
|
||||
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_rows)> {
|
||||
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)> {
|
||||
string llvmBuilder = [{
|
||||
llvm::MatrixBuilder<decltype(builder)> mb(builder);
|
||||
$res = mb.CreateMatrixMultiply(
|
||||
$lhs, $rhs, $lhs_rows.getZExtValue(), $lhs_columns.getZExtValue(),
|
||||
$rhs_rows.getZExtValue());
|
||||
$rhs_columns.getZExtValue());
|
||||
}];
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict "
|
||||
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
|
||||
|
|
|
@ -1336,4 +1336,65 @@ def Vector_PrintOp :
|
|||
let assemblyFormat = "$source attr-dict `:` type($source)";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Ops used for supporting progressive lowering and conversion type changes.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Vector dialect matrix multiplication op that operates on flattened 1-D
|
||||
/// MLIR vectors. This is the counterpart of llvm.matrix.multiply in MLIR.
|
||||
/// This may seem redundant with vector.contract but it serves the purposes of
|
||||
/// more progressive lowering and localized type conversion on the path:
|
||||
/// `vector<...x...xf32> -> vector<...xf32> -> !llvm<... x float>`.
|
||||
def Vector_MatmulOp : Vector_Op<"matrix_multiply", [NoSideEffect,
|
||||
PredOpTrait<"lhs operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 0>>,
|
||||
PredOpTrait<"rhs operand and result have same element type",
|
||||
TCresVTEtIsSameAsOpBase<0, 1>>]>,
|
||||
Arguments<(
|
||||
// TODO(ntv, fhahn): tighten vector element types that make sense.
|
||||
ins VectorOfRankAndType<[1],
|
||||
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$lhs,
|
||||
VectorOfRankAndType<[1],
|
||||
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$rhs,
|
||||
I32Attr:$lhs_rows, I32Attr:$lhs_columns, I32Attr:$rhs_columns)>,
|
||||
Results<(
|
||||
outs VectorOfRankAndType<[1],
|
||||
[AnySignlessInteger, AnySignedInteger, AnyFloat]>:$res)>
|
||||
{
|
||||
let summary = "Vector matrix multiplication op that operates on flattened 1-D"
|
||||
" MLIR vectors";
|
||||
let description = [{
|
||||
This is the counterpart of llvm.matrix.multiply in MLIR. It serves the
|
||||
purposes of more progressive lowering and localized type conversion.
|
||||
|
||||
The ‘vector.matrix_multiply’ op treats `lhs` as matrix with <lhs_rows> rows
|
||||
and <lhs_columns> columns, `rhs` as matrix with <lhs_columns> rows and
|
||||
<rhs_columns> and multiplies them. The result matrix is returned embedded in
|
||||
the result vector.
|
||||
|
||||
Example:
|
||||
|
||||
```
|
||||
%C = vector.matrix_multiply %A, %B
|
||||
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
|
||||
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
|
||||
```
|
||||
}];
|
||||
let builders = [
|
||||
OpBuilder<"Builder *builder, OperationState &result, Value lhs, Value rhs, "
|
||||
"unsigned lhsRows, unsigned lhsColumns, unsigned rhsColumns",
|
||||
[{
|
||||
result.addOperands({lhs, rhs});
|
||||
result.addAttribute("lhs_rows", builder->getI32IntegerAttr(lhsRows));
|
||||
result.addAttribute("lhs_columns", builder->getI32IntegerAttr(lhsColumns));
|
||||
result.addAttribute("rhs_columns", builder->getI32IntegerAttr(rhsColumns));
|
||||
result.addTypes(VectorType::get(lhsRows * lhsColumns,
|
||||
lhs.getType().cast<VectorType>().getElementType()));
|
||||
}]>,
|
||||
];
|
||||
let verifier = ?;
|
||||
let assemblyFormat = "$lhs `,` $rhs attr-dict "
|
||||
"`:` `(` type($lhs) `,` type($rhs) `)` `->` type($res)";
|
||||
}
|
||||
|
||||
#endif // VECTOR_OPS
|
||||
|
|
|
@ -275,6 +275,28 @@ private:
|
|||
}
|
||||
};
|
||||
|
||||
/// Conversion pattern for a vector.matrix_multiply.
|
||||
/// This is lowered directly to the proper llvm.intr.matrix.multiply.
|
||||
class VectorMatmulOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorMatmulOpConversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertToLLVMPattern(vector::MatmulOp::getOperationName(), context,
|
||||
typeConverter) {}
|
||||
|
||||
PatternMatchResult
|
||||
matchAndRewrite(Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
auto matmulOp = cast<vector::MatmulOp>(op);
|
||||
auto adaptor = vector::MatmulOpOperandAdaptor(operands);
|
||||
rewriter.replaceOpWithNewOp<LLVM::MatrixMultiplyOp>(
|
||||
op, typeConverter.convertType(matmulOp.res().getType()), adaptor.lhs(),
|
||||
adaptor.rhs(), matmulOp.lhs_rows(), matmulOp.lhs_columns(),
|
||||
matmulOp.rhs_columns());
|
||||
return matchSuccess();
|
||||
}
|
||||
};
|
||||
|
||||
class VectorReductionOpConversion : public ConvertToLLVMPattern {
|
||||
public:
|
||||
explicit VectorReductionOpConversion(MLIRContext *context,
|
||||
|
@ -1141,6 +1163,12 @@ void mlir::populateVectorToLLVMConversionPatterns(
|
|||
VectorPrintOpConversion>(ctx, converter);
|
||||
}
|
||||
|
||||
void mlir::populateVectorToLLVMMatrixConversionPatterns(
|
||||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
MLIRContext *ctx = converter.getDialect()->getContext();
|
||||
patterns.insert<VectorMatmulOpConversion>(ctx, converter);
|
||||
}
|
||||
|
||||
namespace {
|
||||
struct LowerVectorToLLVMPass : public ModulePass<LowerVectorToLLVMPass> {
|
||||
void runOnModule() override;
|
||||
|
@ -1160,6 +1188,7 @@ void LowerVectorToLLVMPass::runOnModule() {
|
|||
// Convert to the LLVM IR dialect.
|
||||
LLVMTypeConverter converter(&getContext());
|
||||
OwningRewritePatternList patterns;
|
||||
populateVectorToLLVMMatrixConversionPatterns(converter, patterns);
|
||||
populateVectorToLLVMConversionPatterns(converter, patterns);
|
||||
populateStdToLLVMConversionPatterns(converter, patterns);
|
||||
|
||||
|
|
|
@ -701,3 +701,15 @@ func @reduce_i64(%arg0: vector<16xi64>) -> i64 {
|
|||
// CHECK: %[[V:.*]] = "llvm.intr.experimental.vector.reduce.add"(%[[A]])
|
||||
// CHECK: llvm.return %[[V]] : !llvm.i64
|
||||
|
||||
|
||||
// 4x16 16x3 4x3
|
||||
func @matrix_ops(%A: vector<64xf64>, %B: vector<48xf64>) -> vector<12xf64> {
|
||||
%C = vector.matrix_multiply %A, %B
|
||||
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32 } :
|
||||
(vector<64xf64>, vector<48xf64>) -> vector<12xf64>
|
||||
return %C: vector<12xf64>
|
||||
}
|
||||
// CHECK-LABEL: llvm.func @matrix_ops
|
||||
// CHECK: llvm.intr.matrix.multiply %{{.*}}, %{{.*}} {
|
||||
// CHECK-SAME: lhs_columns = 16 : i32, lhs_rows = 4 : i32, rhs_columns = 3 : i32
|
||||
// CHECK-SAME: } : (!llvm<"<64 x double>">, !llvm<"<48 x double>">) -> !llvm<"<12 x double>">
|
||||
|
|
|
@ -136,7 +136,7 @@ llvm.func @matrix_intrinsics(%A: !llvm<"<64 x float>">, %B: !llvm<"<48 x float>"
|
|||
%ptr: !llvm<"float*">, %stride: !llvm.i32) {
|
||||
// CHECK: call <12 x float> @llvm.matrix.multiply.v12f32.v64f32.v48f32(<64 x float> %0, <48 x float> %1, i32 4, i32 16, i32 3)
|
||||
%C = llvm.intr.matrix.multiply %A, %B
|
||||
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_rows = 3: i32} :
|
||||
{ lhs_rows = 4: i32, lhs_columns = 16: i32 , rhs_columns = 3: i32} :
|
||||
(!llvm<"<64 x float>">, !llvm<"<48 x float>">) -> !llvm<"<12 x float>">
|
||||
// CHECK: call <48 x float> @llvm.matrix.transpose.v48f32(<48 x float> %1, i32 3, i32 16)
|
||||
%D = llvm.intr.matrix.transpose %B { rows = 3: i32, columns = 16: i32} :
|
||||
|
|
Loading…
Reference in New Issue