forked from OSchip/llvm-project
[mlir][AVX512] Add mask.compress to AVX512 dialect.
Adds mask.compress to the AVX512 dialect and defines a lowering to the LLVM dialect. Differential Revision: https://reviews.llvm.org/D97611
This commit is contained in:
parent
f4ad7a1a15
commit
acce0ea70c
|
@ -31,6 +31,42 @@ def AVX512_Dialect : Dialect {
|
|||
class AVX512_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<AVX512_Dialect, mnemonic, traits> {}
|
||||
|
||||
def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
|
||||
// TODO: Support optional arguments in `AllTypesMatch`. "type($src)" could
|
||||
// then be removed from assemblyFormat.
|
||||
AllTypesMatch<["a", "dst"]>,
|
||||
TypesMatchWith<"`k` has the same number of bits as elements in `dst`",
|
||||
"dst", "k",
|
||||
"VectorType::get({$_self.cast<VectorType>().getShape()[0]}, "
|
||||
"IntegerType::get($_self.getContext(), 1))">]> {
|
||||
let summary = "Masked compress op";
|
||||
let description = [{
|
||||
The mask.compress op is an AVX512 specific op that can lower to the
|
||||
`llvm.mask.compress` instruction. Instead of `src`, a constant vector
|
||||
vector attribute `constant_src` may be specified. If neither `src` nor
|
||||
`constant_src` is specified, the remaining elements in the result vector are
|
||||
set to zero.
|
||||
|
||||
#### From the Intel Intrinsics Guide:
|
||||
|
||||
Contiguously store the active integer/floating-point elements in `a` (those
|
||||
with their respective bit set in writemask `k`) to `dst`, and pass through the
|
||||
remaining elements from `src`.
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let arguments = (ins VectorOfLengthAndType<[16, 16, 8, 8],
|
||||
[I1, I1, I1, I1]>:$k,
|
||||
VectorOfLengthAndType<[16, 16, 8, 8],
|
||||
[F32, I32, F64, I64]>:$a,
|
||||
Optional<VectorOfLengthAndType<[16, 16, 8, 8],
|
||||
[F32, I32, F64, I64]>>:$src,
|
||||
OptionalAttr<ElementsAttr>:$constant_src);
|
||||
let results = (outs VectorOfLengthAndType<[16, 16, 8, 8],
|
||||
[F32, I32, F64, I64]>:$dst);
|
||||
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
|
||||
" `:` type($dst) (`,` type($src)^)?";
|
||||
}
|
||||
|
||||
def MaskRndScaleOp : AVX512_Op<"mask.rndscale", [NoSideEffect,
|
||||
AllTypesMatch<["src", "a", "dst"]>,
|
||||
TypesMatchWith<"imm has the same number of bits as elements in dst",
|
||||
|
|
|
@ -33,6 +33,16 @@ class LLVMAVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits =
|
|||
"x86_avx512_" # !subst(".", "_", mnemonic),
|
||||
[], [], traits, numResults>;
|
||||
|
||||
// Defined by first result overload. May have to be extended for other
|
||||
// instructions in the future.
|
||||
class LLVMAVX512_IntrOverloadedOp<string mnemonic,
|
||||
list<OpTrait> traits = []> :
|
||||
LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
|
||||
"x86_avx512_" # !subst(".", "_", mnemonic),
|
||||
/*list<int> overloadedResults=*/[0],
|
||||
/*list<int> overloadedOperands=*/[],
|
||||
traits, /*numResults=*/1>;
|
||||
|
||||
def LLVM_x86_avx512_mask_rndscale_ps_512 :
|
||||
LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
@ -49,6 +59,10 @@ def LLVM_x86_avx512_mask_scalef_pd_512 :
|
|||
LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_x86_avx512_mask_compress :
|
||||
LLVMAVX512_IntrOverloadedOp<"mask.compress">,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_x86_avx512_vp2intersect_d_512 :
|
||||
LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type)>;
|
||||
|
|
|
@ -56,6 +56,34 @@ struct MaskRndScaleOp512Conversion : public ConvertToLLVMPattern {
|
|||
}
|
||||
};
|
||||
|
||||
struct MaskCompressOpConversion
|
||||
: public ConvertOpToLLVMPattern<MaskCompressOp> {
|
||||
using ConvertOpToLLVMPattern<MaskCompressOp>::ConvertOpToLLVMPattern;
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(MaskCompressOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
MaskCompressOp::Adaptor adaptor(operands);
|
||||
auto opType = adaptor.a().getType();
|
||||
|
||||
Value src;
|
||||
if (op.src()) {
|
||||
src = adaptor.src();
|
||||
} else if (op.constant_src()) {
|
||||
src = rewriter.create<ConstantOp>(op.getLoc(), opType,
|
||||
op.constant_srcAttr());
|
||||
} else {
|
||||
Attribute zeroAttr = rewriter.getZeroAttr(opType);
|
||||
src = rewriter.create<ConstantOp>(op->getLoc(), opType, zeroAttr);
|
||||
}
|
||||
|
||||
rewriter.replaceOpWithNewOp<LLVM::x86_avx512_mask_compress>(
|
||||
op, opType, adaptor.a(), src, adaptor.k());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
|
||||
explicit ScaleFOp512Conversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
|
@ -110,5 +138,6 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
|
|||
ScaleFOp512Conversion,
|
||||
Vp2IntersectOp512Conversion>(&converter.getContext(),
|
||||
converter);
|
||||
patterns.insert<MaskCompressOpConversion>(converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -25,5 +25,21 @@ void avx512::AVX512Dialect::initialize() {
|
|||
>();
|
||||
}
|
||||
|
||||
static LogicalResult verify(avx512::MaskCompressOp op) {
|
||||
if (op.src() && op.constant_src())
|
||||
return emitError(op.getLoc(), "cannot use both src and constant_src");
|
||||
|
||||
if (op.src() && (op.src().getType() != op.dst().getType()))
|
||||
return emitError(op.getLoc(),
|
||||
"failed to verify that src and dst have same type");
|
||||
|
||||
if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType()))
|
||||
return emitError(
|
||||
op.getLoc(),
|
||||
"failed to verify that constant_src and dst have same type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/AVX512/AVX512.cpp.inc"
|
||||
|
|
|
@ -17,6 +17,19 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
|
|||
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
|
||||
}
|
||||
|
||||
func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
|
||||
%k2: vector<8xi1>, %a2: vector<8xi64>)
|
||||
-> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
|
||||
{
|
||||
// CHECK: llvm_avx512.mask.compress
|
||||
%0 = avx512.mask.compress %k1, %a1 : vector<16xf32>
|
||||
// CHECK: llvm_avx512.mask.compress
|
||||
%1 = avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
||||
// CHECK: llvm_avx512.mask.compress
|
||||
%2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
|
||||
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
|
||||
}
|
||||
|
||||
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
||||
-> (vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>)
|
||||
{
|
||||
|
|
|
@ -29,3 +29,16 @@ func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
|||
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
|
||||
return %0, %1, %2, %3 : vector<16xi1>, vector<16xi1>, vector<8xi1>, vector<8xi1>
|
||||
}
|
||||
|
||||
func @avx512_mask_compress(%k1: vector<16xi1>, %a1: vector<16xf32>,
|
||||
%k2: vector<8xi1>, %a2: vector<8xi64>)
|
||||
-> (vector<16xf32>, vector<16xf32>, vector<8xi64>)
|
||||
{
|
||||
// CHECK: avx512.mask.compress {{.*}} : vector<16xf32>
|
||||
%0 = avx512.mask.compress %k1, %a1 : vector<16xf32>
|
||||
// CHECK: avx512.mask.compress {{.*}} : vector<16xf32>
|
||||
%1 = avx512.mask.compress %k1, %a1 {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
||||
// CHECK: avx512.mask.compress {{.*}} : vector<8xi64>
|
||||
%2 = avx512.mask.compress %k2, %a2, %a2 : vector<8xi64>, vector<8xi64>
|
||||
return %0, %1, %2 : vector<16xf32>, vector<16xf32>, vector<8xi64>
|
||||
}
|
||||
|
|
|
@ -0,0 +1,27 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm | \
|
||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="avx512bw" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
func @entry() -> i32 {
|
||||
%i0 = constant 0 : i32
|
||||
|
||||
%a = std.constant dense<[1., 0., 0., 2., 4., 3., 5., 7., 8., 1., 5., 5., 3., 1., 0., 7.]> : vector<16xf32>
|
||||
%k = std.constant dense<[1, 0, 1, 1, 1, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 0]> : vector<16xi1>
|
||||
%r1 = avx512.mask.compress %k, %a : vector<16xf32>
|
||||
%r2 = avx512.mask.compress %k, %a {constant_src = dense<5.0> : vector<16xf32>} : vector<16xf32>
|
||||
|
||||
vector.print %r1 : vector<16xf32>
|
||||
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 0, 0, 0, 0, 0, 0, 0 )
|
||||
|
||||
vector.print %r2 : vector<16xf32>
|
||||
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 5, 5, 5, 5, 5, 5, 5 )
|
||||
|
||||
%src = std.constant dense<[0., 2., 1., 8., 6., 4., 4., 3., 2., 8., 5., 6., 3., 7., 6., 9.]> : vector<16xf32>
|
||||
%r3 = avx512.mask.compress %k, %a, %src : vector<16xf32>, vector<16xf32>
|
||||
|
||||
vector.print %r3 : vector<16xf32>
|
||||
// CHECK: ( 1, 0, 2, 4, 5, 5, 3, 1, 0, 8, 5, 6, 3, 7, 6, 9 )
|
||||
|
||||
return %i0 : i32
|
||||
}
|
|
@ -1,5 +1,5 @@
|
|||
// RUN: mlir-opt %s -convert-scf-to-std -convert-vector-to-llvm="enable-avx512" -convert-std-to-llvm | \
|
||||
// RUN: mlir-translate --avx512-mlir-to-llvmir | \
|
||||
// RUN: mlir-translate --mlir-to-llvmir | \
|
||||
// RUN: %lli --entry-function=entry --mattr="avx512bw,avx512vp2intersect" --dlopen=%mlir_integration_test_dir/libmlir_c_runner_utils%shlibext | \
|
||||
// RUN: FileCheck %s
|
||||
|
||||
|
|
|
@ -30,6 +30,16 @@ llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>,
|
|||
llvm.return %1: vector<8xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <16 x float> @LLVM_x86_mask_compress
|
||||
llvm.func @LLVM_x86_mask_compress(%k: vector<16xi1>, %a: vector<16xf32>)
|
||||
-> vector<16xf32>
|
||||
{
|
||||
// CHECK: call <16 x float> @llvm.x86.avx512.mask.compress.v16f32(
|
||||
%0 = "llvm_avx512.mask.compress"(%a, %a, %k) :
|
||||
(vector<16xf32>, vector<16xf32>, vector<16xi1>) -> vector<16xf32>
|
||||
llvm.return %0 : vector<16xf32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define { <16 x i1>, <16 x i1> } @LLVM_x86_vp2intersect_d_512
|
||||
llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
|
||||
-> !llvm.struct<(vector<16 x i1>, vector<16 x i1>)>
|
||||
|
|
Loading…
Reference in New Issue