Add vp2intersect to AVX512 dialect.

Adds vp2intersect to the AVX512 dialect and defines a lowering to the
LLVM dialect.

Author: Matthias Springer <springerm@google.com>

Differential Revision: https://reviews.llvm.org/D95301
This commit is contained in:
Matthias Springer 2021-01-26 07:31:20 +00:00 committed by Nicolas Vasilache
parent d705c2fbd4
commit 90ebc489de
7 changed files with 118 additions and 7 deletions

View File

@ -96,4 +96,41 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
}
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
AllTypesMatch<["a", "b"]>,
TypesMatchWith<"k1 has the same number of bits as elements in a",
"a", "k1",
"IntegerType::get($_self.getContext(), "
"($_self.cast<VectorType>().getShape()[0]))">,
TypesMatchWith<"k2 has the same number of bits as elements in b",
// Should use `b` instead of `a`, but that would require
// adding `type($b)` to assemblyFormat.
"a", "k2",
"IntegerType::get($_self.getContext(), "
"($_self.cast<VectorType>().getShape()[0]))">]> {
let summary = "Vp2Intersect op";
let description = [{
The `vp2intersect` op is an AVX512 specific op that can lower to the proper
LLVMAVX512 operation: `llvm.vp2intersect.d.512` or
`llvm.vp2intersect.q.512` depending on the type of MLIR vectors it is
applied to.
#### From the Intel Intrinsics Guide:
Compute intersection of packed integer vectors `a` and `b`, and store
indication of match in the corresponding bit of two mask registers
specified by `k1` and `k2`. A match in corresponding elements of `a` and
`b` is indicated by a set bit in the corresponding bit of the mask
registers.
}];
let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a,
VectorOfLengthAndType<[16, 8], [I32, I64]>:$b
);
let results = (outs AnyTypeOf<[I16, I8]>:$k1,
AnyTypeOf<[I16, I8]>:$k2
);
let assemblyFormat =
"$a `,` $b attr-dict `:` type($a)";
}
#endif // AVX512_OPS

View File

@ -16,6 +16,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/OpImplementation.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Dialect/AVX512/AVX512Dialect.h.inc"

View File

@ -28,25 +28,33 @@ def LLVMAVX512_Dialect : Dialect {
// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system
//----------------------------------------------------------------------------//
class LLVMAVX512_IntrOp<string mnemonic, list<OpTrait> traits = []> :
class LLVMAVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
"x86_avx512_" # !subst(".", "_", mnemonic),
[], [], traits, 1>;
[], [], traits, numResults>;
def LLVM_x86_avx512_mask_rndscale_ps_512 :
LLVMAVX512_IntrOp<"mask.rndscale.ps.512">,
LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_x86_avx512_mask_rndscale_pd_512 :
LLVMAVX512_IntrOp<"mask.rndscale.pd.512">,
LLVMAVX512_IntrOp<"mask.rndscale.pd.512", 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_x86_avx512_mask_scalef_ps_512 :
LLVMAVX512_IntrOp<"mask.scalef.ps.512">,
LLVMAVX512_IntrOp<"mask.scalef.ps.512", 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_x86_avx512_mask_scalef_pd_512 :
LLVMAVX512_IntrOp<"mask.scalef.pd.512">,
LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>,
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
def LLVM_x86_avx512_vp2intersect_d_512 :
LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>,
Arguments<(ins LLVM_Type, LLVM_Type)>;
def LLVM_x86_avx512_vp2intersect_q_512 :
LLVMAVX512_IntrOp<"vp2intersect.q.512", 2>,
Arguments<(ins LLVM_Type, LLVM_Type)>;
#endif // AVX512_OPS

View File

@ -77,6 +77,29 @@ struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
return failure();
}
};
struct Vp2IntersectOp512Conversion
: public ConvertOpToLLVMPattern<Vp2IntersectOp> {
explicit Vp2IntersectOp512Conversion(MLIRContext *context,
LLVMTypeConverter &typeConverter)
: ConvertOpToLLVMPattern<Vp2IntersectOp>(typeConverter) {}
LogicalResult
matchAndRewrite(Vp2IntersectOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Type elementType =
op.a().getType().template cast<VectorType>().getElementType();
if (elementType.isInteger(32))
return LLVM::detail::oneToOneRewrite(
op, LLVM::x86_avx512_vp2intersect_d_512::getOperationName(), operands,
*getTypeConverter(), rewriter);
if (elementType.isInteger(64))
return LLVM::detail::oneToOneRewrite(
op, LLVM::x86_avx512_vp2intersect_q_512::getOperationName(), operands,
*getTypeConverter(), rewriter);
return failure();
}
};
} // namespace
/// Populate the given list with patterns that convert from AVX512 to LLVM.
@ -84,6 +107,8 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
// clang-format off
patterns.insert<MaskRndScaleOp512Conversion,
ScaleFOp512Conversion>(&converter.getContext(), converter);
ScaleFOp512Conversion,
Vp2IntersectOp512Conversion>(&converter.getContext(),
converter);
// clang-format on
}

View File

@ -16,3 +16,13 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
// Keep results alive.
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
}
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-> (i16, i16, i8, i8)
{
// CHECK: llvm_avx512.vp2intersect.d.512
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: llvm_avx512.vp2intersect.q.512
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : i16, i16, i8, i8
}

View File

@ -19,3 +19,13 @@ func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16,
%1 = avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
return %0, %1: vector<16xf32>, vector<8xf64>
}
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
-> (i16, i16, i8, i8)
{
// CHECK: avx512.vp2intersect {{.*}} : vector<16xi32>
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
// CHECK: avx512.vp2intersect {{.*}} : vector<8xi64>
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
return %0, %1, %2, %3 : i16, i16, i8, i8
}

View File

@ -29,3 +29,23 @@ llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>,
(vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64>
llvm.return %1: vector<8xf64>
}
// CHECK-LABEL: define <{ i16, i16 }> @LLVM_x86_vp2intersect_d_512
llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
-> !llvm.struct<packed (i16, i16)>
{
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
%0 = "llvm_avx512.vp2intersect.d.512"(%a, %b) :
(vector<16xi32>, vector<16xi32>) -> !llvm.struct<packed (i16, i16)>
llvm.return %0 : !llvm.struct<packed (i16, i16)>
}
// CHECK-LABEL: define <{ i8, i8 }> @LLVM_x86_vp2intersect_q_512
llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
-> !llvm.struct<packed (i8, i8)>
{
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
%0 = "llvm_avx512.vp2intersect.q.512"(%a, %b) :
(vector<8xi64>, vector<8xi64>) -> !llvm.struct<packed (i8, i8)>
llvm.return %0 : !llvm.struct<packed (i8, i8)>
}