[MLIR] Remove LLVM_AnyInteger type constraint

LLVM Dialect uses builtin-integer types. The existing LLVM_AnyInteger
type constraint is a dupe of AnyInteger. This patch removes LLVM_AnyInteger
and replaces all usage with AnyInteger.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D103839
This commit is contained in:
Kiran Chandramohan 2021-06-08 16:48:57 +01:00
parent d2eccf9bb7
commit cd73af9231
5 changed files with 45 additions and 51 deletions

View File

@ -239,7 +239,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
//
def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
Arguments<(ins LLVM_AnyInteger, LLVM_AnyInteger)>;
Arguments<(ins AnyInteger, AnyInteger)>;
//
// Tile memory operations. Parameters define the tile size,
@ -248,12 +248,12 @@ def LLVM_x86_amx_tilezero : AMX_IntrOp<"tilezero", 1>,
//
def LLVM_x86_amx_tileloadd64 : AMX_IntrOp<"tileloadd64", 1>,
Arguments<(ins LLVM_AnyInteger,
LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger)>;
Arguments<(ins AnyInteger,
AnyInteger, LLVM_AnyPointer, AnyInteger)>;
def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
Arguments<(ins LLVM_AnyInteger,
LLVM_AnyInteger, LLVM_AnyPointer, LLVM_AnyInteger, LLVM_Type)>;
Arguments<(ins AnyInteger,
AnyInteger, LLVM_AnyPointer, AnyInteger, LLVM_Type)>;
//
// Tile multiplication operations (series of dot products). Parameters
@ -263,32 +263,32 @@ def LLVM_x86_amx_tilestored64 : AMX_IntrOp<"tilestored64", 0>,
// Dot product of bf16 tiles into f32 tile.
def LLVM_x86_amx_tdpbf16ps : AMX_IntrOp<"tdpbf16ps", 1>,
Arguments<(ins LLVM_AnyInteger,
LLVM_AnyInteger,
LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with sign/sign extension).
def LLVM_x86_amx_tdpbssd : AMX_IntrOp<"tdpbssd", 1>,
Arguments<(ins LLVM_AnyInteger,
LLVM_AnyInteger,
LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with sign/zero extension).
def LLVM_x86_amx_tdpbsud : AMX_IntrOp<"tdpbsud", 1>,
Arguments<(ins LLVM_AnyInteger,
LLVM_AnyInteger,
LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with zero/sign extension).
def LLVM_x86_amx_tdpbusd : AMX_IntrOp<"tdpbusd", 1>,
Arguments<(ins LLVM_AnyInteger,
LLVM_AnyInteger,
LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
// Dot product of i8 tiles into i32 tile (with zero/zero extension).
def LLVM_x86_amx_tdpbuud : AMX_IntrOp<"tdpbuud", 1>,
Arguments<(ins LLVM_AnyInteger,
LLVM_AnyInteger,
LLVM_AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
Arguments<(ins AnyInteger,
AnyInteger,
AnyInteger, LLVM_Type, LLVM_Type, LLVM_Type)>;
#endif // AMX

View File

@ -62,11 +62,6 @@ def LLVM_TokenType : Type<
"LLVM token type">,
BuildableType<"::mlir::LLVM::LLVMTokenType::get($_builder.getContext())">;
// Type constraint accepting LLVM integer types.
def LLVM_AnyInteger : Type<
CPred<"$_self.isa<::mlir::IntegerType>()">,
"LLVM integer type">;
// Type constraint accepting LLVM primitive types, i.e. all types except void
// and function.
def LLVM_PrimitiveType : Type<

View File

@ -129,7 +129,7 @@ class LLVM_ArithmeticOpBase<Type type, string mnemonic,
}
class LLVM_IntArithmeticOp<string mnemonic, string builderFunc,
list<OpTrait> traits = []> :
LLVM_ArithmeticOpBase<LLVM_AnyInteger, mnemonic, builderFunc, traits> {
LLVM_ArithmeticOpBase<AnyInteger, mnemonic, builderFunc, traits> {
let arguments = commonArgs;
}
class LLVM_FloatArithmeticOp<string mnemonic, string builderFunc,
@ -190,8 +190,8 @@ def ICmpPredicate : I64EnumAttr<
// Other integer operations.
def LLVM_ICmpOp : LLVM_Op<"icmp", [NoSideEffect]> {
let arguments = (ins ICmpPredicate:$predicate,
AnyTypeOf<[LLVM_ScalarOrVectorOf<LLVM_AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs,
AnyTypeOf<[LLVM_ScalarOrVectorOf<LLVM_AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs);
AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$lhs,
AnyTypeOf<[LLVM_ScalarOrVectorOf<AnyInteger>, LLVM_ScalarOrVectorOf<LLVM_AnyPointer>]>:$rhs);
let results = (outs LLVM_ScalarOrVectorOf<I1>:$res);
let llvmBuilder = [{
$res = builder.CreateICmp(getLLVMCmpPredicate($predicate), $lhs, $rhs);
@ -290,7 +290,7 @@ class MemoryOpWithAlignmentAndAttributes : MemoryOpWithAlignmentBase {
// Memory-related operations.
def LLVM_AllocaOp : LLVM_Op<"alloca">, MemoryOpWithAlignmentBase {
let arguments = (ins LLVM_AnyInteger:$arraySize,
let arguments = (ins AnyInteger:$arraySize,
OptionalAttr<I64Attr>:$alignment);
let results = (outs LLVM_AnyPointer:$res);
string llvmBuilder = [{
@ -318,7 +318,7 @@ def LLVM_GEPOp
"$res = builder.CreateGEP("
" $base->getType()->getPointerElementType(), $base, $indices);"> {
let arguments = (ins LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$base,
Variadic<LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>:$indices);
Variadic<LLVM_ScalarOrVectorOf<AnyInteger>>:$indices);
let results = (outs LLVM_ScalarOrVectorOf<LLVM_AnyPointer>:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = [{
@ -389,32 +389,32 @@ def LLVM_AddrSpaceCastOp : LLVM_CastOp<"addrspacecast", "CreateAddrSpaceCast",
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
def LLVM_IntToPtrOp : LLVM_CastOp<"inttoptr", "CreateIntToPtr",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>>;
def LLVM_PtrToIntOp : LLVM_CastOp<"ptrtoint", "CreatePtrToInt",
LLVM_ScalarOrVectorOf<LLVM_AnyPointer>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
LLVM_ScalarOrVectorOf<AnyInteger>>;
def LLVM_SExtOp : LLVM_CastOp<"sext", "CreateSExt",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<AnyInteger>>;
def LLVM_ZExtOp : LLVM_CastOp<"zext", "CreateZExt",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<AnyInteger>>;
def LLVM_TruncOp : LLVM_CastOp<"trunc", "CreateTrunc",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<AnyInteger>>;
def LLVM_SIToFPOp : LLVM_CastOp<"sitofp", "CreateSIToFP",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_UIToFPOp : LLVM_CastOp<"uitofp", "CreateUIToFP",
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>,
LLVM_ScalarOrVectorOf<AnyInteger>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
def LLVM_FPToSIOp : LLVM_CastOp<"fptosi", "CreateFPToSI",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
LLVM_ScalarOrVectorOf<AnyInteger>>;
def LLVM_FPToUIOp : LLVM_CastOp<"fptoui", "CreateFPToUI",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyInteger>>;
LLVM_ScalarOrVectorOf<AnyInteger>>;
def LLVM_FPExtOp : LLVM_CastOp<"fpext", "CreateFPExt",
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>,
LLVM_ScalarOrVectorOf<LLVM_AnyFloat>>;
@ -514,7 +514,7 @@ def LLVM_CallOp : LLVM_Op<"call",
let printer = [{ printCallOp(p, *this); }];
}
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$vector, LLVM_AnyInteger:$position);
let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
let results = (outs LLVM_Type:$res);
string llvmBuilder = [{
$res = builder.CreateExtractElement($vector, $position);
@ -537,7 +537,7 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
}
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value,
LLVM_AnyInteger:$position);
AnyInteger:$position);
let results = (outs LLVM_AnyVector:$res);
string llvmBuilder = [{
$res = builder.CreateInsertElement($vector, $value, $position);
@ -1616,7 +1616,7 @@ def AtomicOrdering : I64EnumAttr<
let cppNamespace = "::mlir::LLVM";
}
def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, LLVM_AnyInteger]>;
def LLVM_AtomicRMWType : AnyTypeOf<[LLVM_AnyFloat, AnyInteger]>;
// FIXME: Need to add alignment attribute to MLIR atomicrmw operation.
def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
@ -1634,7 +1634,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
let verifier = "return ::verify(*this);";
}
def LLVM_AtomicCmpXchgType : AnyTypeOf<[LLVM_AnyInteger, LLVM_AnyPointer]>;
def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>;
def LLVM_AtomicCmpXchgResultType : Type<And<[
LLVM_AnyStruct.predicate,
CPred<"$_self.cast<::mlir::LLVM::LLVMStructType>().getBody().size() == 2">,

View File

@ -28,9 +28,8 @@ def OpenMP_Dialect : Dialect {
class OpenMP_Op<string mnemonic, list<OpTrait> traits = []> :
Op<OpenMP_Dialect, mnemonic, traits>;
// Type which can be constraint accepting standard integers, indices and
// LLVM integer types.
def IntLikeType : AnyTypeOf<[AnyInteger, Index, LLVM_AnyInteger]>;
// Type which can be constraint accepting standard integers and indices.
def IntLikeType : AnyTypeOf<[AnyInteger, Index]>;
//===----------------------------------------------------------------------===//
// 2.6 parallel Construct

View File

@ -539,7 +539,7 @@ func @nvvm_invalid_mma_7(%a0 : vector<2xf16>, %a1 : vector<2xf16>,
// -----
func @atomicrmw_expected_ptr(%f32 : f32) {
// expected-error@+1 {{operand #0 must be LLVM pointer to floating point LLVM type or LLVM integer type}}
// expected-error@+1 {{operand #0 must be LLVM pointer to floating point LLVM type or integer}}
%0 = "llvm.atomicrmw"(%f32, %f32) {bin_op=11, ordering=1} : (f32, f32) -> f32
llvm.return
}
@ -587,7 +587,7 @@ func @atomicrmw_expected_int(%f32_ptr : !llvm.ptr<f32>, %f32 : f32) {
// -----
func @cmpxchg_expected_ptr(%f32_ptr : !llvm.ptr<f32>, %f32 : f32) {
// expected-error@+1 {{op operand #0 must be LLVM pointer to LLVM integer type or LLVM pointer type}}
// expected-error@+1 {{op operand #0 must be LLVM pointer to integer or LLVM pointer type}}
%0 = "llvm.cmpxchg"(%f32, %f32, %f32) {success_ordering=2,failure_ordering=2} : (f32, f32, f32) -> !llvm.struct<(f32, i1)>
llvm.return
}