From acce0ea70c11a215fb8814aa6779953a84eea3e0 Mon Sep 17 00:00:00 2001 From: Matthias Springer Date: Fri, 5 Mar 2021 13:08:05 +0900 Subject: [PATCH] [mlir][AVX512] Add mask.compress to AVX512 dialect. Adds mask.compress to the AVX512 dialect and defines a lowering to the LLVM dialect. Differential Revision: https://reviews.llvm.org/D97611 --- mlir/include/mlir/Dialect/AVX512/AVX512.td | 36 +++++++++++++++++++ .../include/mlir/Dialect/LLVMIR/LLVMAVX512.td | 14 ++++++++ .../AVX512ToLLVM/ConvertAVX512ToLLVM.cpp | 29 +++++++++++++++ mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp | 16 +++++++++ .../AVX512ToLLVM/convert-to-llvm.mlir | 13 +++++++ mlir/test/Dialect/AVX512/roundtrip.mlir | 13 +++++++ .../Vector/CPU/AVX512/test-mask-compress.mlir | 27 ++++++++++++++ .../CPU/AVX512/test-vp2intersect-i32.mlir | 2 +- mlir/test/Target/avx512.mlir | 10 ++++++ 9 files changed, 159 insertions(+), 1 deletion(-) create mode 100644 mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-mask-compress.mlir diff --git a/mlir/include/mlir/Dialect/AVX512/AVX512.td b/mlir/include/mlir/Dialect/AVX512/AVX512.td index 7140b013967a..c2487a021a1d 100644 --- a/mlir/include/mlir/Dialect/AVX512/AVX512.td +++ b/mlir/include/mlir/Dialect/AVX512/AVX512.td @@ -31,6 +31,42 @@ def AVX512_Dialect : Dialect { class AVX512_Op traits = []> : Op {} +def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect, + // TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could + // then be removed from assemblyFormat. + AllTypesMatch<["a", "dst"]>, + TypesMatchWith<"`k` has the same number of bits as elements in `dst`", + "dst", "k", + "VectorType::get({$_self.cast().getShape()[0]}, " + "IntegerType::get($_self.getContext(), 1))">]> { + let summary = "Masked compress op"; + let description = [{ + The mask.compress op is an AVX512 specific op that can lower to the + `llvm.mask.compress` instruction. Instead of `src`, a constant vector + vector attribute `constant_src` may be specified. If neither `src` nor + `constant_src` is specified, the remaining elements in the result vector are + set to zero. + + #### From the Intel Intrinsics Guide: + + Contiguously store the active integer/floating-point elements in `a` (those + with their respective bit set in writemask `k`) to `dst`, and pass through the + remaining elements from `src`. + }]; + let verifier = [{ return ::verify(*this); }]; + let arguments = (ins VectorOfLengthAndType<[16, 16, 8, 8], + [I1, I1, I1, I1]>:$k, + VectorOfLengthAndType<[16, 16, 8, 8], + [F32, I32, F64, I64]>:$a, + Optional>:$src, + OptionalAttr:$constant_src); + let results = (outs VectorOfLengthAndType<[16, 16, 8, 8], + [F32, I32, F64, I64]>:$dst); + let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict" + " `:` type($dst) (`,` type($src)^)?"; +} + def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect, AllTypesMatch<["src", "a", "dst"]>, TypesMatchWith<"imm has the same number of bits as elements in dst", diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td index 9bcbdb5977b6..20fb8030c8b1 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAVX512.td @@ -33,6 +33,16 @@ class LLVMAVX512_IntrOp traits = "x86_avx512_" # !subst(".", "_", mnemonic), [], [], traits, numResults>; +// Defined by first result overload. May have to be extended for other +// instructions in the future. +class LLVMAVX512_IntrOverloadedOp traits = []> : + LLVM_IntrOpBase overloadedResults=*/[0], + /*list overloadedOperands=*/[], + traits, /*numResults=*/1>; + def LLVM_x86_avx512_mask_rndscale_ps_512 : LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>, Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; @@ -49,6 +59,10 @@ def LLVM_x86_avx512_mask_scalef_pd_512 : LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>, Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>; +def LLVM_x86_avx512_mask_compress : + LLVMAVX512_IntrOverloadedOp<"mask.compress">, + Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>; + def LLVM_x86_avx512_vp2intersect_d_512 : LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>, Arguments<(ins LLVM_Type, LLVM_Type)>; diff --git a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp index 3381ad85e5ca..74b919717283 100644 --- a/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp +++ b/mlir/lib/Conversion/AVX512ToLLVM/ConvertAVX512ToLLVM.cpp @@ -56,6 +56,34 @@ struct MaskRndScaleOp512Conversion : public ConvertToLLVMPattern { } }; +struct MaskCompressOpConversion + : public ConvertOpToLLVMPattern { + using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + + LogicalResult + matchAndRewrite(MaskCompressOp op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { + MaskCompressOp::Adaptor adaptor(operands); + auto opType = adaptor.a().getType(); + + Value src; + if (op.src()) { + src = adaptor.src(); + } else if (op.constant_src()) { + src = rewriter.create(op.getLoc(), opType, + op.constant_srcAttr()); + } else { + Attribute zeroAttr = rewriter.getZeroAttr(opType); + src = rewriter.create(op->getLoc(), opType, zeroAttr); + } + + rewriter.replaceOpWithNewOp( + op, opType, adaptor.a(), src, adaptor.k()); + + return success(); + } +}; + struct ScaleFOp512Conversion : public ConvertToLLVMPattern { explicit ScaleFOp512Conversion(MLIRContext *context, LLVMTypeConverter &typeConverter) @@ -110,5 +138,6 @@ void mlir::populateAVX512ToLLVMConversionPatterns( ScaleFOp512Conversion, Vp2IntersectOp512Conversion>(&converter.getContext(), converter); + patterns.insert(converter); // clang-format on } diff --git a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp index 697f00864b15..023018af8086 100644 --- a/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp +++ b/mlir/lib/Dialect/AVX512/IR/AVX512Dialect.cpp @@ -25,5 +25,21 @@ void avx512::AVX512Dialect::initialize() { >(); } +static LogicalResult verify(avx512::MaskCompressOp op) { + if (op.src() && op.constant_src()) + return emitError(op.getLoc(), "cannot use both src and constant_src"); + + if (op.src() && (op.src().getType() != op.dst().getType())) + return emitError(op.getLoc(), + "failed to verify that src and dst have same type"); + + if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType())) + return emitError( + op.getLoc(), + "failed to verify that constant_src and dst have same type"); + + return success(); +} + #define GET_OP_CLASSES #include "mlir/Dialect/AVX512/AVX512.cpp.inc" diff --git a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir index b6f7ad8e196e..0d03917d06c3 100644 --- a/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir +++ b/mlir/test/Conversion/AVX512ToLLVM/convert-to-llvm.mlir @@ -17,6 +17,19 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1 return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64> } +func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>, + %k2: vector<8xi1>, %a2: vector<8xi64>) + -> (vector<16xf32>, vector<16xf32>, vector<8xi64>) +{ + // CHECK: llvm_avx512.mask.compress + %0 = avx512.mask.compress %k1, %a1 : vector<16xf32> + // CHECK: llvm_avx512.mask.compress + %1 = avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> + // CHECK: llvm_avx512.mask.compress + %2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> + return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> +} + func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) -> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>) { diff --git a/mlir/test/Dialect/AVX512/roundtrip.mlir b/mlir/test/Dialect/AVX512/roundtrip.mlir index 865f9185b821..dc1a65bbd47b 100644 --- a/mlir/test/Dialect/AVX512/roundtrip.mlir +++ b/mlir/test/Dialect/AVX512/roundtrip.mlir @@ -29,3 +29,16 @@ func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>) %2, %3 = avx512.vp2intersect %b, %b : vector<8xi64> return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1> } + +func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>, + %k2: vector<8xi1>, %a2: vector<8xi64>) + -> (vector<16xf32>, vector<16xf32>, vector<8xi64>) +{ + // CHECK: avx512.mask.compress {{.*}} : vector<16xf32> + %0 = avx512.mask.compress %k1, %a1 : vector<16xf32> + // CHECK: avx512.mask.compress {{.*}} : vector<16xf32> + %1 = avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> + // CHECK: avx512.mask.compress {{.*}} : vector<8xi64> + %2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64> + return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-mask-compress.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-mask-compress.mlir new file mode 100644 index 000000000000..ae34524b50c8 --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-mask-compress.mlir @@ -0,0 +1,27 @@ +// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm | \ +// RUN: mlir-translate --mlir-to-llvmir | \ +// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ +// RUN: FileCheck %s + +func @entry() -> i32 { + %i0 = constant 0 : i32 + + %a = std.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32> + %k = std.constant dense<[1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]> : vector<16xi1> + %r1 = avx512.mask.compress %k, %a : vector<16xf32> + %r2 = avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32> + + vector.print %r1 : vector<16xf32> + // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 ) + + vector.print %r2 : vector<16xf32> + // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 ) + + %src = std.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32> + %r3 = avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32> + + vector.print %r3 : vector<16xf32> + // CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 ) + + return %i0 : i32 +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-vp2intersect-i32.mlir b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-vp2intersect-i32.mlir index d29789a28067..e291e809f201 100644 --- a/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-vp2intersect-i32.mlir +++ b/mlir/test/Integration/Dialect/Vector/CPU/AVX512/test-vp2intersect-i32.mlir @@ -1,5 +1,5 @@ // RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm | \ -// RUN: mlir-translate --avx512-mlir-to-llvmir | \ +// RUN: mlir-translate --mlir-to-llvmir | \ // RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \ // RUN: FileCheck %s diff --git a/mlir/test/Target/avx512.mlir b/mlir/test/Target/avx512.mlir index 940873bf1592..abf36bd153a1 100644 --- a/mlir/test/Target/avx512.mlir +++ b/mlir/test/Target/avx512.mlir @@ -30,6 +30,16 @@ llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>, llvm.return %1: vector<8xf64> } +// CHECK-LABEL: define <16 x float> @LLVM_x86_mask_compress +llvm.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>) + -> vector<16xf32> +{ + // CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32( + %0 = "llvm_avx512.mask.compress"(%a, %a, %k) : + (vector<16xf32>, vector<16xf32>, vector<16xi1>) -> vector<16xf32> + llvm.return %0 : vector<16xf32> +} + // CHECK-LABEL: define { <16 x i1>, <16 x i1> } @LLVM_x86_vp2intersect_d_512 llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>) -> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>