forked from OSchip/llvm-project
Add vp2intersect to AVX512 dialect.
Adds vp2intersect to the AVX512 dialect and defines a lowering to the LLVM dialect. Author: Matthias Springer <springerm@google.com> Differential Revision: https://reviews.llvm.org/D95301
This commit is contained in:
parent
d705c2fbd4
commit
90ebc489de
|
@ -96,4 +96,41 @@ def MaskScaleFOp : AVX512_Op<"mask.scalef", [NoSideEffect,
|
|||
"$src `,` $a `,` $b `,` $k `,` $rounding attr-dict `:` type($dst)";
|
||||
}
|
||||
|
||||
def Vp2IntersectOp : AVX512_Op<"vp2intersect", [NoSideEffect,
|
||||
AllTypesMatch<["a", "b"]>,
|
||||
TypesMatchWith<"k1 has the same number of bits as elements in a",
|
||||
"a", "k1",
|
||||
"IntegerType::get($_self.getContext(), "
|
||||
"($_self.cast<VectorType>().getShape()[0]))">,
|
||||
TypesMatchWith<"k2 has the same number of bits as elements in b",
|
||||
// Should use `b` instead of `a`, but that would require
|
||||
// adding `type($b)` to assemblyFormat.
|
||||
"a", "k2",
|
||||
"IntegerType::get($_self.getContext(), "
|
||||
"($_self.cast<VectorType>().getShape()[0]))">]> {
|
||||
let summary = "Vp2Intersect op";
|
||||
let description = [{
|
||||
The `vp2intersect` op is an AVX512 specific op that can lower to the proper
|
||||
LLVMAVX512 operation: `llvm.vp2intersect.d.512` or
|
||||
`llvm.vp2intersect.q.512` depending on the type of MLIR vectors it is
|
||||
applied to.
|
||||
|
||||
#### From the Intel Intrinsics Guide:
|
||||
|
||||
Compute intersection of packed integer vectors `a` and `b`, and store
|
||||
indication of match in the corresponding bit of two mask registers
|
||||
specified by `k1` and `k2`. A match in corresponding elements of `a` and
|
||||
`b` is indicated by a set bit in the corresponding bit of the mask
|
||||
registers.
|
||||
}];
|
||||
let arguments = (ins VectorOfLengthAndType<[16, 8], [I32, I64]>:$a,
|
||||
VectorOfLengthAndType<[16, 8], [I32, I64]>:$b
|
||||
);
|
||||
let results = (outs AnyTypeOf<[I16, I8]>:$k1,
|
||||
AnyTypeOf<[I16, I8]>:$k2
|
||||
);
|
||||
let assemblyFormat =
|
||||
"$a `,` $b attr-dict `:` type($a)";
|
||||
}
|
||||
|
||||
#endif // AVX512_OPS
|
||||
|
|
|
@ -16,6 +16,7 @@
|
|||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
#include "mlir/Dialect/AVX512/AVX512Dialect.h.inc"
|
||||
|
|
|
@ -28,25 +28,33 @@ def LLVMAVX512_Dialect : Dialect {
|
|||
// MLIR LLVM AVX512 intrinsics using the MLIR LLVM Dialect type system
|
||||
//----------------------------------------------------------------------------//
|
||||
|
||||
class LLVMAVX512_IntrOp<string mnemonic, list<OpTrait> traits = []> :
|
||||
class LLVMAVX512_IntrOp<string mnemonic, int numResults, list<OpTrait> traits = []> :
|
||||
LLVM_IntrOpBase<LLVMAVX512_Dialect, mnemonic,
|
||||
"x86_avx512_" # !subst(".", "_", mnemonic),
|
||||
[], [], traits, 1>;
|
||||
[], [], traits, numResults>;
|
||||
|
||||
def LLVM_x86_avx512_mask_rndscale_ps_512 :
|
||||
LLVMAVX512_IntrOp<"mask.rndscale.ps.512">,
|
||||
LLVMAVX512_IntrOp<"mask.rndscale.ps.512", 1>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_x86_avx512_mask_rndscale_pd_512 :
|
||||
LLVMAVX512_IntrOp<"mask.rndscale.pd.512">,
|
||||
LLVMAVX512_IntrOp<"mask.rndscale.pd.512", 1>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_x86_avx512_mask_scalef_ps_512 :
|
||||
LLVMAVX512_IntrOp<"mask.scalef.ps.512">,
|
||||
LLVMAVX512_IntrOp<"mask.scalef.ps.512", 1>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_x86_avx512_mask_scalef_pd_512 :
|
||||
LLVMAVX512_IntrOp<"mask.scalef.pd.512">,
|
||||
LLVMAVX512_IntrOp<"mask.scalef.pd.512", 1>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_x86_avx512_vp2intersect_d_512 :
|
||||
LLVMAVX512_IntrOp<"vp2intersect.d.512", 2>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type)>;
|
||||
|
||||
def LLVM_x86_avx512_vp2intersect_q_512 :
|
||||
LLVMAVX512_IntrOp<"vp2intersect.q.512", 2>,
|
||||
Arguments<(ins LLVM_Type, LLVM_Type)>;
|
||||
|
||||
#endif // AVX512_OPS
|
||||
|
|
|
@ -77,6 +77,29 @@ struct ScaleFOp512Conversion : public ConvertToLLVMPattern {
|
|||
return failure();
|
||||
}
|
||||
};
|
||||
|
||||
struct Vp2IntersectOp512Conversion
|
||||
: public ConvertOpToLLVMPattern<Vp2IntersectOp> {
|
||||
explicit Vp2IntersectOp512Conversion(MLIRContext *context,
|
||||
LLVMTypeConverter &typeConverter)
|
||||
: ConvertOpToLLVMPattern<Vp2IntersectOp>(typeConverter) {}
|
||||
|
||||
LogicalResult
|
||||
matchAndRewrite(Vp2IntersectOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
Type elementType =
|
||||
op.a().getType().template cast<VectorType>().getElementType();
|
||||
if (elementType.isInteger(32))
|
||||
return LLVM::detail::oneToOneRewrite(
|
||||
op, LLVM::x86_avx512_vp2intersect_d_512::getOperationName(), operands,
|
||||
*getTypeConverter(), rewriter);
|
||||
if (elementType.isInteger(64))
|
||||
return LLVM::detail::oneToOneRewrite(
|
||||
op, LLVM::x86_avx512_vp2intersect_q_512::getOperationName(), operands,
|
||||
*getTypeConverter(), rewriter);
|
||||
return failure();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
/// Populate the given list with patterns that convert from AVX512 to LLVM.
|
||||
|
@ -84,6 +107,8 @@ void mlir::populateAVX512ToLLVMConversionPatterns(
|
|||
LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
|
||||
// clang-format off
|
||||
patterns.insert<MaskRndScaleOp512Conversion,
|
||||
ScaleFOp512Conversion>(&converter.getContext(), converter);
|
||||
ScaleFOp512Conversion,
|
||||
Vp2IntersectOp512Conversion>(&converter.getContext(),
|
||||
converter);
|
||||
// clang-format on
|
||||
}
|
||||
|
|
|
@ -16,3 +16,13 @@ func @avx512_mask_rndscale(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i1
|
|||
// Keep results alive.
|
||||
return %0, %1, %2, %3 : vector<16xf32>, vector<8xf64>, vector<16xf32>, vector<8xf64>
|
||||
}
|
||||
|
||||
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
||||
-> (i16, i16, i8, i8)
|
||||
{
|
||||
// CHECK: llvm_avx512.vp2intersect.d.512
|
||||
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
|
||||
// CHECK: llvm_avx512.vp2intersect.q.512
|
||||
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
|
||||
return %0, %1, %2, %3 : i16, i16, i8, i8
|
||||
}
|
||||
|
|
|
@ -19,3 +19,13 @@ func @avx512_scalef(%a: vector<16xf32>, %b: vector<8xf64>, %i32: i32, %i16: i16,
|
|||
%1 = avx512.mask.scalef %b, %b, %b, %i8, %i32 : vector<8xf64>
|
||||
return %0, %1: vector<16xf32>, vector<8xf64>
|
||||
}
|
||||
|
||||
func @avx512_vp2intersect(%a: vector<16xi32>, %b: vector<8xi64>)
|
||||
-> (i16, i16, i8, i8)
|
||||
{
|
||||
// CHECK: avx512.vp2intersect {{.*}} : vector<16xi32>
|
||||
%0, %1 = avx512.vp2intersect %a, %a : vector<16xi32>
|
||||
// CHECK: avx512.vp2intersect {{.*}} : vector<8xi64>
|
||||
%2, %3 = avx512.vp2intersect %b, %b : vector<8xi64>
|
||||
return %0, %1, %2, %3 : i16, i16, i8, i8
|
||||
}
|
||||
|
|
|
@ -29,3 +29,23 @@ llvm.func @LLVM_x86_avx512_mask_pd_512(%a: vector<8xf64>,
|
|||
(vector<8xf64>, vector<8xf64>, vector<8xf64>, i8, i32) -> vector<8xf64>
|
||||
llvm.return %1: vector<8xf64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <{ i16, i16 }> @LLVM_x86_vp2intersect_d_512
|
||||
llvm.func @LLVM_x86_vp2intersect_d_512(%a: vector<16xi32>, %b: vector<16xi32>)
|
||||
-> !llvm.struct<packed (i16, i16)>
|
||||
{
|
||||
// CHECK: call { <16 x i1>, <16 x i1> } @llvm.x86.avx512.vp2intersect.d.512(<16 x i32>
|
||||
%0 = "llvm_avx512.vp2intersect.d.512"(%a, %b) :
|
||||
(vector<16xi32>, vector<16xi32>) -> !llvm.struct<packed (i16, i16)>
|
||||
llvm.return %0 : !llvm.struct<packed (i16, i16)>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: define <{ i8, i8 }> @LLVM_x86_vp2intersect_q_512
|
||||
llvm.func @LLVM_x86_vp2intersect_q_512(%a: vector<8xi64>, %b: vector<8xi64>)
|
||||
-> !llvm.struct<packed (i8, i8)>
|
||||
{
|
||||
// CHECK: call { <8 x i1>, <8 x i1> } @llvm.x86.avx512.vp2intersect.q.512(<8 x i64>
|
||||
%0 = "llvm_avx512.vp2intersect.q.512"(%a, %b) :
|
||||
(vector<8xi64>, vector<8xi64>) -> !llvm.struct<packed (i8, i8)>
|
||||
llvm.return %0 : !llvm.struct<packed (i8, i8)>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue