[mlir][NFC] Update AMX/LLVM/NVVM/X86 vector operations to use `hasVerifier` instead of `verifier`

The verifier field is deprecated, and slated for removal.

Differential Revision: https://reviews.llvm.org/D118819
This commit is contained in:
River Riddle 2022-02-02 10:16:28 -08:00
parent f7d85f010f
commit 38abdddf6f
10 changed files with 404 additions and 416 deletions

View File

@ -91,7 +91,6 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
%0 = amx.tile_zero : vector<16x16xbf16>
```
}];
let verifier = [{ return ::verify(*this); }];
let results = (outs
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res);
let extraClassDeclaration = [{
@ -100,6 +99,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> {
}
}];
let assemblyFormat = "attr-dict `:` type($res)";
let hasVerifier = 1;
}
//
@ -120,7 +120,6 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
%0 = amx.tile_load %arg0[%c0, %c0] : memref<?x?xi8> into vector<16x64xi8>
```
}];
let verifier = [{ return ::verify(*this); }];
let arguments = (ins Arg<AnyMemRef, "load base", [MemRead]>:$base,
Variadic<Index>:$indices);
let results = (outs
@ -135,6 +134,7 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> {
}];
let assemblyFormat = "$base `[` $indices `]` attr-dict `:` "
"type($base) `into` type($res)";
let hasVerifier = 1;
}
def TileStoreOp : AMX_Op<"tile_store"> {
@ -151,7 +151,6 @@ def TileStoreOp : AMX_Op<"tile_store"> {
amx.tile_store %arg1[%c0, %c0], %0 : memref<?x?xi8>, vector<16x64xi8>
```
}];
let verifier = [{ return ::verify(*this); }];
let arguments = (ins Arg<AnyMemRef, "store base", [MemWrite]>:$base,
Variadic<Index>:$indices,
VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val);
@ -165,6 +164,7 @@ def TileStoreOp : AMX_Op<"tile_store"> {
}];
let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` "
"type($base) `,` type($val)";
let hasVerifier = 1;
}
//
@ -186,7 +186,6 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]
: vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32>
```
}];
let verifier = [{ return ::verify(*this); }];
let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs,
VectorOfRankAndType<[2], [F32, BF16]>:$rhs,
VectorOfRankAndType<[2], [F32, BF16]>:$acc);
@ -204,6 +203,7 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"]
}];
let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
let hasVerifier = 1;
}
def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> {
@ -224,7 +224,6 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
: vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32>
```
}];
let verifier = [{ return ::verify(*this); }];
let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs,
VectorOfRankAndType<[2], [I32, I8]>:$rhs,
VectorOfRankAndType<[2], [I32, I8]>:$acc,
@ -245,6 +244,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]
}];
let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` "
"type($lhs) `,` type($rhs) `,` type($acc) ";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//

View File

@ -351,9 +351,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> {
constexpr static int kDynamicIndex = std::numeric_limits<int32_t>::min();
}];
let hasFolder = 1;
let verifier = [{
return ::verify(*this);
}];
let hasVerifier = 1;
}
def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
@ -386,7 +384,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes {
CArg<"bool", "false">:$isNonTemporal)>];
let parser = [{ return parseLoadOp(parser, result); }];
let printer = [{ printLoadOp(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let hasVerifier = 1;
}
def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
@ -410,7 +408,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes {
];
let parser = [{ return parseStoreOp(parser, result); }];
let printer = [{ printStoreOp(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let hasVerifier = 1;
}
// Casts.
@ -494,18 +492,18 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [
build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps,
unwindOps, normal, unwind);
}]>];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseInvokeOp(parser, result); }];
let printer = [{ printInvokeOp(p, *this); }];
let hasVerifier = 1;
}
def LLVM_LandingpadOp : LLVM_Op<"landingpad"> {
let arguments = (ins UnitAttr:$cleanup, Variadic<LLVM_Type>);
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseLandingpadOp(parser, result); }];
let printer = [{ printLandingpadOp(p, *this); }];
let hasVerifier = 1;
}
def LLVM_CallOp : LLVM_Op<"call",
@ -562,9 +560,9 @@ def LLVM_CallOp : LLVM_Op<"call",
build($_builder, $_state, results,
StringAttr::get($_builder.getContext(), callee), operands);
}]>];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseCallOp(parser, result); }];
let printer = [{ printCallOp(p, *this); }];
let hasVerifier = 1;
}
def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position);
@ -575,9 +573,9 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> {
let builders = [
OpBuilder<(ins "Value":$vector, "Value":$position,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseExtractElementOp(parser, result); }];
let printer = [{ printExtractElementOp(p, *this); }];
let hasVerifier = 1;
}
def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
let arguments = (ins LLVM_AnyAggregate:$container, ArrayAttr:$position);
@ -586,10 +584,10 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> {
$res = builder.CreateExtractValue($container, extractPosition($position));
}];
let builders = [LLVM_OneResultOpBuilder];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseExtractValueOp(parser, result); }];
let printer = [{ printExtractValueOp(p, *this); }];
let hasFolder = 1;
let hasVerifier = 1;
}
def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value,
@ -599,9 +597,9 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> {
$res = builder.CreateInsertElement($vector, $value, $position);
}];
let builders = [LLVM_OneResultOpBuilder];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseInsertElementOp(parser, result); }];
let printer = [{ printInsertElementOp(p, *this); }];
let hasVerifier = 1;
}
def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
let arguments = (ins LLVM_AnyAggregate:$container, LLVM_PrimitiveType:$value,
@ -616,9 +614,9 @@ def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> {
[{
build($_builder, $_state, container.getType(), container, value, position);
}]>];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return parseInsertValueOp(parser, result); }];
let printer = [{ printInsertValueOp(p, *this); }];
let hasVerifier = 1;
}
def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
let arguments = (ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask);
@ -631,16 +629,9 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> {
let builders = [
OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask,
CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
let verifier = [{
auto type1 = getV1().getType();
auto type2 = getV2().getType();
if (::mlir::LLVM::getVectorElementType(type1) !=
::mlir::LLVM::getVectorElementType(type2))
return emitOpError("expected matching LLVM IR Dialect element types");
return success();
}];
let parser = [{ return parseShuffleVectorOp(parser, result); }];
let printer = [{ printShuffleVectorOp(p, *this); }];
let hasVerifier = 1;
}
// Misc operations.
@ -718,27 +709,15 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> {
builder.CreateRetVoid();
}];
let verifier = [{
if (getNumOperands() > 1)
return emitOpError("expects at most 1 operand");
return success();
}];
let parser = [{ return parseReturnOp(parser, result); }];
let printer = [{ printReturnOp(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let hasVerifier = 1;
}
def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> {
let arguments = (ins LLVM_Type:$value);
string llvmBuilder = [{ builder.CreateResume($value); }];
let verifier = [{
if (!isa_and_nonnull<LandingpadOp>(getValue().getDefiningOp()))
return emitOpError("expects landingpad value as operand");
// No check for personality of function - landingpad op verifies it.
return success();
}];
let assemblyFormat = "$value attr-dict `:` type($value)";
let hasVerifier = 1;
}
def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> {
string llvmBuilder = [{ builder.CreateUnreachable(); }];
@ -761,7 +740,6 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
VariadicSuccessor<AnySuccessor>:$caseDestinations
);
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = [{
$value `:` type($value) `,`
$defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)?
@ -769,6 +747,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch",
$caseOperands, type($caseOperands)) `]`
attr-dict
}];
let hasVerifier = 1;
let builders = [
OpBuilder<(ins "Value":$value,
@ -924,7 +903,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", [NoSideEffect]> {
}];
let assemblyFormat = "$global_name attr-dict `:` type($res)";
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
}
def LLVM_MetadataOp : LLVM_Op<"metadata", [
@ -1175,7 +1154,7 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global",
let printer = "printGlobalOp(p, *this);";
let parser = "return parseGlobalOp(parser, result);";
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
}
def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [
@ -1205,8 +1184,8 @@ def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [
```
}];
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = "attr-dict";
let hasVerifier = 1;
}
def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [
@ -1234,8 +1213,8 @@ def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [
```
}];
let verifier = [{ return ::verify(*this); }];
let assemblyFormat = "attr-dict";
let hasVerifier = 1;
}
def LLVM_LLVMFuncOp : LLVM_Op<"func", [
@ -1310,9 +1289,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [
LogicalResult verifyType();
}];
let verifier = [{ return ::verify(*this); }];
let printer = [{ printLLVMFuncOp(p, *this); }];
let parser = [{ return parseLLVMFuncOp(parser, result); }];
let hasVerifier = 1;
}
def LLVM_NullOp
@ -1402,8 +1381,8 @@ def LLVM_ConstantOp
let results = (outs LLVM_Type:$res);
let builders = [LLVM_OneResultOpBuilder];
let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)";
let verifier = [{ return ::verify(*this); }];
let hasFolder = 1;
let hasVerifier = 1;
}
// Operations that correspond to LLVM intrinsics. With MLIR operation set being
@ -1848,7 +1827,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> {
}];
let parser = [{ return parseAtomicRMWOp(parser, result); }];
let printer = [{ printAtomicRMWOp(p, *this); }];
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
}
def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>;
@ -1878,7 +1857,7 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg"> {
}];
let parser = [{ return parseAtomicCmpXchgOp(parser, result); }];
let printer = [{ printAtomicCmpXchgOp(p, *this); }];
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
}
def LLVM_AssumeOp : LLVM_Op<"intr.assume", []> {
@ -1901,7 +1880,7 @@ def LLVM_FenceOp : LLVM_Op<"fence"> {
}];
let parser = [{ return parseFenceOp(parser, result); }];
let printer = [{ printFenceOp(p, *this); }];
let verifier = "return ::verify(*this);";
let hasVerifier = 1;
}
def AsmATT : LLVM_EnumAttrCase<

View File

@ -22,12 +22,16 @@
#include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc"
namespace mlir {
namespace NVVM {
/// Return the element type and number of elements associated with a wmma matrix
/// of given chracteristics. This matches the logic in IntrinsicsNVVM.td
/// WMMA_REGS structure.
std::pair<mlir::Type, unsigned> inferMMAType(mlir::NVVM::MMATypes type,
mlir::NVVM::MMAFrag frag,
mlir::MLIRContext *context);
} // namespace NVVM
} // namespace mlir
///// Ops /////
#define GET_ATTRDEF_CLASSES

View File

@ -131,22 +131,11 @@ def NVVM_ShflOp :
$res = createIntrinsicCall(builder,
intId, {$dst, $val, $offset, $mask_and_clamp});
}];
let verifier = [{
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
auto type = getType().dyn_cast<LLVM::LLVMStructType>();
auto elementType = (type && type.getBody().size() == 2)
? type.getBody()[1].dyn_cast<IntegerType>()
: nullptr;
if (!elementType || elementType.getWidth() != 1)
return emitError("expected return type to be a two-element struct with "
"i1 as the second element");
return success();
}];
let assemblyFormat = [{
$kind $dst `,` $val `,` $offset `,` $mask_and_clamp attr-dict
`:` type($val) `->` type($res)
}];
let hasVerifier = 1;
}
def NVVM_VoteBallotOp :
@ -183,12 +172,8 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">,
}
createIntrinsicCall(builder, id, {$dst, $src});
}];
let verifier = [{
if (size() != 4 && size() != 8 && size() != 16)
return emitError("expected byte size to be either 4, 8 or 16.");
return success();
}];
let assemblyFormat = "$dst `,` $src `,` $size attr-dict";
let hasVerifier = 1;
}
def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> {
@ -220,7 +205,7 @@ def NVVM_MmaOp :
builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args);
}];
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
let verifier = [{ return ::verify(*this); }];
let hasVerifier = 1;
}
/// Helpers to instantiate different version of wmma intrinsics.
@ -538,7 +523,7 @@ def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">,
}];
let assemblyFormat = "$ptr `,` $stride attr-dict `:` functional-type($ptr, $res)";
let verifier = [{ return ::verify(*this); }];
let hasVerifier = 1;
}
def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
@ -593,7 +578,7 @@ def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">,
}];
let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)";
let verifier = [{ return ::verify(*this); }];
let hasVerifier = 1;
}
// Base class for all the variants of WMMA mmaOps that may be defined.
@ -647,7 +632,7 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">,
}];
let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)";
let verifier = [{ return ::verify(*this); }];
let hasVerifier = 1;
}
#endif // NVVMIR_OPS

View File

@ -76,7 +76,6 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
with their respective bit set in writemask `k`) to `dst`, and pass through the
remaining elements from `src`.
}];
let verifier = [{ return ::verify(*this); }];
let arguments = (ins VectorOfLengthAndType<[16, 8],
[I1]>:$k,
VectorOfLengthAndType<[16, 8],
@ -88,6 +87,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect,
[F32, I32, F64, I64]>:$dst);
let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict"
" `:` type($dst) (`,` type($src)^)?";
let hasVerifier = 1;
}
def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [

View File

@ -358,22 +358,19 @@ struct WmmaElementwiseOpToNVVMLowering
} // namespace
namespace mlir {
/// Return the LLVMStructureType corresponding to the MMAMatrixType `type`.
LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) {
LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) {
NVVM::MMAFrag frag = convertOperand(type.getOperand());
NVVM::MMATypes eltType = getElementType(type);
std::pair<Type, unsigned> typeInfo =
inferMMAType(eltType, frag, type.getContext());
NVVM::inferMMAType(eltType, frag, type.getContext());
return LLVM::LLVMStructType::getLiteral(
type.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
}
void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter,
RewritePatternSet &patterns) {
void mlir::populateGpuWMMAToNVVMConversionPatterns(
LLVMTypeConverter &converter, RewritePatternSet &patterns) {
patterns.insert<WmmaLoadOpToNVVMLowering, WmmaMmaOpToNVVMLowering,
WmmaStoreOpToNVVMLowering, WmmaConstantOpToNVVMLowering,
WmmaElementwiseOpToNVVMLowering>(converter);
}
} // namespace mlir

View File

@ -52,53 +52,55 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp,
return success();
}
static LogicalResult verify(amx::TileZeroOp op) {
return verifyTileSize(op, op.getVectorType());
LogicalResult amx::TileZeroOp::verify() {
return verifyTileSize(*this, getVectorType());
}
static LogicalResult verify(amx::TileLoadOp op) {
unsigned rank = op.getMemRefType().getRank();
if (llvm::size(op.indices()) != rank)
return op.emitOpError("requires ") << rank << " indices";
return verifyTileSize(op, op.getVectorType());
LogicalResult amx::TileLoadOp::verify() {
unsigned rank = getMemRefType().getRank();
if (indices().size() != rank)
return emitOpError("requires ") << rank << " indices";
return verifyTileSize(*this, getVectorType());
}
static LogicalResult verify(amx::TileStoreOp op) {
unsigned rank = op.getMemRefType().getRank();
if (llvm::size(op.indices()) != rank)
return op.emitOpError("requires ") << rank << " indices";
return verifyTileSize(op, op.getVectorType());
LogicalResult amx::TileStoreOp::verify() {
unsigned rank = getMemRefType().getRank();
if (indices().size() != rank)
return emitOpError("requires ") << rank << " indices";
return verifyTileSize(*this, getVectorType());
}
static LogicalResult verify(amx::TileMulFOp op) {
VectorType aType = op.getLhsVectorType();
VectorType bType = op.getRhsVectorType();
VectorType cType = op.getVectorType();
if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
failed(verifyTileSize(op, cType)) ||
failed(verifyMultShape(op, aType, bType, cType, 1)))
LogicalResult amx::TileMulFOp::verify() {
VectorType aType = getLhsVectorType();
VectorType bType = getRhsVectorType();
VectorType cType = getVectorType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
failed(verifyMultShape(*this, aType, bType, cType, 1)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if (!ta.isBF16() || !tb.isBF16() || !tc.isF32())
return op.emitOpError("unsupported type combination");
return emitOpError("unsupported type combination");
return success();
}
static LogicalResult verify(amx::TileMulIOp op) {
VectorType aType = op.getLhsVectorType();
VectorType bType = op.getRhsVectorType();
VectorType cType = op.getVectorType();
if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) ||
failed(verifyTileSize(op, cType)) ||
failed(verifyMultShape(op, aType, bType, cType, 2)))
LogicalResult amx::TileMulIOp::verify() {
VectorType aType = getLhsVectorType();
VectorType bType = getRhsVectorType();
VectorType cType = getVectorType();
if (failed(verifyTileSize(*this, aType)) ||
failed(verifyTileSize(*this, bType)) ||
failed(verifyTileSize(*this, cType)) ||
failed(verifyMultShape(*this, aType, bType, cType, 2)))
return failure();
Type ta = aType.getElementType();
Type tb = bType.getElementType();
Type tc = cType.getElementType();
if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32))
return op.emitOpError("unsupported type combination");
return emitOpError("unsupported type combination");
return success();
}

View File

@ -334,18 +334,17 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType,
p.printNewline();
}
static LogicalResult verify(SwitchOp op) {
if ((!op.getCaseValues() && !op.getCaseDestinations().empty()) ||
(op.getCaseValues() &&
op.getCaseValues()->size() !=
static_cast<int64_t>(op.getCaseDestinations().size())))
return op.emitOpError("expects number of case values to match number of "
"case destinations");
if (op.getBranchWeights() &&
op.getBranchWeights()->size() != op.getNumSuccessors())
return op.emitError("expects number of branch weights to match number of "
"successors: ")
<< op.getBranchWeights()->size() << " vs " << op.getNumSuccessors();
LogicalResult SwitchOp::verify() {
if ((!getCaseValues() && !getCaseDestinations().empty()) ||
(getCaseValues() &&
getCaseValues()->size() !=
static_cast<int64_t>(getCaseDestinations().size())))
return emitOpError("expects number of case values to match number of "
"case destinations");
if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors())
return emitError("expects number of branch weights to match number of "
"successors: ")
<< getBranchWeights()->size() << " vs " << getNumSuccessors();
return success();
}
@ -518,11 +517,11 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp,
});
}
LogicalResult verify(LLVM::GEPOp gepOp) {
LogicalResult LLVM::GEPOp::verify() {
SmallVector<unsigned> indices;
SmallVector<unsigned> structSizes;
findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes);
DenseIntElementsAttr structIndices = gepOp.getStructIndices();
findKnownStructIndices(getBase().getType(), indices, &structSizes);
DenseIntElementsAttr structIndices = getStructIndices();
for (unsigned i : llvm::seq<unsigned>(0, indices.size())) {
unsigned index = indices[i];
// GEP may not be indexing as deep as some structs nested in the type.
@ -531,11 +530,11 @@ LogicalResult verify(LLVM::GEPOp gepOp) {
int32_t staticIndex = structIndices.getValues<int32_t>()[index];
if (staticIndex == LLVM::GEPOp::kDynamicIndex)
return gepOp.emitOpError() << "expected index " << index
<< " indexing a struct to be constant";
return emitOpError() << "expected index " << index
<< " indexing a struct to be constant";
if (staticIndex < 0 || static_cast<unsigned>(staticIndex) >= structSizes[i])
return gepOp.emitOpError()
<< "index " << index << " indexing a struct is out of bounds";
return emitOpError() << "index " << index
<< " indexing a struct is out of bounds";
}
return success();
}
@ -613,9 +612,7 @@ static LogicalResult verifyMemoryOpMetadata(Operation *op) {
return success();
}
static LogicalResult verify(LoadOp op) {
return verifyMemoryOpMetadata(op.getOperation());
}
LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); }
void LoadOp::build(OpBuilder &builder, OperationState &result, Type t,
Value addr, unsigned alignment, bool isVolatile,
@ -675,9 +672,7 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) {
// Builder, printer and parser for LLVM::StoreOp.
//===----------------------------------------------------------------------===//
static LogicalResult verify(StoreOp op) {
return verifyMemoryOpMetadata(op.getOperation());
}
LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); }
void StoreOp::build(OpBuilder &builder, OperationState &result, Value value,
Value addr, unsigned alignment, bool isVolatile,
@ -739,19 +734,18 @@ InvokeOp::getMutableSuccessorOperands(unsigned index) {
: getUnwindDestOperandsMutable();
}
static LogicalResult verify(InvokeOp op) {
if (op.getNumResults() > 1)
return op.emitOpError("must have 0 or 1 result");
LogicalResult InvokeOp::verify() {
if (getNumResults() > 1)
return emitOpError("must have 0 or 1 result");
Block *unwindDest = op.getUnwindDest();
Block *unwindDest = getUnwindDest();
if (unwindDest->empty())
return op.emitError(
"must have at least one operation in unwind destination");
return emitError("must have at least one operation in unwind destination");
// In unwind destination, first operation must be LandingpadOp
if (!isa<LandingpadOp>(unwindDest->front()))
return op.emitError("first operation in unwind destination should be a "
"llvm.landingpad operation");
return emitError("first operation in unwind destination should be a "
"llvm.landingpad operation");
return success();
}
@ -880,20 +874,20 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) {
/// Verifying/Printing/Parsing for LLVM::LandingpadOp.
///===----------------------------------------------------------------------===//
static LogicalResult verify(LandingpadOp op) {
LogicalResult LandingpadOp::verify() {
Value value;
if (LLVMFuncOp func = op->getParentOfType<LLVMFuncOp>()) {
if (LLVMFuncOp func = (*this)->getParentOfType<LLVMFuncOp>()) {
if (!func.getPersonality().hasValue())
return op.emitError(
return emitError(
"llvm.landingpad needs to be in a function with a personality");
}
if (!op.getCleanup() && op.getOperands().empty())
return op.emitError("landingpad instruction expects at least one clause or "
"cleanup attribute");
if (!getCleanup() && getOperands().empty())
return emitError("landingpad instruction expects at least one clause or "
"cleanup attribute");
for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) {
value = op.getOperand(idx);
for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) {
value = getOperand(idx);
bool isFilter = value.getType().isa<LLVMArrayType>();
if (isFilter) {
// FIXME: Verify filter clauses when arrays are appropriately handled
@ -903,8 +897,7 @@ static LogicalResult verify(LandingpadOp op) {
if (auto bcOp = value.getDefiningOp<BitcastOp>()) {
if (auto addrOp = bcOp.getArg().getDefiningOp<AddressOfOp>())
continue;
return op.emitError("constant clauses expected")
.attachNote(bcOp.getLoc())
return emitError("constant clauses expected").attachNote(bcOp.getLoc())
<< "global addresses expected as operand to "
"bitcast used in clauses for landingpad";
}
@ -913,7 +906,7 @@ static LogicalResult verify(LandingpadOp op) {
continue;
if (value.getDefiningOp<AddressOfOp>())
continue;
return op.emitError("clause #")
return emitError("clause #")
<< idx << " is not a known constant - null, addressof, bitcast";
}
}
@ -970,9 +963,9 @@ static ParseResult parseLandingpadOp(OpAsmParser &parser,
// Verifying/Printing/parsing for LLVM::CallOp.
//===----------------------------------------------------------------------===//
static LogicalResult verify(CallOp &op) {
if (op.getNumResults() > 1)
return op.emitOpError("must have 0 or 1 result");
LogicalResult CallOp::verify() {
if (getNumResults() > 1)
return emitOpError("must have 0 or 1 result");
// Type for the callee, we'll get it differently depending if it is a direct
// or indirect call.
@ -981,75 +974,73 @@ static LogicalResult verify(CallOp &op) {
bool isIndirect = false;
// If this is an indirect call, the callee attribute is missing.
FlatSymbolRefAttr calleeName = op.getCalleeAttr();
FlatSymbolRefAttr calleeName = getCalleeAttr();
if (!calleeName) {
isIndirect = true;
if (!op.getNumOperands())
return op.emitOpError(
if (!getNumOperands())
return emitOpError(
"must have either a `callee` attribute or at least an operand");
auto ptrType = op.getOperand(0).getType().dyn_cast<LLVMPointerType>();
auto ptrType = getOperand(0).getType().dyn_cast<LLVMPointerType>();
if (!ptrType)
return op.emitOpError("indirect call expects a pointer as callee: ")
return emitOpError("indirect call expects a pointer as callee: ")
<< ptrType;
fnType = ptrType.getElementType();
} else {
Operation *callee =
SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr());
SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr());
if (!callee)
return op.emitOpError()
return emitOpError()
<< "'" << calleeName.getValue()
<< "' does not reference a symbol in the current scope";
auto fn = dyn_cast<LLVMFuncOp>(callee);
if (!fn)
return op.emitOpError() << "'" << calleeName.getValue()
<< "' does not reference a valid LLVM function";
return emitOpError() << "'" << calleeName.getValue()
<< "' does not reference a valid LLVM function";
fnType = fn.getType();
}
LLVMFunctionType funcType = fnType.dyn_cast<LLVMFunctionType>();
if (!funcType)
return op.emitOpError("callee does not have a functional type: ") << fnType;
return emitOpError("callee does not have a functional type: ") << fnType;
// Verify that the operand and result types match the callee.
if (!funcType.isVarArg() &&
funcType.getNumParams() != (op.getNumOperands() - isIndirect))
return op.emitOpError()
<< "incorrect number of operands ("
<< (op.getNumOperands() - isIndirect)
<< ") for callee (expecting: " << funcType.getNumParams() << ")";
funcType.getNumParams() != (getNumOperands() - isIndirect))
return emitOpError() << "incorrect number of operands ("
<< (getNumOperands() - isIndirect)
<< ") for callee (expecting: "
<< funcType.getNumParams() << ")";
if (funcType.getNumParams() > (op.getNumOperands() - isIndirect))
return op.emitOpError() << "incorrect number of operands ("
<< (op.getNumOperands() - isIndirect)
<< ") for varargs callee (expecting at least: "
<< funcType.getNumParams() << ")";
if (funcType.getNumParams() > (getNumOperands() - isIndirect))
return emitOpError() << "incorrect number of operands ("
<< (getNumOperands() - isIndirect)
<< ") for varargs callee (expecting at least: "
<< funcType.getNumParams() << ")";
for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i)
if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i))
return op.emitOpError() << "operand type mismatch for operand " << i
<< ": " << op.getOperand(i + isIndirect).getType()
<< " != " << funcType.getParamType(i);
if (getOperand(i + isIndirect).getType() != funcType.getParamType(i))
return emitOpError() << "operand type mismatch for operand " << i << ": "
<< getOperand(i + isIndirect).getType()
<< " != " << funcType.getParamType(i);
if (op.getNumResults() == 0 &&
if (getNumResults() == 0 &&
!funcType.getReturnType().isa<LLVM::LLVMVoidType>())
return op.emitOpError() << "expected function call to produce a value";
return emitOpError() << "expected function call to produce a value";
if (op.getNumResults() != 0 &&
if (getNumResults() != 0 &&
funcType.getReturnType().isa<LLVM::LLVMVoidType>())
return op.emitOpError()
return emitOpError()
<< "calling function with void result must not produce values";
if (op.getNumResults() > 1)
return op.emitOpError()
if (getNumResults() > 1)
return emitOpError()
<< "expected LLVM function call to produce 0 or 1 result";
if (op.getNumResults() &&
op.getResult(0).getType() != funcType.getReturnType())
return op.emitOpError()
<< "result type mismatch: " << op.getResult(0).getType()
<< " != " << funcType.getReturnType();
if (getNumResults() && getResult(0).getType() != funcType.getReturnType())
return emitOpError() << "result type mismatch: " << getResult(0).getType()
<< " != " << funcType.getReturnType();
return success();
}
@ -1200,17 +1191,17 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(ExtractElementOp op) {
Type vectorType = op.getVector().getType();
LogicalResult ExtractElementOp::verify() {
Type vectorType = getVector().getType();
if (!LLVM::isCompatibleVectorType(vectorType))
return op->emitOpError("expected LLVM dialect-compatible vector type for "
"operand #1, got")
return emitOpError("expected LLVM dialect-compatible vector type for "
"operand #1, got")
<< vectorType;
Type valueType = LLVM::getVectorElementType(vectorType);
if (valueType != op.getRes().getType())
return op.emitOpError() << "Type mismatch: extracting from " << vectorType
<< " should produce " << valueType
<< " but this op returns " << op.getRes().getType();
if (valueType != getRes().getType())
return emitOpError() << "Type mismatch: extracting from " << vectorType
<< " should produce " << valueType
<< " but this op returns " << getRes().getType();
return success();
}
@ -1367,17 +1358,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef<Attribute> operands) {
return {};
}
static LogicalResult verify(ExtractValueOp op) {
Type valueType = getInsertExtractValueElementType(op.getContainer().getType(),
op.getPositionAttr(), op);
LogicalResult ExtractValueOp::verify() {
Type valueType = getInsertExtractValueElementType(getContainer().getType(),
getPositionAttr(), *this);
if (!valueType)
return failure();
if (op.getRes().getType() != valueType)
return op.emitOpError()
<< "Type mismatch: extracting from " << op.getContainer().getType()
<< " should produce " << valueType << " but this op returns "
<< op.getRes().getType();
if (getRes().getType() != valueType)
return emitOpError() << "Type mismatch: extracting from "
<< getContainer().getType() << " should produce "
<< valueType << " but this op returns "
<< getRes().getType();
return success();
}
@ -1423,14 +1414,15 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(InsertElementOp op) {
Type valueType = LLVM::getVectorElementType(op.getVector().getType());
if (valueType != op.getValue().getType())
return op.emitOpError()
<< "Type mismatch: cannot insert " << op.getValue().getType()
<< " into " << op.getVector().getType();
LogicalResult InsertElementOp::verify() {
Type valueType = LLVM::getVectorElementType(getVector().getType());
if (valueType != getValue().getType())
return emitOpError() << "Type mismatch: cannot insert "
<< getValue().getType() << " into "
<< getVector().getType();
return success();
}
//===----------------------------------------------------------------------===//
// Printing/parsing for LLVM::InsertValueOp.
//===----------------------------------------------------------------------===//
@ -1473,16 +1465,16 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(InsertValueOp op) {
Type valueType = getInsertExtractValueElementType(op.getContainer().getType(),
op.getPositionAttr(), op);
LogicalResult InsertValueOp::verify() {
Type valueType = getInsertExtractValueElementType(getContainer().getType(),
getPositionAttr(), *this);
if (!valueType)
return failure();
if (op.getValue().getType() != valueType)
return op.emitOpError()
<< "Type mismatch: cannot insert " << op.getValue().getType()
<< " into " << op.getContainer().getType();
if (getValue().getType() != valueType)
return emitOpError() << "Type mismatch: cannot insert "
<< getValue().getType() << " into "
<< getContainer().getType();
return success();
}
@ -1519,28 +1511,28 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) {
return success();
}
static LogicalResult verify(ReturnOp op) {
if (op->getNumOperands() > 1)
return op->emitOpError("expected at most 1 operand");
LogicalResult ReturnOp::verify() {
if (getNumOperands() > 1)
return emitOpError("expected at most 1 operand");
if (auto parent = op->getParentOfType<LLVMFuncOp>()) {
if (auto parent = (*this)->getParentOfType<LLVMFuncOp>()) {
Type expectedType = parent.getType().getReturnType();
if (expectedType.isa<LLVMVoidType>()) {
if (op->getNumOperands() == 0)
if (getNumOperands() == 0)
return success();
InFlightDiagnostic diag = op->emitOpError("expected no operands");
InFlightDiagnostic diag = emitOpError("expected no operands");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
if (op->getNumOperands() == 0) {
if (getNumOperands() == 0) {
if (expectedType.isa<LLVMVoidType>())
return success();
InFlightDiagnostic diag = op->emitOpError("expected 1 operand");
InFlightDiagnostic diag = emitOpError("expected 1 operand");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
if (expectedType != op->getOperand(0).getType()) {
InFlightDiagnostic diag = op->emitOpError("mismatching result types");
if (expectedType != getOperand(0).getType()) {
InFlightDiagnostic diag = emitOpError("mismatching result types");
diag.attachNote(parent->getLoc()) << "when returning from function";
return diag;
}
@ -1548,6 +1540,17 @@ static LogicalResult verify(ReturnOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// ResumeOp
//===----------------------------------------------------------------------===//
LogicalResult ResumeOp::verify() {
if (!getValue().getDefiningOp<LandingpadOp>())
return emitOpError("expects landingpad value as operand");
// No check for personality of function - landingpad op verifies it.
return success();
}
//===----------------------------------------------------------------------===//
// Verifier for LLVM::AddressOfOp.
//===----------------------------------------------------------------------===//
@ -1572,22 +1575,22 @@ LLVMFuncOp AddressOfOp::getFunction() {
getGlobalName());
}
static LogicalResult verify(AddressOfOp op) {
auto global = op.getGlobal();
auto function = op.getFunction();
LogicalResult AddressOfOp::verify() {
auto global = getGlobal();
auto function = getFunction();
if (!global && !function)
return op.emitOpError(
return emitOpError(
"must reference a global defined by 'llvm.mlir.global' or 'llvm.func'");
if (global &&
LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) !=
op.getResult().getType())
return op.emitOpError(
getResult().getType())
return emitOpError(
"the type must be a pointer to the type of the referenced global");
if (function && LLVM::LLVMPointerType::get(function.getType()) !=
op.getResult().getType())
return op.emitOpError(
if (function &&
LLVM::LLVMPointerType::get(function.getType()) != getResult().getType())
return emitOpError(
"the type must be a pointer to the type of the referenced function");
return success();
@ -1791,60 +1794,60 @@ static bool isZeroAttribute(Attribute value) {
return false;
}
static LogicalResult verify(GlobalOp op) {
if (!LLVMPointerType::isValidElementType(op.getType()))
return op.emitOpError(
LogicalResult GlobalOp::verify() {
if (!LLVMPointerType::isValidElementType(getType()))
return emitOpError(
"expects type to be a valid element type for an LLVM pointer");
if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp()))
return op.emitOpError("must appear at the module level");
if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp()))
return emitOpError("must appear at the module level");
if (auto strAttr = op.getValueOrNull().dyn_cast_or_null<StringAttr>()) {
auto type = op.getType().dyn_cast<LLVMArrayType>();
if (auto strAttr = getValueOrNull().dyn_cast_or_null<StringAttr>()) {
auto type = getType().dyn_cast<LLVMArrayType>();
IntegerType elementType =
type ? type.getElementType().dyn_cast<IntegerType>() : nullptr;
if (!elementType || elementType.getWidth() != 8 ||
type.getNumElements() != strAttr.getValue().size())
return op.emitOpError(
return emitOpError(
"requires an i8 array type of the length equal to that of the string "
"attribute");
}
if (Block *b = op.getInitializerBlock()) {
if (Block *b = getInitializerBlock()) {
ReturnOp ret = cast<ReturnOp>(b->getTerminator());
if (ret.operand_type_begin() == ret.operand_type_end())
return op.emitOpError("initializer region cannot return void");
if (*ret.operand_type_begin() != op.getType())
return op.emitOpError("initializer region type ")
return emitOpError("initializer region cannot return void");
if (*ret.operand_type_begin() != getType())
return emitOpError("initializer region type ")
<< *ret.operand_type_begin() << " does not match global type "
<< op.getType();
<< getType();
if (op.getValueOrNull())
return op.emitOpError("cannot have both initializer value and region");
if (getValueOrNull())
return emitOpError("cannot have both initializer value and region");
}
if (op.getLinkage() == Linkage::Common) {
if (Attribute value = op.getValueOrNull()) {
if (getLinkage() == Linkage::Common) {
if (Attribute value = getValueOrNull()) {
if (!isZeroAttribute(value)) {
return op.emitOpError()
return emitOpError()
<< "expected zero value for '"
<< stringifyLinkage(Linkage::Common) << "' linkage";
}
}
}
if (op.getLinkage() == Linkage::Appending) {
if (!op.getType().isa<LLVMArrayType>()) {
return op.emitOpError()
<< "expected array type for '"
<< stringifyLinkage(Linkage::Appending) << "' linkage";
if (getLinkage() == Linkage::Appending) {
if (!getType().isa<LLVMArrayType>()) {
return emitOpError() << "expected array type for '"
<< stringifyLinkage(Linkage::Appending)
<< "' linkage";
}
}
Optional<uint64_t> alignAttr = op.getAlignment();
Optional<uint64_t> alignAttr = getAlignment();
if (alignAttr.hasValue()) {
uint64_t value = alignAttr.getValue();
if (!llvm::isPowerOf2_64(value))
return op->emitError() << "alignment attribute is not a power of 2";
return emitError() << "alignment attribute is not a power of 2";
}
return success();
@ -1864,9 +1867,9 @@ GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
static LogicalResult verify(GlobalCtorsOp op) {
if (op.getCtors().size() != op.getPriorities().size())
return op.emitError(
LogicalResult GlobalCtorsOp::verify() {
if (getCtors().size() != getPriorities().size())
return emitError(
"mismatch between the number of ctors and the number of priorities");
return success();
}
@ -1885,9 +1888,9 @@ GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
static LogicalResult verify(GlobalDtorsOp op) {
if (op.getDtors().size() != op.getPriorities().size())
return op.emitError(
LogicalResult GlobalDtorsOp::verify() {
if (getDtors().size() != getPriorities().size())
return emitError(
"mismatch between the number of dtors and the number of priorities");
return success();
}
@ -1940,6 +1943,14 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser,
return success();
}
LogicalResult ShuffleVectorOp::verify() {
Type type1 = getV1().getType();
Type type2 = getV2().getType();
if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2))
return emitOpError("expected matching LLVM IR Dialect element types");
return success();
}
//===----------------------------------------------------------------------===//
// Implementations for LLVM::LLVMFuncOp.
//===----------------------------------------------------------------------===//
@ -2117,42 +2128,43 @@ LogicalResult LLVMFuncOp::verifyType() {
// - external functions have 'external' or 'extern_weak' linkage;
// - vararg is (currently) only supported for external functions;
// - entry block arguments are of LLVM types and match the function signature.
static LogicalResult verify(LLVMFuncOp op) {
if (op.getLinkage() == LLVM::Linkage::Common)
return op.emitOpError()
<< "functions cannot have '"
<< stringifyLinkage(LLVM::Linkage::Common) << "' linkage";
LogicalResult LLVMFuncOp::verify() {
if (getLinkage() == LLVM::Linkage::Common)
return emitOpError() << "functions cannot have '"
<< stringifyLinkage(LLVM::Linkage::Common)
<< "' linkage";
// Check to see if this function has a void return with a result attribute to
// it. It isn't clear what semantics we would assign to that.
if (op.getType().getReturnType().isa<LLVMVoidType>() &&
!op.getResultAttrs(0).empty()) {
return op.emitOpError()
if (getType().getReturnType().isa<LLVMVoidType>() &&
!getResultAttrs(0).empty()) {
return emitOpError()
<< "cannot attach result attributes to functions with a void return";
}
if (op.isExternal()) {
if (op.getLinkage() != LLVM::Linkage::External &&
op.getLinkage() != LLVM::Linkage::ExternWeak)
return op.emitOpError()
<< "external functions must have '"
<< stringifyLinkage(LLVM::Linkage::External) << "' or '"
<< stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage";
if (isExternal()) {
if (getLinkage() != LLVM::Linkage::External &&
getLinkage() != LLVM::Linkage::ExternWeak)
return emitOpError() << "external functions must have '"
<< stringifyLinkage(LLVM::Linkage::External)
<< "' or '"
<< stringifyLinkage(LLVM::Linkage::ExternWeak)
<< "' linkage";
return success();
}
if (op.isVarArg())
return op.emitOpError("only external functions can be variadic");
if (isVarArg())
return emitOpError("only external functions can be variadic");
unsigned numArguments = op.getType().getNumParams();
Block &entryBlock = op.front();
unsigned numArguments = getType().getNumParams();
Block &entryBlock = front();
for (unsigned i = 0; i < numArguments; ++i) {
Type argType = entryBlock.getArgument(i).getType();
if (!isCompatibleType(argType))
return op.emitOpError("entry block argument #")
return emitOpError("entry block argument #")
<< i << " is not of LLVM type";
if (op.getType().getParamType(i) != argType)
return op.emitOpError("the type of entry block argument #")
if (getType().getParamType(i) != argType)
return emitOpError("the type of entry block argument #")
<< i << " does not match the function signature";
}
@ -2163,42 +2175,42 @@ static LogicalResult verify(LLVMFuncOp op) {
// Verification for LLVM::ConstantOp.
//===----------------------------------------------------------------------===//
static LogicalResult verify(LLVM::ConstantOp op) {
if (StringAttr sAttr = op.getValue().dyn_cast<StringAttr>()) {
auto arrayType = op.getType().dyn_cast<LLVMArrayType>();
LogicalResult LLVM::ConstantOp::verify() {
if (StringAttr sAttr = getValue().dyn_cast<StringAttr>()) {
auto arrayType = getType().dyn_cast<LLVMArrayType>();
if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() ||
!arrayType.getElementType().isInteger(8)) {
return op->emitOpError()
<< "expected array type of " << sAttr.getValue().size()
<< " i8 elements for the string constant";
return emitOpError() << "expected array type of "
<< sAttr.getValue().size()
<< " i8 elements for the string constant";
}
return success();
}
if (auto structType = op.getType().dyn_cast<LLVMStructType>()) {
if (auto structType = getType().dyn_cast<LLVMStructType>()) {
if (structType.getBody().size() != 2 ||
structType.getBody()[0] != structType.getBody()[1]) {
return op.emitError() << "expected struct type with two elements of the "
"same type, the type of a complex constant";
return emitError() << "expected struct type with two elements of the "
"same type, the type of a complex constant";
}
auto arrayAttr = op.getValue().dyn_cast<ArrayAttr>();
auto arrayAttr = getValue().dyn_cast<ArrayAttr>();
if (!arrayAttr || arrayAttr.size() != 2 ||
arrayAttr[0].getType() != arrayAttr[1].getType()) {
return op.emitOpError() << "expected array attribute with two elements, "
"representing a complex constant";
return emitOpError() << "expected array attribute with two elements, "
"representing a complex constant";
}
Type elementType = structType.getBody()[0];
if (!elementType
.isa<IntegerType, Float16Type, Float32Type, Float64Type>()) {
return op.emitError()
return emitError()
<< "expected struct element types to be floating point type or "
"integer type";
}
return success();
}
if (!op.getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
return op.emitOpError()
if (!getValue().isa<IntegerAttr, ArrayAttr, FloatAttr, ElementsAttr>())
return emitOpError()
<< "only supports integer, float, string or elements attributes";
return success();
}
@ -2294,42 +2306,40 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(AtomicRMWOp op) {
auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>();
auto valType = op.getVal().getType();
LogicalResult AtomicRMWOp::verify() {
auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
auto valType = getVal().getType();
if (valType != ptrType.getElementType())
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for operand #1");
auto resType = op.getRes().getType();
return emitOpError("expected LLVM IR element type for operand #0 to "
"match type for operand #1");
auto resType = getRes().getType();
if (resType != valType)
return op.emitOpError(
return emitOpError(
"expected LLVM IR result type to match type for operand #1");
if (op.getBinOp() == AtomicBinOp::fadd ||
op.getBinOp() == AtomicBinOp::fsub) {
if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) {
if (!mlir::LLVM::isCompatibleFloatingPointType(valType))
return op.emitOpError("expected LLVM IR floating point type");
} else if (op.getBinOp() == AtomicBinOp::xchg) {
return emitOpError("expected LLVM IR floating point type");
} else if (getBinOp() == AtomicBinOp::xchg) {
auto intType = valType.dyn_cast<IntegerType>();
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
intBitWidth != 64 && !valType.isa<BFloat16Type>() &&
!valType.isa<Float16Type>() && !valType.isa<Float32Type>() &&
!valType.isa<Float64Type>())
return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
return emitOpError("unexpected LLVM IR type for 'xchg' bin_op");
} else {
auto intType = valType.dyn_cast<IntegerType>();
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 &&
intBitWidth != 64)
return op.emitOpError("expected LLVM IR integer type");
return emitOpError("expected LLVM IR integer type");
}
if (static_cast<unsigned>(op.getOrdering()) <
if (static_cast<unsigned>(getOrdering()) <
static_cast<unsigned>(AtomicOrdering::monotonic))
return op.emitOpError()
<< "expected at least '"
<< stringifyAtomicOrdering(AtomicOrdering::monotonic)
<< "' ordering";
return emitOpError() << "expected at least '"
<< stringifyAtomicOrdering(AtomicOrdering::monotonic)
<< "' ordering";
return success();
}
@ -2375,28 +2385,28 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(AtomicCmpXchgOp op) {
auto ptrType = op.getPtr().getType().cast<LLVM::LLVMPointerType>();
LogicalResult AtomicCmpXchgOp::verify() {
auto ptrType = getPtr().getType().cast<LLVM::LLVMPointerType>();
if (!ptrType)
return op.emitOpError("expected LLVM IR pointer type for operand #0");
auto cmpType = op.getCmp().getType();
auto valType = op.getVal().getType();
return emitOpError("expected LLVM IR pointer type for operand #0");
auto cmpType = getCmp().getType();
auto valType = getVal().getType();
if (cmpType != ptrType.getElementType() || cmpType != valType)
return op.emitOpError("expected LLVM IR element type for operand #0 to "
"match type for all other operands");
return emitOpError("expected LLVM IR element type for operand #0 to "
"match type for all other operands");
auto intType = valType.dyn_cast<IntegerType>();
unsigned intBitWidth = intType ? intType.getWidth() : 0;
if (!valType.isa<LLVMPointerType>() && intBitWidth != 8 &&
intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 &&
!valType.isa<BFloat16Type>() && !valType.isa<Float16Type>() &&
!valType.isa<Float32Type>() && !valType.isa<Float64Type>())
return op.emitOpError("unexpected LLVM IR type");
if (op.getSuccessOrdering() < AtomicOrdering::monotonic ||
op.getFailureOrdering() < AtomicOrdering::monotonic)
return op.emitOpError("ordering must be at least 'monotonic'");
if (op.getFailureOrdering() == AtomicOrdering::release ||
op.getFailureOrdering() == AtomicOrdering::acq_rel)
return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
return emitOpError("unexpected LLVM IR type");
if (getSuccessOrdering() < AtomicOrdering::monotonic ||
getFailureOrdering() < AtomicOrdering::monotonic)
return emitOpError("ordering must be at least 'monotonic'");
if (getFailureOrdering() == AtomicOrdering::release ||
getFailureOrdering() == AtomicOrdering::acq_rel)
return emitOpError("failure ordering cannot be 'release' or 'acq_rel'");
return success();
}
@ -2432,12 +2442,12 @@ static void printFenceOp(OpAsmPrinter &p, FenceOp &op) {
p << stringifyAtomicOrdering(op.getOrdering());
}
static LogicalResult verify(FenceOp &op) {
if (op.getOrdering() == AtomicOrdering::not_atomic ||
op.getOrdering() == AtomicOrdering::unordered ||
op.getOrdering() == AtomicOrdering::monotonic)
return op.emitOpError("can be given only acquire, release, acq_rel, "
"and seq_cst orderings");
LogicalResult FenceOp::verify() {
if (getOrdering() == AtomicOrdering::not_atomic ||
getOrdering() == AtomicOrdering::unordered ||
getOrdering() == AtomicOrdering::monotonic)
return emitOpError("can be given only acquire, release, acq_rel, "
"and seq_cst orderings");
return success();
}

View File

@ -62,8 +62,14 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser,
parser.getNameLoc(), result.operands));
}
static LogicalResult verify(MmaOp op) {
MLIRContext *context = op.getContext();
LogicalResult CpAsyncOp::verify() {
if (size() != 4 && size() != 8 && size() != 16)
return emitError("expected byte size to be either 4, 8 or 16.");
return success();
}
LogicalResult MmaOp::verify() {
MLIRContext *context = getContext();
auto f16Ty = Float16Type::get(context);
auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2);
auto f32Ty = Float32Type::get(context);
@ -72,44 +78,55 @@ static LogicalResult verify(MmaOp op) {
auto f32x8StructTy = LLVM::LLVMStructType::getLiteral(
context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty});
SmallVector<Type, 12> operandTypes(op.getOperandTypes().begin(),
op.getOperandTypes().end());
auto operandTypes = getOperandTypes();
if (operandTypes != SmallVector<Type, 8>(8, f16x2Ty) &&
operandTypes != SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty}) {
return op.emitOpError(
"expected operands to be 4 <halfx2>s followed by either "
"4 <halfx2>s or 8 floats");
operandTypes != ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty}) {
return emitOpError("expected operands to be 4 <halfx2>s followed by either "
"4 <halfx2>s or 8 floats");
}
if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) {
return op.emitOpError("expected result type to be a struct of either 4 "
"<halfx2>s or 8 floats");
if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) {
return emitOpError("expected result type to be a struct of either 4 "
"<halfx2>s or 8 floats");
}
auto alayout = op->getAttrOfType<StringAttr>("alayout");
auto blayout = op->getAttrOfType<StringAttr>("blayout");
auto alayout = (*this)->getAttrOfType<StringAttr>("alayout");
auto blayout = (*this)->getAttrOfType<StringAttr>("blayout");
if (!(alayout && blayout) ||
!(alayout.getValue() == "row" || alayout.getValue() == "col") ||
!(blayout.getValue() == "row" || blayout.getValue() == "col")) {
return op.emitOpError(
"alayout and blayout attributes must be set to either "
"\"row\" or \"col\"");
return emitOpError("alayout and blayout attributes must be set to either "
"\"row\" or \"col\"");
}
if (operandTypes == SmallVector<Type, 12>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty, f32Ty, f32Ty} &&
op.getType() == f32x8StructTy && alayout.getValue() == "row" &&
if (operandTypes == ArrayRef<Type>{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty,
f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty,
f32Ty} &&
getType() == f32x8StructTy && alayout.getValue() == "row" &&
blayout.getValue() == "col") {
return success();
}
return op.emitOpError("unimplemented mma.sync variant");
return emitOpError("unimplemented mma.sync variant");
}
std::pair<mlir::Type, unsigned>
inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) {
LogicalResult ShflOp::verify() {
if (!(*this)->getAttrOfType<UnitAttr>("return_value_and_is_valid"))
return success();
auto type = getType().dyn_cast<LLVM::LLVMStructType>();
auto elementType = (type && type.getBody().size() == 2)
? type.getBody()[1].dyn_cast<IntegerType>()
: nullptr;
if (!elementType || elementType.getWidth() != 1)
return emitError("expected return type to be a two-element struct with "
"i1 as the second element");
return success();
}
std::pair<mlir::Type, unsigned> NVVM::inferMMAType(NVVM::MMATypes type,
NVVM::MMAFrag frag,
MLIRContext *context) {
unsigned numberElements = 0;
Type elementType;
OpBuilder builder(context);
@ -131,76 +148,72 @@ inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) {
return std::make_pair(elementType, numberElements);
}
static LogicalResult verify(NVVM::WMMALoadOp op) {
LogicalResult NVVM::WMMALoadOp::verify() {
unsigned addressSpace =
op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
return op.emitOpError("expected source pointer in memory "
"space 0, 1, 3");
return emitOpError("expected source pointer in memory "
"space 0, 1, 3");
if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
op.eltype(), op.frag()) == 0)
return op.emitOpError() << "invalid attribute combination";
if (NVVM::WMMALoadOp::getIntrinsicID(m(), n(), k(), layout(), eltype(),
frag()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo =
inferMMAType(op.eltype(), op.frag(), op.getContext());
inferMMAType(eltype(), frag(), getContext());
Type dstType = LLVM::LLVMStructType::getLiteral(
op.getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
if (op.getType() != dstType)
return op.emitOpError("expected destination type is a structure of ")
getContext(), SmallVector<Type, 8>(typeInfo.second, typeInfo.first));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
<< typeInfo.second << " elements of type " << typeInfo.first;
return success();
}
static LogicalResult verify(NVVM::WMMAStoreOp op) {
LogicalResult NVVM::WMMAStoreOp::verify() {
unsigned addressSpace =
op.ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
ptr().getType().cast<LLVM::LLVMPointerType>().getAddressSpace();
if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3)
return op.emitOpError("expected operands to be a source pointer in memory "
"space 0, 1, 3");
return emitOpError("expected operands to be a source pointer in memory "
"space 0, 1, 3");
if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(),
op.eltype()) == 0)
return op.emitOpError() << "invalid attribute combination";
if (NVVM::WMMAStoreOp::getIntrinsicID(m(), n(), k(), layout(), eltype()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfo =
inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext());
if (op.args().size() != typeInfo.second)
return op.emitOpError()
<< "expected " << typeInfo.second << " data operands";
if (llvm::any_of(op.args(), [&typeInfo](Value operands) {
inferMMAType(eltype(), NVVM::MMAFrag::c, getContext());
if (args().size() != typeInfo.second)
return emitOpError() << "expected " << typeInfo.second << " data operands";
if (llvm::any_of(args(), [&typeInfo](Value operands) {
return operands.getType() != typeInfo.first;
}))
return op.emitOpError()
<< "expected data operands of type " << typeInfo.first;
return emitOpError() << "expected data operands of type " << typeInfo.first;
return success();
}
static LogicalResult verify(NVVM::WMMAMmaOp op) {
if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(),
op.layoutB(), op.eltypeA(),
op.eltypeB()) == 0)
return op.emitOpError() << "invalid attribute combination";
LogicalResult NVVM::WMMAMmaOp::verify() {
if (NVVM::WMMAMmaOp::getIntrinsicID(m(), n(), k(), layoutA(), layoutB(),
eltypeA(), eltypeB()) == 0)
return emitOpError() << "invalid attribute combination";
std::pair<Type, unsigned> typeInfoA =
inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext());
inferMMAType(eltypeA(), NVVM::MMAFrag::a, getContext());
std::pair<Type, unsigned> typeInfoB =
inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext());
inferMMAType(eltypeA(), NVVM::MMAFrag::b, getContext());
std::pair<Type, unsigned> typeInfoC =
inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext());
inferMMAType(eltypeB(), NVVM::MMAFrag::c, getContext());
SmallVector<Type, 32> arguments;
arguments.append(typeInfoA.second, typeInfoA.first);
arguments.append(typeInfoB.second, typeInfoB.first);
arguments.append(typeInfoC.second, typeInfoC.first);
unsigned numArgs = arguments.size();
if (op.args().size() != numArgs)
return op.emitOpError() << "expected " << numArgs << " arguments";
if (args().size() != numArgs)
return emitOpError() << "expected " << numArgs << " arguments";
for (unsigned i = 0; i < numArgs; i++) {
if (op.args()[i].getType() != arguments[i])
return op.emitOpError()
<< "expected argument " << i << " to be of type " << arguments[i];
if (args()[i].getType() != arguments[i])
return emitOpError() << "expected argument " << i << " to be of type "
<< arguments[i];
}
Type dstType = LLVM::LLVMStructType::getLiteral(
op.getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
if (op.getType() != dstType)
return op.emitOpError("expected destination type is a structure of ")
getContext(), SmallVector<Type, 8>(typeInfoC.second, typeInfoC.first));
if (getType() != dstType)
return emitOpError("expected destination type is a structure of ")
<< typeInfoC.second << " elements of type " << typeInfoC.first;
return success();
}

View File

@ -28,17 +28,15 @@ void x86vector::X86VectorDialect::initialize() {
>();
}
static LogicalResult verify(x86vector::MaskCompressOp op) {
if (op.src() && op.constant_src())
return emitError(op.getLoc(), "cannot use both src and constant_src");
LogicalResult x86vector::MaskCompressOp::verify() {
if (src() && constant_src())
return emitError("cannot use both src and constant_src");
if (op.src() && (op.src().getType() != op.dst().getType()))
return emitError(op.getLoc(),
"failed to verify that src and dst have same type");
if (src() && (src().getType() != dst().getType()))
return emitError("failed to verify that src and dst have same type");
if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType()))
if (constant_src() && (constant_src()->getType() != dst().getType()))
return emitError(
op.getLoc(),
"failed to verify that constant_src and dst have same type");
return success();