forked from OSchip/llvm-project
[mlir][NFC] Update AMX/LLVM/NVVM/X86 vector operations to use `hasVerifier` instead of `verifier`
The verifier field is deprecated, and slated for removal. Differential Revision: https://reviews.llvm.org/D118819
This commit is contained in:
parent
f7d85f010f
commit
38abdddf6f
|
@ -91,7 +91,6 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
|
|||
%0 = amx.tile_zero : vector<16x16xbf16>
|
||||
```
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let results = (outs
|
||||
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -100,6 +99,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
|
|||
}
|
||||
}];
|
||||
let assemblyFormat = "attr-dict `:` type($res)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -120,7 +120,6 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
|
|||
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
|
||||
```
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
|
||||
Variadic<Index>:$indices);
|
||||
let results = (outs
|
||||
|
@ -135,6 +134,7 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
|
|||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
|
||||
"type($base) `into` type($res)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TileStoreOp : AMX_Op<"tile_store"> {
|
||||
|
@ -151,7 +151,6 @@ def TileStoreOp : AMX_Op<"tile_store"> {
|
|||
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
|
||||
```
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
|
||||
Variadic<Index>:$indices,
|
||||
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
|
||||
|
@ -165,6 +164,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
|
|||
}];
|
||||
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
|
||||
"type($base) `,` type($val)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//
|
||||
|
@ -186,7 +186,6 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
: vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
|
||||
```
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
|
||||
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
|
||||
VectorOfRankAndType<[2], [F32, BF16]>:$acc);
|
||||
|
@ -204,6 +203,7 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
}];
|
||||
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
|
||||
"type($lhs) `,` type($rhs) `,` type($acc) ";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> {
|
||||
|
@ -224,7 +224,6 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
|
||||
```
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
|
||||
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
|
||||
VectorOfRankAndType<[2], [I32, I8]>:$acc,
|
||||
|
@ -245,6 +244,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
|
|||
}];
|
||||
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
|
||||
"type($lhs) `,` type($rhs) `,` type($acc) ";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -351,9 +351,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
|
|||
constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let verifier = [{
|
||||
return ::verify(*this);
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
|
||||
|
@ -386,7 +384,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
|
|||
CArg<"bool", "false">:$isNonTemporal)>];
|
||||
let parser = [{ return parseLoadOp(parser, result); }];
|
||||
let printer = [{ printLoadOp(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
|
||||
|
@ -410,7 +408,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
|
|||
];
|
||||
let parser = [{ return parseStoreOp(parser, result); }];
|
||||
let printer = [{ printStoreOp(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
// Casts.
|
||||
|
@ -494,18 +492,18 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
|
|||
build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
|
||||
unwindOps, normal, unwind);
|
||||
}]>];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return parseInvokeOp(parser, result); }];
|
||||
let printer = [{ printInvokeOp(p, *this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
|
||||
let arguments = (ins UnitAttr:$cleanup, Variadic<LLVM_Type>);
|
||||
let results = (outs LLVM_Type:$res);
|
||||
let builders = [LLVM_OneResultOpBuilder];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return parseLandingpadOp(parser, result); }];
|
||||
let printer = [{ printLandingpadOp(p, *this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_CallOp : LLVM_Op<"call",
|
||||
|
@ -562,9 +560,9 @@ def LLVM_CallOp : LLVM_Op<"call",
|
|||
build($_builder, $_state, results,
|
||||
StringAttr::get($_builder.getContext(), callee), operands);
|
||||
}]>];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return parseCallOp(parser, result); }];
|
||||
let printer = [{ printCallOp(p, *this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
|
||||
let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
|
||||
|
@ -575,9 +573,9 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
|
|||
let builders = [
|
||||
OpBuilder<(ins "Value":$vector, "Value":$position,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return parseExtractElementOp(parser, result); }];
|
||||
let printer = [{ printExtractElementOp(p, *this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
|
||||
let arguments = (ins LLVM_AnyAggregate:$container, ArrayAttr:$position);
|
||||
|
@ -586,10 +584,10 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
|
|||
$res = builder.CreateExtractValue($container, extractPosition($position));
|
||||
}];
|
||||
let builders = [LLVM_OneResultOpBuilder];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return parseExtractValueOp(parser, result); }];
|
||||
let printer = [{ printExtractValueOp(p, *this); }];
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
|
||||
let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value,
|
||||
|
@ -599,9 +597,9 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
|
|||
$res = builder.CreateInsertElement($vector, $value, $position);
|
||||
}];
|
||||
let builders = [LLVM_OneResultOpBuilder];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return parseInsertElementOp(parser, result); }];
|
||||
let printer = [{ printInsertElementOp(p, *this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
|
||||
let arguments = (ins LLVM_AnyAggregate:$container, LLVM_PrimitiveType:$value,
|
||||
|
@ -616,9 +614,9 @@ def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
|
|||
[{
|
||||
build($_builder, $_state, container.getType(), container, value, position);
|
||||
}]>];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return parseInsertValueOp(parser, result); }];
|
||||
let printer = [{ printInsertValueOp(p, *this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
|
||||
let arguments = (ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask);
|
||||
|
@ -631,16 +629,9 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
|
|||
let builders = [
|
||||
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask,
|
||||
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
|
||||
let verifier = [{
|
||||
auto type1 = getV1().getType();
|
||||
auto type2 = getV2().getType();
|
||||
if (::mlir::LLVM::getVectorElementType(type1) !=
|
||||
::mlir::LLVM::getVectorElementType(type2))
|
||||
return emitOpError("expected matching LLVM IR Dialect element types");
|
||||
return success();
|
||||
}];
|
||||
let parser = [{ return parseShuffleVectorOp(parser, result); }];
|
||||
let printer = [{ printShuffleVectorOp(p, *this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
// Misc operations.
|
||||
|
@ -718,27 +709,15 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> {
|
|||
builder.CreateRetVoid();
|
||||
}];
|
||||
|
||||
let verifier = [{
|
||||
if (getNumOperands() > 1)
|
||||
return emitOpError("expects at most 1 operand");
|
||||
return success();
|
||||
}];
|
||||
|
||||
let parser = [{ return parseReturnOp(parser, result); }];
|
||||
let printer = [{ printReturnOp(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> {
|
||||
let arguments = (ins LLVM_Type:$value);
|
||||
string llvmBuilder = [{ builder.CreateResume($value); }];
|
||||
let verifier = [{
|
||||
if (!isa_and_nonnull<LandingpadOp>(getValue().getDefiningOp()))
|
||||
return emitOpError("expects landingpad value as operand");
|
||||
// No check for personality of function - landingpad op verifies it.
|
||||
return success();
|
||||
}];
|
||||
|
||||
let assemblyFormat = "$value attr-dict `:` type($value)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
|
||||
string llvmBuilder = [{ builder.CreateUnreachable(); }];
|
||||
|
@ -761,7 +740,6 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
|
|||
VariadicSuccessor<AnySuccessor>:$caseDestinations
|
||||
);
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let assemblyFormat = [{
|
||||
$value `:` type($value) `,`
|
||||
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
|
||||
|
@ -769,6 +747,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
|
|||
$caseOperands, type($caseOperands)) `]`
|
||||
attr-dict
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$value,
|
||||
|
@ -924,7 +903,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", [NoSideEffect]> {
|
|||
}];
|
||||
|
||||
let assemblyFormat = "$global_name attr-dict `:` type($res)";
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_MetadataOp : LLVM_Op<"metadata", [
|
||||
|
@ -1175,7 +1154,7 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
|
|||
|
||||
let printer = "printGlobalOp(p, *this);";
|
||||
let parser = "return parseGlobalOp(parser, result);";
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [
|
||||
|
@ -1205,8 +1184,8 @@ def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [
|
|||
```
|
||||
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let assemblyFormat = "attr-dict";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [
|
||||
|
@ -1234,8 +1213,8 @@ def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [
|
|||
```
|
||||
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let assemblyFormat = "attr-dict";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_LLVMFuncOp : LLVM_Op<"func", [
|
||||
|
@ -1310,9 +1289,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
|
|||
LogicalResult verifyType();
|
||||
}];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let printer = [{ printLLVMFuncOp(p, *this); }];
|
||||
let parser = [{ return parseLLVMFuncOp(parser, result); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_NullOp
|
||||
|
@ -1402,8 +1381,8 @@ def LLVM_ConstantOp
|
|||
let results = (outs LLVM_Type:$res);
|
||||
let builders = [LLVM_OneResultOpBuilder];
|
||||
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
// Operations that correspond to LLVM intrinsics. With MLIR operation set being
|
||||
|
@ -1848,7 +1827,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
|
|||
}];
|
||||
let parser = [{ return parseAtomicRMWOp(parser, result); }];
|
||||
let printer = [{ printAtomicRMWOp(p, *this); }];
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>;
|
||||
|
@ -1878,7 +1857,7 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg"> {
|
|||
}];
|
||||
let parser = [{ return parseAtomicCmpXchgOp(parser, result); }];
|
||||
let printer = [{ printAtomicCmpXchgOp(p, *this); }];
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def LLVM_AssumeOp : LLVM_Op<"intr.assume", []> {
|
||||
|
@ -1901,7 +1880,7 @@ def LLVM_FenceOp : LLVM_Op<"fence"> {
|
|||
}];
|
||||
let parser = [{ return parseFenceOp(parser, result); }];
|
||||
let printer = [{ printFenceOp(p, *this); }];
|
||||
let verifier = "return ::verify(*this);";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def AsmATT : LLVM_EnumAttrCase<
|
||||
|
|
|
@ -22,12 +22,16 @@
|
|||
|
||||
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"
|
||||
|
||||
namespace mlir {
|
||||
namespace NVVM {
|
||||
/// Return the element type and number of elements associated with a wmma matrix
|
||||
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
|
||||
/// WMMA_REGS structure.
|
||||
std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
|
||||
mlir::NVVM::MMAFrag frag,
|
||||
mlir::MLIRContext *context);
|
||||
} // namespace NVVM
|
||||
} // namespace mlir
|
||||
|
||||
///// Ops /////
|
||||
#define GET_ATTRDEF_CLASSES
|
||||
|
|
|
@ -131,22 +131,11 @@ def NVVM_ShflOp :
|
|||
$res = createIntrinsicCall(builder,
|
||||
intId, {$dst, $val, $offset, $mask_and_clamp});
|
||||
}];
|
||||
let verifier = [{
|
||||
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
|
||||
return success();
|
||||
auto type = getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
auto elementType = (type && type.getBody().size() == 2)
|
||||
? type.getBody()[1].dyn_cast<IntegerType>()
|
||||
: nullptr;
|
||||
if (!elementType || elementType.getWidth() != 1)
|
||||
return emitError("expected return type to be a two-element struct with "
|
||||
"i1 as the second element");
|
||||
return success();
|
||||
}];
|
||||
let assemblyFormat = [{
|
||||
$kind $dst `,` $val `,` $offset `,` $mask_and_clamp attr-dict
|
||||
`:` type($val) `->` type($res)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def NVVM_VoteBallotOp :
|
||||
|
@ -183,12 +172,8 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
|
|||
}
|
||||
createIntrinsicCall(builder, id, {$dst, $src});
|
||||
}];
|
||||
let verifier = [{
|
||||
if (size() != 4 && size() != 8 && size() != 16)
|
||||
return emitError("expected byte size to be either 4, 8 or 16.");
|
||||
return success();
|
||||
}];
|
||||
let assemblyFormat = "$dst `,` $src `,` $size attr-dict";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
|
||||
|
@ -220,7 +205,7 @@ def NVVM_MmaOp :
|
|||
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
|
||||
}];
|
||||
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
/// Helpers to instantiate different version of wmma intrinsics.
|
||||
|
@ -538,7 +523,7 @@ def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">,
|
|||
}];
|
||||
|
||||
let assemblyFormat = "$ptr `,` $stride attr-dict `:` functional-type($ptr, $res)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
|
||||
|
@ -593,7 +578,7 @@ def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
|
|||
}];
|
||||
|
||||
let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
// Base class for all the variants of WMMA mmaOps that may be defined.
|
||||
|
@ -647,7 +632,7 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
|
|||
}];
|
||||
|
||||
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // NVVMIR_OPS
|
||||
|
|
|
@ -76,7 +76,6 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
|
|||
with their respective bit set in writemask `k`) to `dst`, and pass through the
|
||||
remaining elements from `src`.
|
||||
}];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let arguments = (ins VectorOfLengthAndType<[16, 8],
|
||||
[I1]>:$k,
|
||||
VectorOfLengthAndType<[16, 8],
|
||||
|
@ -88,6 +87,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
|
|||
[F32, I32, F64, I64]>:$dst);
|
||||
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
|
||||
" `:` type($dst) (`,` type($src)^)?";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [
|
||||
|
|
|
@ -358,22 +358,19 @@ struct WmmaElementwiseOpToNVVMLowering
|
|||
|
||||
} // namespace
|
||||
|
||||
namespace mlir {
|
||||
|
||||
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
|
||||
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
|
||||
LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
|
||||
NVVM::MMAFrag frag = convertOperand(type.getOperand());
|
||||
NVVM::MMATypes eltType = getElementType(type);
|
||||
std::pair<Type, unsigned> typeInfo =
|
||||
inferMMAType(eltType, frag, type.getContext());
|
||||
NVVM::inferMMAType(eltType, frag, type.getContext());
|
||||
return LLVM::LLVMStructType::getLiteral(
|
||||
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
|
||||
}
|
||||
|
||||
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
|
||||
RewritePatternSet &patterns) {
|
||||
void mlir::populateGpuWMMAToNVVMConversionPatterns(
|
||||
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
|
||||
patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
|
||||
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
|
||||
WmmaElementwiseOpToNVVMLowering>(converter);
|
||||
}
|
||||
} // namespace mlir
|
||||
|
|
|
@ -52,53 +52,55 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(amx::TileZeroOp op) {
|
||||
return verifyTileSize(op, op.getVectorType());
|
||||
LogicalResult amx::TileZeroOp::verify() {
|
||||
return verifyTileSize(*this, getVectorType());
|
||||
}
|
||||
|
||||
static LogicalResult verify(amx::TileLoadOp op) {
|
||||
unsigned rank = op.getMemRefType().getRank();
|
||||
if (llvm::size(op.indices()) != rank)
|
||||
return op.emitOpError("requires ") << rank << " indices";
|
||||
return verifyTileSize(op, op.getVectorType());
|
||||
LogicalResult amx::TileLoadOp::verify() {
|
||||
unsigned rank = getMemRefType().getRank();
|
||||
if (indices().size() != rank)
|
||||
return emitOpError("requires ") << rank << " indices";
|
||||
return verifyTileSize(*this, getVectorType());
|
||||
}
|
||||
|
||||
static LogicalResult verify(amx::TileStoreOp op) {
|
||||
unsigned rank = op.getMemRefType().getRank();
|
||||
if (llvm::size(op.indices()) != rank)
|
||||
return op.emitOpError("requires ") << rank << " indices";
|
||||
return verifyTileSize(op, op.getVectorType());
|
||||
LogicalResult amx::TileStoreOp::verify() {
|
||||
unsigned rank = getMemRefType().getRank();
|
||||
if (indices().size() != rank)
|
||||
return emitOpError("requires ") << rank << " indices";
|
||||
return verifyTileSize(*this, getVectorType());
|
||||
}
|
||||
|
||||
static LogicalResult verify(amx::TileMulFOp op) {
|
||||
VectorType aType = op.getLhsVectorType();
|
||||
VectorType bType = op.getRhsVectorType();
|
||||
VectorType cType = op.getVectorType();
|
||||
if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
|
||||
failed(verifyTileSize(op, cType)) ||
|
||||
failed(verifyMultShape(op, aType, bType, cType, 1)))
|
||||
LogicalResult amx::TileMulFOp::verify() {
|
||||
VectorType aType = getLhsVectorType();
|
||||
VectorType bType = getRhsVectorType();
|
||||
VectorType cType = getVectorType();
|
||||
if (failed(verifyTileSize(*this, aType)) ||
|
||||
failed(verifyTileSize(*this, bType)) ||
|
||||
failed(verifyTileSize(*this, cType)) ||
|
||||
failed(verifyMultShape(*this, aType, bType, cType, 1)))
|
||||
return failure();
|
||||
Type ta = aType.getElementType();
|
||||
Type tb = bType.getElementType();
|
||||
Type tc = cType.getElementType();
|
||||
if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
|
||||
return op.emitOpError("unsupported type combination");
|
||||
return emitOpError("unsupported type combination");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(amx::TileMulIOp op) {
|
||||
VectorType aType = op.getLhsVectorType();
|
||||
VectorType bType = op.getRhsVectorType();
|
||||
VectorType cType = op.getVectorType();
|
||||
if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
|
||||
failed(verifyTileSize(op, cType)) ||
|
||||
failed(verifyMultShape(op, aType, bType, cType, 2)))
|
||||
LogicalResult amx::TileMulIOp::verify() {
|
||||
VectorType aType = getLhsVectorType();
|
||||
VectorType bType = getRhsVectorType();
|
||||
VectorType cType = getVectorType();
|
||||
if (failed(verifyTileSize(*this, aType)) ||
|
||||
failed(verifyTileSize(*this, bType)) ||
|
||||
failed(verifyTileSize(*this, cType)) ||
|
||||
failed(verifyMultShape(*this, aType, bType, cType, 2)))
|
||||
return failure();
|
||||
Type ta = aType.getElementType();
|
||||
Type tb = bType.getElementType();
|
||||
Type tc = cType.getElementType();
|
||||
if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
|
||||
return op.emitOpError("unsupported type combination");
|
||||
return emitOpError("unsupported type combination");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -334,18 +334,17 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
|
|||
p.printNewline();
|
||||
}
|
||||
|
||||
static LogicalResult verify(SwitchOp op) {
|
||||
if ((!op.getCaseValues() && !op.getCaseDestinations().empty()) ||
|
||||
(op.getCaseValues() &&
|
||||
op.getCaseValues()->size() !=
|
||||
static_cast<int64_t>(op.getCaseDestinations().size())))
|
||||
return op.emitOpError("expects number of case values to match number of "
|
||||
"case destinations");
|
||||
if (op.getBranchWeights() &&
|
||||
op.getBranchWeights()->size() != op.getNumSuccessors())
|
||||
return op.emitError("expects number of branch weights to match number of "
|
||||
"successors: ")
|
||||
<< op.getBranchWeights()->size() << " vs " << op.getNumSuccessors();
|
||||
LogicalResult SwitchOp::verify() {
|
||||
if ((!getCaseValues() && !getCaseDestinations().empty()) ||
|
||||
(getCaseValues() &&
|
||||
getCaseValues()->size() !=
|
||||
static_cast<int64_t>(getCaseDestinations().size())))
|
||||
return emitOpError("expects number of case values to match number of "
|
||||
"case destinations");
|
||||
if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
|
||||
return emitError("expects number of branch weights to match number of "
|
||||
"successors: ")
|
||||
<< getBranchWeights()->size() << " vs " << getNumSuccessors();
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -518,11 +517,11 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
|
|||
});
|
||||
}
|
||||
|
||||
LogicalResult verify(LLVM::GEPOp gepOp) {
|
||||
LogicalResult LLVM::GEPOp::verify() {
|
||||
SmallVector<unsigned> indices;
|
||||
SmallVector<unsigned> structSizes;
|
||||
findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes);
|
||||
DenseIntElementsAttr structIndices = gepOp.getStructIndices();
|
||||
findKnownStructIndices(getBase().getType(), indices, &structSizes);
|
||||
DenseIntElementsAttr structIndices = getStructIndices();
|
||||
for (unsigned i : llvm::seq<unsigned>(0, indices.size())) {
|
||||
unsigned index = indices[i];
|
||||
// GEP may not be indexing as deep as some structs nested in the type.
|
||||
|
@ -531,11 +530,11 @@ LogicalResult verify(LLVM::GEPOp gepOp) {
|
|||
|
||||
int32_t staticIndex = structIndices.getValues<int32_t>()[index];
|
||||
if (staticIndex == LLVM::GEPOp::kDynamicIndex)
|
||||
return gepOp.emitOpError() << "expected index " << index
|
||||
<< " indexing a struct to be constant";
|
||||
return emitOpError() << "expected index " << index
|
||||
<< " indexing a struct to be constant";
|
||||
if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
|
||||
return gepOp.emitOpError()
|
||||
<< "index " << index << " indexing a struct is out of bounds";
|
||||
return emitOpError() << "index " << index
|
||||
<< " indexing a struct is out of bounds";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
@ -613,9 +612,7 @@ static LogicalResult verifyMemoryOpMetadata(Operation *op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(LoadOp op) {
|
||||
return verifyMemoryOpMetadata(op.getOperation());
|
||||
}
|
||||
LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
|
||||
|
||||
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
|
||||
Value addr, unsigned alignment, bool isVolatile,
|
||||
|
@ -675,9 +672,7 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
|
|||
// Builder, printer and parser for LLVM::StoreOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(StoreOp op) {
|
||||
return verifyMemoryOpMetadata(op.getOperation());
|
||||
}
|
||||
LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
|
||||
|
||||
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
|
||||
Value addr, unsigned alignment, bool isVolatile,
|
||||
|
@ -739,19 +734,18 @@ InvokeOp::getMutableSuccessorOperands(unsigned index) {
|
|||
: getUnwindDestOperandsMutable();
|
||||
}
|
||||
|
||||
static LogicalResult verify(InvokeOp op) {
|
||||
if (op.getNumResults() > 1)
|
||||
return op.emitOpError("must have 0 or 1 result");
|
||||
LogicalResult InvokeOp::verify() {
|
||||
if (getNumResults() > 1)
|
||||
return emitOpError("must have 0 or 1 result");
|
||||
|
||||
Block *unwindDest = op.getUnwindDest();
|
||||
Block *unwindDest = getUnwindDest();
|
||||
if (unwindDest->empty())
|
||||
return op.emitError(
|
||||
"must have at least one operation in unwind destination");
|
||||
return emitError("must have at least one operation in unwind destination");
|
||||
|
||||
// In unwind destination, first operation must be LandingpadOp
|
||||
if (!isa<LandingpadOp>(unwindDest->front()))
|
||||
return op.emitError("first operation in unwind destination should be a "
|
||||
"llvm.landingpad operation");
|
||||
return emitError("first operation in unwind destination should be a "
|
||||
"llvm.landingpad operation");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -880,20 +874,20 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
|
|||
/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
|
||||
///===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(LandingpadOp op) {
|
||||
LogicalResult LandingpadOp::verify() {
|
||||
Value value;
|
||||
if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) {
|
||||
if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
|
||||
if (!func.getPersonality().hasValue())
|
||||
return op.emitError(
|
||||
return emitError(
|
||||
"llvm.landingpad needs to be in a function with a personality");
|
||||
}
|
||||
|
||||
if (!op.getCleanup() && op.getOperands().empty())
|
||||
return op.emitError("landingpad instruction expects at least one clause or "
|
||||
"cleanup attribute");
|
||||
if (!getCleanup() && getOperands().empty())
|
||||
return emitError("landingpad instruction expects at least one clause or "
|
||||
"cleanup attribute");
|
||||
|
||||
for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
|
||||
value = op.getOperand(idx);
|
||||
for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
|
||||
value = getOperand(idx);
|
||||
bool isFilter = value.getType().isa<LLVMArrayType>();
|
||||
if (isFilter) {
|
||||
// FIXME: Verify filter clauses when arrays are appropriately handled
|
||||
|
@ -903,8 +897,7 @@ static LogicalResult verify(LandingpadOp op) {
|
|||
if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
|
||||
if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
|
||||
continue;
|
||||
return op.emitError("constant clauses expected")
|
||||
.attachNote(bcOp.getLoc())
|
||||
return emitError("constant clauses expected").attachNote(bcOp.getLoc())
|
||||
<< "global addresses expected as operand to "
|
||||
"bitcast used in clauses for landingpad";
|
||||
}
|
||||
|
@ -913,7 +906,7 @@ static LogicalResult verify(LandingpadOp op) {
|
|||
continue;
|
||||
if (value.getDefiningOp<AddressOfOp>())
|
||||
continue;
|
||||
return op.emitError("clause #")
|
||||
return emitError("clause #")
|
||||
<< idx << " is not a known constant - null, addressof, bitcast";
|
||||
}
|
||||
}
|
||||
|
@ -970,9 +963,9 @@ static ParseResult parseLandingpadOp(OpAsmParser &parser,
|
|||
// Verifying/Printing/parsing for LLVM::CallOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(CallOp &op) {
|
||||
if (op.getNumResults() > 1)
|
||||
return op.emitOpError("must have 0 or 1 result");
|
||||
LogicalResult CallOp::verify() {
|
||||
if (getNumResults() > 1)
|
||||
return emitOpError("must have 0 or 1 result");
|
||||
|
||||
// Type for the callee, we'll get it differently depending if it is a direct
|
||||
// or indirect call.
|
||||
|
@ -981,75 +974,73 @@ static LogicalResult verify(CallOp &op) {
|
|||
bool isIndirect = false;
|
||||
|
||||
// If this is an indirect call, the callee attribute is missing.
|
||||
FlatSymbolRefAttr calleeName = op.getCalleeAttr();
|
||||
FlatSymbolRefAttr calleeName = getCalleeAttr();
|
||||
if (!calleeName) {
|
||||
isIndirect = true;
|
||||
if (!op.getNumOperands())
|
||||
return op.emitOpError(
|
||||
if (!getNumOperands())
|
||||
return emitOpError(
|
||||
"must have either a `callee` attribute or at least an operand");
|
||||
auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>();
|
||||
auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>();
|
||||
if (!ptrType)
|
||||
return op.emitOpError("indirect call expects a pointer as callee: ")
|
||||
return emitOpError("indirect call expects a pointer as callee: ")
|
||||
<< ptrType;
|
||||
fnType = ptrType.getElementType();
|
||||
} else {
|
||||
Operation *callee =
|
||||
SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr());
|
||||
SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
|
||||
if (!callee)
|
||||
return op.emitOpError()
|
||||
return emitOpError()
|
||||
<< "'" << calleeName.getValue()
|
||||
<< "' does not reference a symbol in the current scope";
|
||||
auto fn = dyn_cast<LLVMFuncOp>(callee);
|
||||
if (!fn)
|
||||
return op.emitOpError() << "'" << calleeName.getValue()
|
||||
<< "' does not reference a valid LLVM function";
|
||||
return emitOpError() << "'" << calleeName.getValue()
|
||||
<< "' does not reference a valid LLVM function";
|
||||
|
||||
fnType = fn.getType();
|
||||
}
|
||||
|
||||
LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
|
||||
if (!funcType)
|
||||
return op.emitOpError("callee does not have a functional type: ") << fnType;
|
||||
return emitOpError("callee does not have a functional type: ") << fnType;
|
||||
|
||||
// Verify that the operand and result types match the callee.
|
||||
|
||||
if (!funcType.isVarArg() &&
|
||||
funcType.getNumParams() != (op.getNumOperands() - isIndirect))
|
||||
return op.emitOpError()
|
||||
<< "incorrect number of operands ("
|
||||
<< (op.getNumOperands() - isIndirect)
|
||||
<< ") for callee (expecting: " << funcType.getNumParams() << ")";
|
||||
funcType.getNumParams() != (getNumOperands() - isIndirect))
|
||||
return emitOpError() << "incorrect number of operands ("
|
||||
<< (getNumOperands() - isIndirect)
|
||||
<< ") for callee (expecting: "
|
||||
<< funcType.getNumParams() << ")";
|
||||
|
||||
if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
|
||||
return op.emitOpError() << "incorrect number of operands ("
|
||||
<< (op.getNumOperands() - isIndirect)
|
||||
<< ") for varargs callee (expecting at least: "
|
||||
<< funcType.getNumParams() << ")";
|
||||
if (funcType.getNumParams() > (getNumOperands() - isIndirect))
|
||||
return emitOpError() << "incorrect number of operands ("
|
||||
<< (getNumOperands() - isIndirect)
|
||||
<< ") for varargs callee (expecting at least: "
|
||||
<< funcType.getNumParams() << ")";
|
||||
|
||||
for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
|
||||
if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
|
||||
return op.emitOpError() << "operand type mismatch for operand " << i
|
||||
<< ": " << op.getOperand(i + isIndirect).getType()
|
||||
<< " != " << funcType.getParamType(i);
|
||||
if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
|
||||
return emitOpError() << "operand type mismatch for operand " << i << ": "
|
||||
<< getOperand(i + isIndirect).getType()
|
||||
<< " != " << funcType.getParamType(i);
|
||||
|
||||
if (op.getNumResults() == 0 &&
|
||||
if (getNumResults() == 0 &&
|
||||
!funcType.getReturnType().isa<LLVM::LLVMVoidType>())
|
||||
return op.emitOpError() << "expected function call to produce a value";
|
||||
return emitOpError() << "expected function call to produce a value";
|
||||
|
||||
if (op.getNumResults() != 0 &&
|
||||
if (getNumResults() != 0 &&
|
||||
funcType.getReturnType().isa<LLVM::LLVMVoidType>())
|
||||
return op.emitOpError()
|
||||
return emitOpError()
|
||||
<< "calling function with void result must not produce values";
|
||||
|
||||
if (op.getNumResults() > 1)
|
||||
return op.emitOpError()
|
||||
if (getNumResults() > 1)
|
||||
return emitOpError()
|
||||
<< "expected LLVM function call to produce 0 or 1 result";
|
||||
|
||||
if (op.getNumResults() &&
|
||||
op.getResult(0).getType() != funcType.getReturnType())
|
||||
return op.emitOpError()
|
||||
<< "result type mismatch: " << op.getResult(0).getType()
|
||||
<< " != " << funcType.getReturnType();
|
||||
if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
|
||||
return emitOpError() << "result type mismatch: " << getResult(0).getType()
|
||||
<< " != " << funcType.getReturnType();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1200,17 +1191,17 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExtractElementOp op) {
|
||||
Type vectorType = op.getVector().getType();
|
||||
LogicalResult ExtractElementOp::verify() {
|
||||
Type vectorType = getVector().getType();
|
||||
if (!LLVM::isCompatibleVectorType(vectorType))
|
||||
return op->emitOpError("expected LLVM dialect-compatible vector type for "
|
||||
"operand #1, got")
|
||||
return emitOpError("expected LLVM dialect-compatible vector type for "
|
||||
"operand #1, got")
|
||||
<< vectorType;
|
||||
Type valueType = LLVM::getVectorElementType(vectorType);
|
||||
if (valueType != op.getRes().getType())
|
||||
return op.emitOpError() << "Type mismatch: extracting from " << vectorType
|
||||
<< " should produce " << valueType
|
||||
<< " but this op returns " << op.getRes().getType();
|
||||
if (valueType != getRes().getType())
|
||||
return emitOpError() << "Type mismatch: extracting from " << vectorType
|
||||
<< " should produce " << valueType
|
||||
<< " but this op returns " << getRes().getType();
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1367,17 +1358,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExtractValueOp op) {
|
||||
Type valueType = getInsertExtractValueElementType(op.getContainer().getType(),
|
||||
op.getPositionAttr(), op);
|
||||
LogicalResult ExtractValueOp::verify() {
|
||||
Type valueType = getInsertExtractValueElementType(getContainer().getType(),
|
||||
getPositionAttr(), *this);
|
||||
if (!valueType)
|
||||
return failure();
|
||||
|
||||
if (op.getRes().getType() != valueType)
|
||||
return op.emitOpError()
|
||||
<< "Type mismatch: extracting from " << op.getContainer().getType()
|
||||
<< " should produce " << valueType << " but this op returns "
|
||||
<< op.getRes().getType();
|
||||
if (getRes().getType() != valueType)
|
||||
return emitOpError() << "Type mismatch: extracting from "
|
||||
<< getContainer().getType() << " should produce "
|
||||
<< valueType << " but this op returns "
|
||||
<< getRes().getType();
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1423,14 +1414,15 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertElementOp op) {
|
||||
Type valueType = LLVM::getVectorElementType(op.getVector().getType());
|
||||
if (valueType != op.getValue().getType())
|
||||
return op.emitOpError()
|
||||
<< "Type mismatch: cannot insert " << op.getValue().getType()
|
||||
<< " into " << op.getVector().getType();
|
||||
LogicalResult InsertElementOp::verify() {
|
||||
Type valueType = LLVM::getVectorElementType(getVector().getType());
|
||||
if (valueType != getValue().getType())
|
||||
return emitOpError() << "Type mismatch: cannot insert "
|
||||
<< getValue().getType() << " into "
|
||||
<< getVector().getType();
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Printing/parsing for LLVM::InsertValueOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1473,16 +1465,16 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(InsertValueOp op) {
|
||||
Type valueType = getInsertExtractValueElementType(op.getContainer().getType(),
|
||||
op.getPositionAttr(), op);
|
||||
LogicalResult InsertValueOp::verify() {
|
||||
Type valueType = getInsertExtractValueElementType(getContainer().getType(),
|
||||
getPositionAttr(), *this);
|
||||
if (!valueType)
|
||||
return failure();
|
||||
|
||||
if (op.getValue().getType() != valueType)
|
||||
return op.emitOpError()
|
||||
<< "Type mismatch: cannot insert " << op.getValue().getType()
|
||||
<< " into " << op.getContainer().getType();
|
||||
if (getValue().getType() != valueType)
|
||||
return emitOpError() << "Type mismatch: cannot insert "
|
||||
<< getValue().getType() << " into "
|
||||
<< getContainer().getType();
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1519,28 +1511,28 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ReturnOp op) {
|
||||
if (op->getNumOperands() > 1)
|
||||
return op->emitOpError("expected at most 1 operand");
|
||||
LogicalResult ReturnOp::verify() {
|
||||
if (getNumOperands() > 1)
|
||||
return emitOpError("expected at most 1 operand");
|
||||
|
||||
if (auto parent = op->getParentOfType<LLVMFuncOp>()) {
|
||||
if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
|
||||
Type expectedType = parent.getType().getReturnType();
|
||||
if (expectedType.isa<LLVMVoidType>()) {
|
||||
if (op->getNumOperands() == 0)
|
||||
if (getNumOperands() == 0)
|
||||
return success();
|
||||
InFlightDiagnostic diag = op->emitOpError("expected no operands");
|
||||
InFlightDiagnostic diag = emitOpError("expected no operands");
|
||||
diag.attachNote(parent->getLoc()) << "when returning from function";
|
||||
return diag;
|
||||
}
|
||||
if (op->getNumOperands() == 0) {
|
||||
if (getNumOperands() == 0) {
|
||||
if (expectedType.isa<LLVMVoidType>())
|
||||
return success();
|
||||
InFlightDiagnostic diag = op->emitOpError("expected 1 operand");
|
||||
InFlightDiagnostic diag = emitOpError("expected 1 operand");
|
||||
diag.attachNote(parent->getLoc()) << "when returning from function";
|
||||
return diag;
|
||||
}
|
||||
if (expectedType != op->getOperand(0).getType()) {
|
||||
InFlightDiagnostic diag = op->emitOpError("mismatching result types");
|
||||
if (expectedType != getOperand(0).getType()) {
|
||||
InFlightDiagnostic diag = emitOpError("mismatching result types");
|
||||
diag.attachNote(parent->getLoc()) << "when returning from function";
|
||||
return diag;
|
||||
}
|
||||
|
@ -1548,6 +1540,17 @@ static LogicalResult verify(ReturnOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResumeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ResumeOp::verify() {
|
||||
if (!getValue().getDefiningOp<LandingpadOp>())
|
||||
return emitOpError("expects landingpad value as operand");
|
||||
// No check for personality of function - landingpad op verifies it.
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Verifier for LLVM::AddressOfOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1572,22 +1575,22 @@ LLVMFuncOp AddressOfOp::getFunction() {
|
|||
getGlobalName());
|
||||
}
|
||||
|
||||
static LogicalResult verify(AddressOfOp op) {
|
||||
auto global = op.getGlobal();
|
||||
auto function = op.getFunction();
|
||||
LogicalResult AddressOfOp::verify() {
|
||||
auto global = getGlobal();
|
||||
auto function = getFunction();
|
||||
if (!global && !function)
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
|
||||
|
||||
if (global &&
|
||||
LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) !=
|
||||
op.getResult().getType())
|
||||
return op.emitOpError(
|
||||
getResult().getType())
|
||||
return emitOpError(
|
||||
"the type must be a pointer to the type of the referenced global");
|
||||
|
||||
if (function && LLVM::LLVMPointerType::get(function.getType()) !=
|
||||
op.getResult().getType())
|
||||
return op.emitOpError(
|
||||
if (function &&
|
||||
LLVM::LLVMPointerType::get(function.getType()) != getResult().getType())
|
||||
return emitOpError(
|
||||
"the type must be a pointer to the type of the referenced function");
|
||||
|
||||
return success();
|
||||
|
@ -1791,60 +1794,60 @@ static bool isZeroAttribute(Attribute value) {
|
|||
return false;
|
||||
}
|
||||
|
||||
static LogicalResult verify(GlobalOp op) {
|
||||
if (!LLVMPointerType::isValidElementType(op.getType()))
|
||||
return op.emitOpError(
|
||||
LogicalResult GlobalOp::verify() {
|
||||
if (!LLVMPointerType::isValidElementType(getType()))
|
||||
return emitOpError(
|
||||
"expects type to be a valid element type for an LLVM pointer");
|
||||
if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp()))
|
||||
return op.emitOpError("must appear at the module level");
|
||||
if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
|
||||
return emitOpError("must appear at the module level");
|
||||
|
||||
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
|
||||
auto type = op.getType().dyn_cast<LLVMArrayType>();
|
||||
if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) {
|
||||
auto type = getType().dyn_cast<LLVMArrayType>();
|
||||
IntegerType elementType =
|
||||
type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
|
||||
if (!elementType || elementType.getWidth() != 8 ||
|
||||
type.getNumElements() != strAttr.getValue().size())
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"requires an i8 array type of the length equal to that of the string "
|
||||
"attribute");
|
||||
}
|
||||
|
||||
if (Block *b = op.getInitializerBlock()) {
|
||||
if (Block *b = getInitializerBlock()) {
|
||||
ReturnOp ret = cast<ReturnOp>(b->getTerminator());
|
||||
if (ret.operand_type_begin() == ret.operand_type_end())
|
||||
return op.emitOpError("initializer region cannot return void");
|
||||
if (*ret.operand_type_begin() != op.getType())
|
||||
return op.emitOpError("initializer region type ")
|
||||
return emitOpError("initializer region cannot return void");
|
||||
if (*ret.operand_type_begin() != getType())
|
||||
return emitOpError("initializer region type ")
|
||||
<< *ret.operand_type_begin() << " does not match global type "
|
||||
<< op.getType();
|
||||
<< getType();
|
||||
|
||||
if (op.getValueOrNull())
|
||||
return op.emitOpError("cannot have both initializer value and region");
|
||||
if (getValueOrNull())
|
||||
return emitOpError("cannot have both initializer value and region");
|
||||
}
|
||||
|
||||
if (op.getLinkage() == Linkage::Common) {
|
||||
if (Attribute value = op.getValueOrNull()) {
|
||||
if (getLinkage() == Linkage::Common) {
|
||||
if (Attribute value = getValueOrNull()) {
|
||||
if (!isZeroAttribute(value)) {
|
||||
return op.emitOpError()
|
||||
return emitOpError()
|
||||
<< "expected zero value for '"
|
||||
<< stringifyLinkage(Linkage::Common) << "' linkage";
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if (op.getLinkage() == Linkage::Appending) {
|
||||
if (!op.getType().isa<LLVMArrayType>()) {
|
||||
return op.emitOpError()
|
||||
<< "expected array type for '"
|
||||
<< stringifyLinkage(Linkage::Appending) << "' linkage";
|
||||
if (getLinkage() == Linkage::Appending) {
|
||||
if (!getType().isa<LLVMArrayType>()) {
|
||||
return emitOpError() << "expected array type for '"
|
||||
<< stringifyLinkage(Linkage::Appending)
|
||||
<< "' linkage";
|
||||
}
|
||||
}
|
||||
|
||||
Optional<uint64_t> alignAttr = op.getAlignment();
|
||||
Optional<uint64_t> alignAttr = getAlignment();
|
||||
if (alignAttr.hasValue()) {
|
||||
uint64_t value = alignAttr.getValue();
|
||||
if (!llvm::isPowerOf2_64(value))
|
||||
return op->emitError() << "alignment attribute is not a power of 2";
|
||||
return emitError() << "alignment attribute is not a power of 2";
|
||||
}
|
||||
|
||||
return success();
|
||||
|
@ -1864,9 +1867,9 @@ GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(GlobalCtorsOp op) {
|
||||
if (op.getCtors().size() != op.getPriorities().size())
|
||||
return op.emitError(
|
||||
LogicalResult GlobalCtorsOp::verify() {
|
||||
if (getCtors().size() != getPriorities().size())
|
||||
return emitError(
|
||||
"mismatch between the number of ctors and the number of priorities");
|
||||
return success();
|
||||
}
|
||||
|
@ -1885,9 +1888,9 @@ GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(GlobalDtorsOp op) {
|
||||
if (op.getDtors().size() != op.getPriorities().size())
|
||||
return op.emitError(
|
||||
LogicalResult GlobalDtorsOp::verify() {
|
||||
if (getDtors().size() != getPriorities().size())
|
||||
return emitError(
|
||||
"mismatch between the number of dtors and the number of priorities");
|
||||
return success();
|
||||
}
|
||||
|
@ -1940,6 +1943,14 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult ShuffleVectorOp::verify() {
|
||||
Type type1 = getV1().getType();
|
||||
Type type2 = getV2().getType();
|
||||
if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2))
|
||||
return emitOpError("expected matching LLVM IR Dialect element types");
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Implementations for LLVM::LLVMFuncOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2117,42 +2128,43 @@ LogicalResult LLVMFuncOp::verifyType() {
|
|||
// - external functions have 'external' or 'extern_weak' linkage;
|
||||
// - vararg is (currently) only supported for external functions;
|
||||
// - entry block arguments are of LLVM types and match the function signature.
|
||||
static LogicalResult verify(LLVMFuncOp op) {
|
||||
if (op.getLinkage() == LLVM::Linkage::Common)
|
||||
return op.emitOpError()
|
||||
<< "functions cannot have '"
|
||||
<< stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
|
||||
LogicalResult LLVMFuncOp::verify() {
|
||||
if (getLinkage() == LLVM::Linkage::Common)
|
||||
return emitOpError() << "functions cannot have '"
|
||||
<< stringifyLinkage(LLVM::Linkage::Common)
|
||||
<< "' linkage";
|
||||
|
||||
// Check to see if this function has a void return with a result attribute to
|
||||
// it. It isn't clear what semantics we would assign to that.
|
||||
if (op.getType().getReturnType().isa<LLVMVoidType>() &&
|
||||
!op.getResultAttrs(0).empty()) {
|
||||
return op.emitOpError()
|
||||
if (getType().getReturnType().isa<LLVMVoidType>() &&
|
||||
!getResultAttrs(0).empty()) {
|
||||
return emitOpError()
|
||||
<< "cannot attach result attributes to functions with a void return";
|
||||
}
|
||||
|
||||
if (op.isExternal()) {
|
||||
if (op.getLinkage() != LLVM::Linkage::External &&
|
||||
op.getLinkage() != LLVM::Linkage::ExternWeak)
|
||||
return op.emitOpError()
|
||||
<< "external functions must have '"
|
||||
<< stringifyLinkage(LLVM::Linkage::External) << "' or '"
|
||||
<< stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
|
||||
if (isExternal()) {
|
||||
if (getLinkage() != LLVM::Linkage::External &&
|
||||
getLinkage() != LLVM::Linkage::ExternWeak)
|
||||
return emitOpError() << "external functions must have '"
|
||||
<< stringifyLinkage(LLVM::Linkage::External)
|
||||
<< "' or '"
|
||||
<< stringifyLinkage(LLVM::Linkage::ExternWeak)
|
||||
<< "' linkage";
|
||||
return success();
|
||||
}
|
||||
|
||||
if (op.isVarArg())
|
||||
return op.emitOpError("only external functions can be variadic");
|
||||
if (isVarArg())
|
||||
return emitOpError("only external functions can be variadic");
|
||||
|
||||
unsigned numArguments = op.getType().getNumParams();
|
||||
Block &entryBlock = op.front();
|
||||
unsigned numArguments = getType().getNumParams();
|
||||
Block &entryBlock = front();
|
||||
for (unsigned i = 0; i < numArguments; ++i) {
|
||||
Type argType = entryBlock.getArgument(i).getType();
|
||||
if (!isCompatibleType(argType))
|
||||
return op.emitOpError("entry block argument #")
|
||||
return emitOpError("entry block argument #")
|
||||
<< i << " is not of LLVM type";
|
||||
if (op.getType().getParamType(i) != argType)
|
||||
return op.emitOpError("the type of entry block argument #")
|
||||
if (getType().getParamType(i) != argType)
|
||||
return emitOpError("the type of entry block argument #")
|
||||
<< i << " does not match the function signature";
|
||||
}
|
||||
|
||||
|
@ -2163,42 +2175,42 @@ static LogicalResult verify(LLVMFuncOp op) {
|
|||
// Verification for LLVM::ConstantOp.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(LLVM::ConstantOp op) {
|
||||
if (StringAttr sAttr = op.getValue().dyn_cast<StringAttr>()) {
|
||||
auto arrayType = op.getType().dyn_cast<LLVMArrayType>();
|
||||
LogicalResult LLVM::ConstantOp::verify() {
|
||||
if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) {
|
||||
auto arrayType = getType().dyn_cast<LLVMArrayType>();
|
||||
if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
|
||||
!arrayType.getElementType().isInteger(8)) {
|
||||
return op->emitOpError()
|
||||
<< "expected array type of " << sAttr.getValue().size()
|
||||
<< " i8 elements for the string constant";
|
||||
return emitOpError() << "expected array type of "
|
||||
<< sAttr.getValue().size()
|
||||
<< " i8 elements for the string constant";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
if (auto structType = op.getType().dyn_cast<LLVMStructType>()) {
|
||||
if (auto structType = getType().dyn_cast<LLVMStructType>()) {
|
||||
if (structType.getBody().size() != 2 ||
|
||||
structType.getBody()[0] != structType.getBody()[1]) {
|
||||
return op.emitError() << "expected struct type with two elements of the "
|
||||
"same type, the type of a complex constant";
|
||||
return emitError() << "expected struct type with two elements of the "
|
||||
"same type, the type of a complex constant";
|
||||
}
|
||||
|
||||
auto arrayAttr = op.getValue().dyn_cast<ArrayAttr>();
|
||||
auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
|
||||
if (!arrayAttr || arrayAttr.size() != 2 ||
|
||||
arrayAttr[0].getType() != arrayAttr[1].getType()) {
|
||||
return op.emitOpError() << "expected array attribute with two elements, "
|
||||
"representing a complex constant";
|
||||
return emitOpError() << "expected array attribute with two elements, "
|
||||
"representing a complex constant";
|
||||
}
|
||||
|
||||
Type elementType = structType.getBody()[0];
|
||||
if (!elementType
|
||||
.isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
|
||||
return op.emitError()
|
||||
return emitError()
|
||||
<< "expected struct element types to be floating point type or "
|
||||
"integer type";
|
||||
}
|
||||
return success();
|
||||
}
|
||||
if (!op.getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
|
||||
return op.emitOpError()
|
||||
if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
|
||||
return emitOpError()
|
||||
<< "only supports integer, float, string or elements attributes";
|
||||
return success();
|
||||
}
|
||||
|
@ -2294,42 +2306,40 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AtomicRMWOp op) {
|
||||
auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>();
|
||||
auto valType = op.getVal().getType();
|
||||
LogicalResult AtomicRMWOp::verify() {
|
||||
auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
|
||||
auto valType = getVal().getType();
|
||||
if (valType != ptrType.getElementType())
|
||||
return op.emitOpError("expected LLVM IR element type for operand #0 to "
|
||||
"match type for operand #1");
|
||||
auto resType = op.getRes().getType();
|
||||
return emitOpError("expected LLVM IR element type for operand #0 to "
|
||||
"match type for operand #1");
|
||||
auto resType = getRes().getType();
|
||||
if (resType != valType)
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"expected LLVM IR result type to match type for operand #1");
|
||||
if (op.getBinOp() == AtomicBinOp::fadd ||
|
||||
op.getBinOp() == AtomicBinOp::fsub) {
|
||||
if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) {
|
||||
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
|
||||
return op.emitOpError("expected LLVM IR floating point type");
|
||||
} else if (op.getBinOp() == AtomicBinOp::xchg) {
|
||||
return emitOpError("expected LLVM IR floating point type");
|
||||
} else if (getBinOp() == AtomicBinOp::xchg) {
|
||||
auto intType = valType.dyn_cast<IntegerType>();
|
||||
unsigned intBitWidth = intType ? intType.getWidth() : 0;
|
||||
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
|
||||
intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
|
||||
!valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
|
||||
!valType.isa<Float64Type>())
|
||||
return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
|
||||
return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
|
||||
} else {
|
||||
auto intType = valType.dyn_cast<IntegerType>();
|
||||
unsigned intBitWidth = intType ? intType.getWidth() : 0;
|
||||
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
|
||||
intBitWidth != 64)
|
||||
return op.emitOpError("expected LLVM IR integer type");
|
||||
return emitOpError("expected LLVM IR integer type");
|
||||
}
|
||||
|
||||
if (static_cast<unsigned>(op.getOrdering()) <
|
||||
if (static_cast<unsigned>(getOrdering()) <
|
||||
static_cast<unsigned>(AtomicOrdering::monotonic))
|
||||
return op.emitOpError()
|
||||
<< "expected at least '"
|
||||
<< stringifyAtomicOrdering(AtomicOrdering::monotonic)
|
||||
<< "' ordering";
|
||||
return emitOpError() << "expected at least '"
|
||||
<< stringifyAtomicOrdering(AtomicOrdering::monotonic)
|
||||
<< "' ordering";
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -2375,28 +2385,28 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AtomicCmpXchgOp op) {
|
||||
auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>();
|
||||
LogicalResult AtomicCmpXchgOp::verify() {
|
||||
auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
|
||||
if (!ptrType)
|
||||
return op.emitOpError("expected LLVM IR pointer type for operand #0");
|
||||
auto cmpType = op.getCmp().getType();
|
||||
auto valType = op.getVal().getType();
|
||||
return emitOpError("expected LLVM IR pointer type for operand #0");
|
||||
auto cmpType = getCmp().getType();
|
||||
auto valType = getVal().getType();
|
||||
if (cmpType != ptrType.getElementType() || cmpType != valType)
|
||||
return op.emitOpError("expected LLVM IR element type for operand #0 to "
|
||||
"match type for all other operands");
|
||||
return emitOpError("expected LLVM IR element type for operand #0 to "
|
||||
"match type for all other operands");
|
||||
auto intType = valType.dyn_cast<IntegerType>();
|
||||
unsigned intBitWidth = intType ? intType.getWidth() : 0;
|
||||
if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
|
||||
intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
|
||||
!valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
|
||||
!valType.isa<Float32Type>() && !valType.isa<Float64Type>())
|
||||
return op.emitOpError("unexpected LLVM IR type");
|
||||
if (op.getSuccessOrdering() < AtomicOrdering::monotonic ||
|
||||
op.getFailureOrdering() < AtomicOrdering::monotonic)
|
||||
return op.emitOpError("ordering must be at least 'monotonic'");
|
||||
if (op.getFailureOrdering() == AtomicOrdering::release ||
|
||||
op.getFailureOrdering() == AtomicOrdering::acq_rel)
|
||||
return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
|
||||
return emitOpError("unexpected LLVM IR type");
|
||||
if (getSuccessOrdering() < AtomicOrdering::monotonic ||
|
||||
getFailureOrdering() < AtomicOrdering::monotonic)
|
||||
return emitOpError("ordering must be at least 'monotonic'");
|
||||
if (getFailureOrdering() == AtomicOrdering::release ||
|
||||
getFailureOrdering() == AtomicOrdering::acq_rel)
|
||||
return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -2432,12 +2442,12 @@ static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
|
|||
p << stringifyAtomicOrdering(op.getOrdering());
|
||||
}
|
||||
|
||||
static LogicalResult verify(FenceOp &op) {
|
||||
if (op.getOrdering() == AtomicOrdering::not_atomic ||
|
||||
op.getOrdering() == AtomicOrdering::unordered ||
|
||||
op.getOrdering() == AtomicOrdering::monotonic)
|
||||
return op.emitOpError("can be given only acquire, release, acq_rel, "
|
||||
"and seq_cst orderings");
|
||||
LogicalResult FenceOp::verify() {
|
||||
if (getOrdering() == AtomicOrdering::not_atomic ||
|
||||
getOrdering() == AtomicOrdering::unordered ||
|
||||
getOrdering() == AtomicOrdering::monotonic)
|
||||
return emitOpError("can be given only acquire, release, acq_rel, "
|
||||
"and seq_cst orderings");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -62,8 +62,14 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
|
|||
parser.getNameLoc(), result.operands));
|
||||
}
|
||||
|
||||
static LogicalResult verify(MmaOp op) {
|
||||
MLIRContext *context = op.getContext();
|
||||
LogicalResult CpAsyncOp::verify() {
|
||||
if (size() != 4 && size() != 8 && size() != 16)
|
||||
return emitError("expected byte size to be either 4, 8 or 16.");
|
||||
return success();
|
||||
}
|
||||
|
||||
LogicalResult MmaOp::verify() {
|
||||
MLIRContext *context = getContext();
|
||||
auto f16Ty = Float16Type::get(context);
|
||||
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
|
||||
auto f32Ty = Float32Type::get(context);
|
||||
|
@ -72,44 +78,55 @@ static LogicalResult verify(MmaOp op) {
|
|||
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
|
||||
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
|
||||
|
||||
SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
|
||||
op.getOperandTypes().end());
|
||||
auto operandTypes = getOperandTypes();
|
||||
if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
|
||||
operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty}) {
|
||||
return op.emitOpError(
|
||||
"expected operands to be 4 <halfx2>s followed by either "
|
||||
"4 <halfx2>s or 8 floats");
|
||||
operandTypes != ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty}) {
|
||||
return emitOpError("expected operands to be 4 <halfx2>s followed by either "
|
||||
"4 <halfx2>s or 8 floats");
|
||||
}
|
||||
if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
|
||||
return op.emitOpError("expected result type to be a struct of either 4 "
|
||||
"<halfx2>s or 8 floats");
|
||||
if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) {
|
||||
return emitOpError("expected result type to be a struct of either 4 "
|
||||
"<halfx2>s or 8 floats");
|
||||
}
|
||||
|
||||
auto alayout = op->getAttrOfType<StringAttr>("alayout");
|
||||
auto blayout = op->getAttrOfType<StringAttr>("blayout");
|
||||
auto alayout = (*this)->getAttrOfType<StringAttr>("alayout");
|
||||
auto blayout = (*this)->getAttrOfType<StringAttr>("blayout");
|
||||
|
||||
if (!(alayout && blayout) ||
|
||||
!(alayout.getValue() == "row" || alayout.getValue() == "col") ||
|
||||
!(blayout.getValue() == "row" || blayout.getValue() == "col")) {
|
||||
return op.emitOpError(
|
||||
"alayout and blayout attributes must be set to either "
|
||||
"\"row\" or \"col\"");
|
||||
return emitOpError("alayout and blayout attributes must be set to either "
|
||||
"\"row\" or \"col\"");
|
||||
}
|
||||
|
||||
if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty} &&
|
||||
op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
|
||||
if (operandTypes == ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
|
||||
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
|
||||
f32Ty} &&
|
||||
getType() == f32x8StructTy && alayout.getValue() == "row" &&
|
||||
blayout.getValue() == "col") {
|
||||
return success();
|
||||
}
|
||||
return op.emitOpError("unimplemented mma.sync variant");
|
||||
return emitOpError("unimplemented mma.sync variant");
|
||||
}
|
||||
|
||||
std::pair<mlir::Type, unsigned>
|
||||
inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) {
|
||||
LogicalResult ShflOp::verify() {
|
||||
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
|
||||
return success();
|
||||
auto type = getType().dyn_cast<LLVM::LLVMStructType>();
|
||||
auto elementType = (type && type.getBody().size() == 2)
|
||||
? type.getBody()[1].dyn_cast<IntegerType>()
|
||||
: nullptr;
|
||||
if (!elementType || elementType.getWidth() != 1)
|
||||
return emitError("expected return type to be a two-element struct with "
|
||||
"i1 as the second element");
|
||||
return success();
|
||||
}
|
||||
|
||||
std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
|
||||
NVVM::MMAFrag frag,
|
||||
MLIRContext *context) {
|
||||
unsigned numberElements = 0;
|
||||
Type elementType;
|
||||
OpBuilder builder(context);
|
||||
|
@ -131,76 +148,72 @@ inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) {
|
|||
return std::make_pair(elementType, numberElements);
|
||||
}
|
||||
|
||||
static LogicalResult verify(NVVM::WMMALoadOp op) {
|
||||
LogicalResult NVVM::WMMALoadOp::verify() {
|
||||
unsigned addressSpace =
|
||||
op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
|
||||
ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
|
||||
if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
|
||||
return op.emitOpError("expected source pointer in memory "
|
||||
"space 0, 1, 3");
|
||||
return emitOpError("expected source pointer in memory "
|
||||
"space 0, 1, 3");
|
||||
|
||||
if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
|
||||
op.eltype(), op.frag()) == 0)
|
||||
return op.emitOpError() << "invalid attribute combination";
|
||||
if (NVVM::WMMALoadOp::getIntrinsicID(m(), n(), k(), layout(), eltype(),
|
||||
frag()) == 0)
|
||||
return emitOpError() << "invalid attribute combination";
|
||||
std::pair<Type, unsigned> typeInfo =
|
||||
inferMMAType(op.eltype(), op.frag(), op.getContext());
|
||||
inferMMAType(eltype(), frag(), getContext());
|
||||
Type dstType = LLVM::LLVMStructType::getLiteral(
|
||||
op.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
|
||||
if (op.getType() != dstType)
|
||||
return op.emitOpError("expected destination type is a structure of ")
|
||||
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
|
||||
if (getType() != dstType)
|
||||
return emitOpError("expected destination type is a structure of ")
|
||||
<< typeInfo.second << " elements of type " << typeInfo.first;
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(NVVM::WMMAStoreOp op) {
|
||||
LogicalResult NVVM::WMMAStoreOp::verify() {
|
||||
unsigned addressSpace =
|
||||
op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
|
||||
ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
|
||||
if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
|
||||
return op.emitOpError("expected operands to be a source pointer in memory "
|
||||
"space 0, 1, 3");
|
||||
return emitOpError("expected operands to be a source pointer in memory "
|
||||
"space 0, 1, 3");
|
||||
|
||||
if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
|
||||
op.eltype()) == 0)
|
||||
return op.emitOpError() << "invalid attribute combination";
|
||||
if (NVVM::WMMAStoreOp::getIntrinsicID(m(), n(), k(), layout(), eltype()) == 0)
|
||||
return emitOpError() << "invalid attribute combination";
|
||||
std::pair<Type, unsigned> typeInfo =
|
||||
inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext());
|
||||
if (op.args().size() != typeInfo.second)
|
||||
return op.emitOpError()
|
||||
<< "expected " << typeInfo.second << " data operands";
|
||||
if (llvm::any_of(op.args(), [&typeInfo](Value operands) {
|
||||
inferMMAType(eltype(), NVVM::MMAFrag::c, getContext());
|
||||
if (args().size() != typeInfo.second)
|
||||
return emitOpError() << "expected " << typeInfo.second << " data operands";
|
||||
if (llvm::any_of(args(), [&typeInfo](Value operands) {
|
||||
return operands.getType() != typeInfo.first;
|
||||
}))
|
||||
return op.emitOpError()
|
||||
<< "expected data operands of type " << typeInfo.first;
|
||||
return emitOpError() << "expected data operands of type " << typeInfo.first;
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(NVVM::WMMAMmaOp op) {
|
||||
if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(),
|
||||
op.layoutB(), op.eltypeA(),
|
||||
op.eltypeB()) == 0)
|
||||
return op.emitOpError() << "invalid attribute combination";
|
||||
LogicalResult NVVM::WMMAMmaOp::verify() {
|
||||
if (NVVM::WMMAMmaOp::getIntrinsicID(m(), n(), k(), layoutA(), layoutB(),
|
||||
eltypeA(), eltypeB()) == 0)
|
||||
return emitOpError() << "invalid attribute combination";
|
||||
std::pair<Type, unsigned> typeInfoA =
|
||||
inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext());
|
||||
inferMMAType(eltypeA(), NVVM::MMAFrag::a, getContext());
|
||||
std::pair<Type, unsigned> typeInfoB =
|
||||
inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext());
|
||||
inferMMAType(eltypeA(), NVVM::MMAFrag::b, getContext());
|
||||
std::pair<Type, unsigned> typeInfoC =
|
||||
inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext());
|
||||
inferMMAType(eltypeB(), NVVM::MMAFrag::c, getContext());
|
||||
SmallVector<Type, 32> arguments;
|
||||
arguments.append(typeInfoA.second, typeInfoA.first);
|
||||
arguments.append(typeInfoB.second, typeInfoB.first);
|
||||
arguments.append(typeInfoC.second, typeInfoC.first);
|
||||
unsigned numArgs = arguments.size();
|
||||
if (op.args().size() != numArgs)
|
||||
return op.emitOpError() << "expected " << numArgs << " arguments";
|
||||
if (args().size() != numArgs)
|
||||
return emitOpError() << "expected " << numArgs << " arguments";
|
||||
for (unsigned i = 0; i < numArgs; i++) {
|
||||
if (op.args()[i].getType() != arguments[i])
|
||||
return op.emitOpError()
|
||||
<< "expected argument " << i << " to be of type " << arguments[i];
|
||||
if (args()[i].getType() != arguments[i])
|
||||
return emitOpError() << "expected argument " << i << " to be of type "
|
||||
<< arguments[i];
|
||||
}
|
||||
Type dstType = LLVM::LLVMStructType::getLiteral(
|
||||
op.getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
|
||||
if (op.getType() != dstType)
|
||||
return op.emitOpError("expected destination type is a structure of ")
|
||||
getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
|
||||
if (getType() != dstType)
|
||||
return emitOpError("expected destination type is a structure of ")
|
||||
<< typeInfoC.second << " elements of type " << typeInfoC.first;
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -28,17 +28,15 @@ void x86vector::X86VectorDialect::initialize() {
|
|||
>();
|
||||
}
|
||||
|
||||
static LogicalResult verify(x86vector::MaskCompressOp op) {
|
||||
if (op.src() && op.constant_src())
|
||||
return emitError(op.getLoc(), "cannot use both src and constant_src");
|
||||
LogicalResult x86vector::MaskCompressOp::verify() {
|
||||
if (src() && constant_src())
|
||||
return emitError("cannot use both src and constant_src");
|
||||
|
||||
if (op.src() && (op.src().getType() != op.dst().getType()))
|
||||
return emitError(op.getLoc(),
|
||||
"failed to verify that src and dst have same type");
|
||||
if (src() && (src().getType() != dst().getType()))
|
||||
return emitError("failed to verify that src and dst have same type");
|
||||
|
||||
if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType()))
|
||||
if (constant_src() && (constant_src()->getType() != dst().getType()))
|
||||
return emitError(
|
||||
op.getLoc(),
|
||||
"failed to verify that constant_src and dst have same type");
|
||||
|
||||
return success();
|
||||
|
|
Loading…
Reference in New Issue