From 38abdddf6f660c6d71d1c018ee1f2a1b46808f68 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 2 Feb 2022 10:16:28 -0800 Subject: [PATCH] [mlir][NFC] Update AMX/LLVM/NVVM/X86 vector operations to use `hasVerifier` instead of `verifier` The verifier field is deprecated, and slated for removal. Differential Revision: https://reviews.llvm.org/D118819 --- mlir/include/mlir/Dialect/AMX/AMX.td | 10 +- mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td | 67 +-- .../include/mlir/Dialect/LLVMIR/NVVMDialect.h | 4 + mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td | 27 +- .../mlir/Dialect/X86Vector/X86Vector.td | 2 +- .../Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp | 11 +- mlir/lib/Dialect/AMX/IR/AMXDialect.cpp | 58 ++- mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp | 482 +++++++++--------- mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp | 145 +++--- .../Dialect/X86Vector/IR/X86VectorDialect.cpp | 14 +- 10 files changed, 404 insertions(+), 416 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMX/AMX.td b/mlir/include/mlir/Dialect/AMX/AMX.td index 0d23cde90949..16e2e14504a8 100644 --- a/mlir/include/mlir/Dialect/AMX/AMX.td +++ b/mlir/include/mlir/Dialect/AMX/AMX.td @@ -91,7 +91,6 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> { %0 = amx.tile_zero : vector<16x16xbf16> ``` }]; - let verifier = [{ return ::verify(*this); }]; let results = (outs VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$res); let extraClassDeclaration = [{ @@ -100,6 +99,7 @@ def TileZeroOp : AMX_Op<"tile_zero", [NoSideEffect]> { } }]; let assemblyFormat = "attr-dict `:` type($res)"; + let hasVerifier = 1; } // @@ -120,7 +120,6 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> { %0 = amx.tile_load %arg0[%c0, %c0] : memref into vector<16x64xi8> ``` }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins Arg:$base, Variadic:$indices); let results = (outs @@ -135,6 +134,7 @@ def TileLoadOp : AMX_Op<"tile_load", [NoSideEffect]> { }]; let assemblyFormat = "$base `[` $indices `]` attr-dict `:` " "type($base) `into` type($res)"; + let hasVerifier = 1; } def TileStoreOp : AMX_Op<"tile_store"> { @@ -151,7 +151,6 @@ def TileStoreOp : AMX_Op<"tile_store"> { amx.tile_store %arg1[%c0, %c0], %0 : memref, vector<16x64xi8> ``` }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins Arg:$base, Variadic:$indices, VectorOfRankAndType<[2], [F32, BF16, I32, I8]>:$val); @@ -165,6 +164,7 @@ def TileStoreOp : AMX_Op<"tile_store"> { }]; let assemblyFormat = "$base `[` $indices `]` `,` $val attr-dict `:` " "type($base) `,` type($val)"; + let hasVerifier = 1; } // @@ -186,7 +186,6 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"] : vector<16x32xbf16>, vector<16x32xbf16>, vector<16x16xf32> ``` }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins VectorOfRankAndType<[2], [F32, BF16]>:$lhs, VectorOfRankAndType<[2], [F32, BF16]>:$rhs, VectorOfRankAndType<[2], [F32, BF16]>:$acc); @@ -204,6 +203,7 @@ def TileMulFOp : AMX_Op<"tile_mulf", [NoSideEffect, AllTypesMatch<["acc", "res"] }]; let assemblyFormat = "$lhs `,` $rhs `,` $acc attr-dict `:` " "type($lhs) `,` type($rhs) `,` type($acc) "; + let hasVerifier = 1; } def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"]>]> { @@ -224,7 +224,6 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"] : vector<16x64xi8>, vector<16x64xi8>, vector<16x16xi32> ``` }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins VectorOfRankAndType<[2], [I32, I8]>:$lhs, VectorOfRankAndType<[2], [I32, I8]>:$rhs, VectorOfRankAndType<[2], [I32, I8]>:$acc, @@ -245,6 +244,7 @@ def TileMulIOp : AMX_Op<"tile_muli", [NoSideEffect, AllTypesMatch<["acc", "res"] }]; let assemblyFormat = "$lhs (`zext` $isZextLhs^)? `,` $rhs (`zext` $isZextRhs^)? `,` $acc attr-dict `:` " "type($lhs) `,` type($rhs) `,` type($acc) "; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td index b4cf3adc7bb8..66d28c713bcf 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOps.td @@ -351,9 +351,7 @@ def LLVM_GEPOp : LLVM_Op<"getelementptr", [NoSideEffect]> { constexpr static int kDynamicIndex = std::numeric_limits::min(); }]; let hasFolder = 1; - let verifier = [{ - return ::verify(*this); - }]; + let hasVerifier = 1; } def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { @@ -386,7 +384,7 @@ def LLVM_LoadOp : LLVM_Op<"load">, MemoryOpWithAlignmentAndAttributes { CArg<"bool", "false">:$isNonTemporal)>]; let parser = [{ return parseLoadOp(parser, result); }]; let printer = [{ printLoadOp(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes { @@ -410,7 +408,7 @@ def LLVM_StoreOp : LLVM_Op<"store">, MemoryOpWithAlignmentAndAttributes { ]; let parser = [{ return parseStoreOp(parser, result); }]; let printer = [{ printStoreOp(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } // Casts. @@ -494,18 +492,18 @@ def LLVM_InvokeOp : LLVM_Op<"invoke", [ build($_builder, $_state, tys, /*callee=*/FlatSymbolRefAttr(), ops, normalOps, unwindOps, normal, unwind); }]>]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInvokeOp(parser, result); }]; let printer = [{ printInvokeOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_LandingpadOp : LLVM_Op<"landingpad"> { let arguments = (ins UnitAttr:$cleanup, Variadic); let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseLandingpadOp(parser, result); }]; let printer = [{ printLandingpadOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_CallOp : LLVM_Op<"call", @@ -562,9 +560,9 @@ def LLVM_CallOp : LLVM_Op<"call", build($_builder, $_state, results, StringAttr::get($_builder.getContext(), callee), operands); }]>]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseCallOp(parser, result); }]; let printer = [{ printCallOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, AnyInteger:$position); @@ -575,9 +573,9 @@ def LLVM_ExtractElementOp : LLVM_Op<"extractelement", [NoSideEffect]> { let builders = [ OpBuilder<(ins "Value":$vector, "Value":$position, CArg<"ArrayRef", "{}">:$attrs)>]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseExtractElementOp(parser, result); }]; let printer = [{ printExtractElementOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> { let arguments = (ins LLVM_AnyAggregate:$container, ArrayAttr:$position); @@ -586,10 +584,10 @@ def LLVM_ExtractValueOp : LLVM_Op<"extractvalue", [NoSideEffect]> { $res = builder.CreateExtractValue($container, extractPosition($position)); }]; let builders = [LLVM_OneResultOpBuilder]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseExtractValueOp(parser, result); }]; let printer = [{ printExtractValueOp(p, *this); }]; let hasFolder = 1; + let hasVerifier = 1; } def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$vector, LLVM_PrimitiveType:$value, @@ -599,9 +597,9 @@ def LLVM_InsertElementOp : LLVM_Op<"insertelement", [NoSideEffect]> { $res = builder.CreateInsertElement($vector, $value, $position); }]; let builders = [LLVM_OneResultOpBuilder]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInsertElementOp(parser, result); }]; let printer = [{ printInsertElementOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> { let arguments = (ins LLVM_AnyAggregate:$container, LLVM_PrimitiveType:$value, @@ -616,9 +614,9 @@ def LLVM_InsertValueOp : LLVM_Op<"insertvalue", [NoSideEffect]> { [{ build($_builder, $_state, container.getType(), container, value, position); }]>]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return parseInsertValueOp(parser, result); }]; let printer = [{ printInsertValueOp(p, *this); }]; + let hasVerifier = 1; } def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> { let arguments = (ins LLVM_AnyVector:$v1, LLVM_AnyVector:$v2, ArrayAttr:$mask); @@ -631,16 +629,9 @@ def LLVM_ShuffleVectorOp : LLVM_Op<"shufflevector", [NoSideEffect]> { let builders = [ OpBuilder<(ins "Value":$v1, "Value":$v2, "ArrayAttr":$mask, CArg<"ArrayRef", "{}">:$attrs)>]; - let verifier = [{ - auto type1 = getV1().getType(); - auto type2 = getV2().getType(); - if (::mlir::LLVM::getVectorElementType(type1) != - ::mlir::LLVM::getVectorElementType(type2)) - return emitOpError("expected matching LLVM IR Dialect element types"); - return success(); - }]; let parser = [{ return parseShuffleVectorOp(parser, result); }]; let printer = [{ printShuffleVectorOp(p, *this); }]; + let hasVerifier = 1; } // Misc operations. @@ -718,27 +709,15 @@ def LLVM_ReturnOp : LLVM_TerminatorOp<"return", [NoSideEffect]> { builder.CreateRetVoid(); }]; - let verifier = [{ - if (getNumOperands() > 1) - return emitOpError("expects at most 1 operand"); - return success(); - }]; - let parser = [{ return parseReturnOp(parser, result); }]; let printer = [{ printReturnOp(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def LLVM_ResumeOp : LLVM_TerminatorOp<"resume", []> { let arguments = (ins LLVM_Type:$value); string llvmBuilder = [{ builder.CreateResume($value); }]; - let verifier = [{ - if (!isa_and_nonnull(getValue().getDefiningOp())) - return emitOpError("expects landingpad value as operand"); - // No check for personality of function - landingpad op verifies it. - return success(); - }]; - let assemblyFormat = "$value attr-dict `:` type($value)"; + let hasVerifier = 1; } def LLVM_UnreachableOp : LLVM_TerminatorOp<"unreachable", []> { string llvmBuilder = [{ builder.CreateUnreachable(); }]; @@ -761,7 +740,6 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", VariadicSuccessor:$caseDestinations ); - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = [{ $value `:` type($value) `,` $defaultDestination (`(` $defaultOperands^ `:` type($defaultOperands) `)`)? @@ -769,6 +747,7 @@ def LLVM_SwitchOp : LLVM_TerminatorOp<"switch", $caseOperands, type($caseOperands)) `]` attr-dict }]; + let hasVerifier = 1; let builders = [ OpBuilder<(ins "Value":$value, @@ -924,7 +903,7 @@ def LLVM_AddressOfOp : LLVM_Op<"mlir.addressof", [NoSideEffect]> { }]; let assemblyFormat = "$global_name attr-dict `:` type($res)"; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def LLVM_MetadataOp : LLVM_Op<"metadata", [ @@ -1175,7 +1154,7 @@ def LLVM_GlobalOp : LLVM_Op<"mlir.global", let printer = "printGlobalOp(p, *this);"; let parser = "return parseGlobalOp(parser, result);"; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [ @@ -1205,8 +1184,8 @@ def LLVM_GlobalCtorsOp : LLVM_Op<"mlir.global_ctors", [ ``` }]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = "attr-dict"; + let hasVerifier = 1; } def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [ @@ -1234,8 +1213,8 @@ def LLVM_GlobalDtorsOp : LLVM_Op<"mlir.global_dtors", [ ``` }]; - let verifier = [{ return ::verify(*this); }]; let assemblyFormat = "attr-dict"; + let hasVerifier = 1; } def LLVM_LLVMFuncOp : LLVM_Op<"func", [ @@ -1310,9 +1289,9 @@ def LLVM_LLVMFuncOp : LLVM_Op<"func", [ LogicalResult verifyType(); }]; - let verifier = [{ return ::verify(*this); }]; let printer = [{ printLLVMFuncOp(p, *this); }]; let parser = [{ return parseLLVMFuncOp(parser, result); }]; + let hasVerifier = 1; } def LLVM_NullOp @@ -1402,8 +1381,8 @@ def LLVM_ConstantOp let results = (outs LLVM_Type:$res); let builders = [LLVM_OneResultOpBuilder]; let assemblyFormat = "`(` $value `)` attr-dict `:` type($res)"; - let verifier = [{ return ::verify(*this); }]; let hasFolder = 1; + let hasVerifier = 1; } // Operations that correspond to LLVM intrinsics. With MLIR operation set being @@ -1848,7 +1827,7 @@ def LLVM_AtomicRMWOp : LLVM_Op<"atomicrmw"> { }]; let parser = [{ return parseAtomicRMWOp(parser, result); }]; let printer = [{ printAtomicRMWOp(p, *this); }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def LLVM_AtomicCmpXchgType : AnyTypeOf<[AnyInteger, LLVM_AnyPointer]>; @@ -1878,7 +1857,7 @@ def LLVM_AtomicCmpXchgOp : LLVM_Op<"cmpxchg"> { }]; let parser = [{ return parseAtomicCmpXchgOp(parser, result); }]; let printer = [{ printAtomicCmpXchgOp(p, *this); }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def LLVM_AssumeOp : LLVM_Op<"intr.assume", []> { @@ -1901,7 +1880,7 @@ def LLVM_FenceOp : LLVM_Op<"fence"> { }]; let parser = [{ return parseFenceOp(parser, result); }]; let printer = [{ printFenceOp(p, *this); }]; - let verifier = "return ::verify(*this);"; + let hasVerifier = 1; } def AsmATT : LLVM_EnumAttrCase< diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h index 3fd7c5bc0660..de942f6fb4d3 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMDialect.h @@ -22,12 +22,16 @@ #include "mlir/Dialect/LLVMIR/NVVMOpsEnums.h.inc" +namespace mlir { +namespace NVVM { /// Return the element type and number of elements associated with a wmma matrix /// of given chracteristics. This matches the logic in IntrinsicsNVVM.td /// WMMA_REGS structure. std::pair inferMMAType(mlir::NVVM::MMATypes type, mlir::NVVM::MMAFrag frag, mlir::MLIRContext *context); +} // namespace NVVM +} // namespace mlir ///// Ops ///// #define GET_ATTRDEF_CLASSES diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index d26a0b2c6f30..4a55ddd96cb7 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -131,22 +131,11 @@ def NVVM_ShflOp : $res = createIntrinsicCall(builder, intId, {$dst, $val, $offset, $mask_and_clamp}); }]; - let verifier = [{ - if (!(*this)->getAttrOfType("return_value_and_is_valid")) - return success(); - auto type = getType().dyn_cast(); - auto elementType = (type && type.getBody().size() == 2) - ? type.getBody()[1].dyn_cast() - : nullptr; - if (!elementType || elementType.getWidth() != 1) - return emitError("expected return type to be a two-element struct with " - "i1 as the second element"); - return success(); - }]; let assemblyFormat = [{ $kind $dst `,` $val `,` $offset `,` $mask_and_clamp attr-dict `:` type($val) `->` type($res) }]; + let hasVerifier = 1; } def NVVM_VoteBallotOp : @@ -183,12 +172,8 @@ def NVVM_CpAsyncOp : NVVM_Op<"cp.async.shared.global">, } createIntrinsicCall(builder, id, {$dst, $src}); }]; - let verifier = [{ - if (size() != 4 && size() != 8 && size() != 16) - return emitError("expected byte size to be either 4, 8 or 16."); - return success(); - }]; let assemblyFormat = "$dst `,` $src `,` $size attr-dict"; + let hasVerifier = 1; } def NVVM_CpAsyncCommitGroupOp : NVVM_Op<"cp.async.commit.group"> { @@ -220,7 +205,7 @@ def NVVM_MmaOp : builder, llvm::Intrinsic::nvvm_mma_m8n8k4_row_col_f32_f32, $args); }]; let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } /// Helpers to instantiate different version of wmma intrinsics. @@ -538,7 +523,7 @@ def NVVM_WMMALoadOp: NVVM_Op<"wmma.load">, }]; let assemblyFormat = "$ptr `,` $stride attr-dict `:` functional-type($ptr, $res)"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">, @@ -593,7 +578,7 @@ def NVVM_WMMAStoreOp : NVVM_Op<"wmma.store">, }]; let assemblyFormat = "$ptr `,` $stride `,` $args attr-dict `:` type($ptr) `,` type($args)"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } // Base class for all the variants of WMMA mmaOps that may be defined. @@ -647,7 +632,7 @@ def NVVM_WMMAMmaOp : NVVM_Op<"wmma.mma">, }]; let assemblyFormat = "$args attr-dict `:` functional-type($args, $res)"; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; } #endif // NVVMIR_OPS diff --git a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td index bda3440aa974..03fa89ef899a 100644 --- a/mlir/include/mlir/Dialect/X86Vector/X86Vector.td +++ b/mlir/include/mlir/Dialect/X86Vector/X86Vector.td @@ -76,7 +76,6 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect, with their respective bit set in writemask `k`) to `dst`, and pass through the remaining elements from `src`. }]; - let verifier = [{ return ::verify(*this); }]; let arguments = (ins VectorOfLengthAndType<[16, 8], [I1]>:$k, VectorOfLengthAndType<[16, 8], @@ -88,6 +87,7 @@ def MaskCompressOp : AVX512_Op<"mask.compress", [NoSideEffect, [F32, I32, F64, I64]>:$dst); let assemblyFormat = "$k `,` $a (`,` $src^)? attr-dict" " `:` type($dst) (`,` type($src)^)?"; + let hasVerifier = 1; } def MaskCompressIntrOp : AVX512_IntrOverloadedOp<"mask.compress", [ diff --git a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp index f52e589ad9df..4a3ba46233ff 100644 --- a/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/WmmaOpsToNvvm.cpp @@ -358,22 +358,19 @@ struct WmmaElementwiseOpToNVVMLowering } // namespace -namespace mlir { - /// Return the LLVMStructureType corresponding to the MMAMatrixType `type`. -LLVM::LLVMStructType convertMMAToLLVMType(gpu::MMAMatrixType type) { +LLVM::LLVMStructType mlir::convertMMAToLLVMType(gpu::MMAMatrixType type) { NVVM::MMAFrag frag = convertOperand(type.getOperand()); NVVM::MMATypes eltType = getElementType(type); std::pair typeInfo = - inferMMAType(eltType, frag, type.getContext()); + NVVM::inferMMAType(eltType, frag, type.getContext()); return LLVM::LLVMStructType::getLiteral( type.getContext(), SmallVector(typeInfo.second, typeInfo.first)); } -void populateGpuWMMAToNVVMConversionPatterns(LLVMTypeConverter &converter, - RewritePatternSet &patterns) { +void mlir::populateGpuWMMAToNVVMConversionPatterns( + LLVMTypeConverter &converter, RewritePatternSet &patterns) { patterns.insert(converter); } -} // namespace mlir diff --git a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp index c5cf1f41d709..9ea96791cef4 100644 --- a/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp +++ b/mlir/lib/Dialect/AMX/IR/AMXDialect.cpp @@ -52,53 +52,55 @@ static LogicalResult verifyMultShape(Operation *op, VectorType atp, return success(); } -static LogicalResult verify(amx::TileZeroOp op) { - return verifyTileSize(op, op.getVectorType()); +LogicalResult amx::TileZeroOp::verify() { + return verifyTileSize(*this, getVectorType()); } -static LogicalResult verify(amx::TileLoadOp op) { - unsigned rank = op.getMemRefType().getRank(); - if (llvm::size(op.indices()) != rank) - return op.emitOpError("requires ") << rank << " indices"; - return verifyTileSize(op, op.getVectorType()); +LogicalResult amx::TileLoadOp::verify() { + unsigned rank = getMemRefType().getRank(); + if (indices().size() != rank) + return emitOpError("requires ") << rank << " indices"; + return verifyTileSize(*this, getVectorType()); } -static LogicalResult verify(amx::TileStoreOp op) { - unsigned rank = op.getMemRefType().getRank(); - if (llvm::size(op.indices()) != rank) - return op.emitOpError("requires ") << rank << " indices"; - return verifyTileSize(op, op.getVectorType()); +LogicalResult amx::TileStoreOp::verify() { + unsigned rank = getMemRefType().getRank(); + if (indices().size() != rank) + return emitOpError("requires ") << rank << " indices"; + return verifyTileSize(*this, getVectorType()); } -static LogicalResult verify(amx::TileMulFOp op) { - VectorType aType = op.getLhsVectorType(); - VectorType bType = op.getRhsVectorType(); - VectorType cType = op.getVectorType(); - if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || - failed(verifyTileSize(op, cType)) || - failed(verifyMultShape(op, aType, bType, cType, 1))) +LogicalResult amx::TileMulFOp::verify() { + VectorType aType = getLhsVectorType(); + VectorType bType = getRhsVectorType(); + VectorType cType = getVectorType(); + if (failed(verifyTileSize(*this, aType)) || + failed(verifyTileSize(*this, bType)) || + failed(verifyTileSize(*this, cType)) || + failed(verifyMultShape(*this, aType, bType, cType, 1))) return failure(); Type ta = aType.getElementType(); Type tb = bType.getElementType(); Type tc = cType.getElementType(); if (!ta.isBF16() || !tb.isBF16() || !tc.isF32()) - return op.emitOpError("unsupported type combination"); + return emitOpError("unsupported type combination"); return success(); } -static LogicalResult verify(amx::TileMulIOp op) { - VectorType aType = op.getLhsVectorType(); - VectorType bType = op.getRhsVectorType(); - VectorType cType = op.getVectorType(); - if (failed(verifyTileSize(op, aType)) || failed(verifyTileSize(op, bType)) || - failed(verifyTileSize(op, cType)) || - failed(verifyMultShape(op, aType, bType, cType, 2))) +LogicalResult amx::TileMulIOp::verify() { + VectorType aType = getLhsVectorType(); + VectorType bType = getRhsVectorType(); + VectorType cType = getVectorType(); + if (failed(verifyTileSize(*this, aType)) || + failed(verifyTileSize(*this, bType)) || + failed(verifyTileSize(*this, cType)) || + failed(verifyMultShape(*this, aType, bType, cType, 2))) return failure(); Type ta = aType.getElementType(); Type tb = bType.getElementType(); Type tc = cType.getElementType(); if (!ta.isInteger(8) || !tb.isInteger(8) || !tc.isInteger(32)) - return op.emitOpError("unsupported type combination"); + return emitOpError("unsupported type combination"); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp index 30eec4369dd1..fb2cd2546d5f 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMDialect.cpp @@ -334,18 +334,17 @@ static void printSwitchOpCases(OpAsmPrinter &p, SwitchOp op, Type flagType, p.printNewline(); } -static LogicalResult verify(SwitchOp op) { - if ((!op.getCaseValues() && !op.getCaseDestinations().empty()) || - (op.getCaseValues() && - op.getCaseValues()->size() != - static_cast(op.getCaseDestinations().size()))) - return op.emitOpError("expects number of case values to match number of " - "case destinations"); - if (op.getBranchWeights() && - op.getBranchWeights()->size() != op.getNumSuccessors()) - return op.emitError("expects number of branch weights to match number of " - "successors: ") - << op.getBranchWeights()->size() << " vs " << op.getNumSuccessors(); +LogicalResult SwitchOp::verify() { + if ((!getCaseValues() && !getCaseDestinations().empty()) || + (getCaseValues() && + getCaseValues()->size() != + static_cast(getCaseDestinations().size()))) + return emitOpError("expects number of case values to match number of " + "case destinations"); + if (getBranchWeights() && getBranchWeights()->size() != getNumSuccessors()) + return emitError("expects number of branch weights to match number of " + "successors: ") + << getBranchWeights()->size() << " vs " << getNumSuccessors(); return success(); } @@ -518,11 +517,11 @@ static void printGEPIndices(OpAsmPrinter &printer, LLVM::GEPOp gepOp, }); } -LogicalResult verify(LLVM::GEPOp gepOp) { +LogicalResult LLVM::GEPOp::verify() { SmallVector indices; SmallVector structSizes; - findKnownStructIndices(gepOp.getBase().getType(), indices, &structSizes); - DenseIntElementsAttr structIndices = gepOp.getStructIndices(); + findKnownStructIndices(getBase().getType(), indices, &structSizes); + DenseIntElementsAttr structIndices = getStructIndices(); for (unsigned i : llvm::seq(0, indices.size())) { unsigned index = indices[i]; // GEP may not be indexing as deep as some structs nested in the type. @@ -531,11 +530,11 @@ LogicalResult verify(LLVM::GEPOp gepOp) { int32_t staticIndex = structIndices.getValues()[index]; if (staticIndex == LLVM::GEPOp::kDynamicIndex) - return gepOp.emitOpError() << "expected index " << index - << " indexing a struct to be constant"; + return emitOpError() << "expected index " << index + << " indexing a struct to be constant"; if (staticIndex < 0 || static_cast(staticIndex) >= structSizes[i]) - return gepOp.emitOpError() - << "index " << index << " indexing a struct is out of bounds"; + return emitOpError() << "index " << index + << " indexing a struct is out of bounds"; } return success(); } @@ -613,9 +612,7 @@ static LogicalResult verifyMemoryOpMetadata(Operation *op) { return success(); } -static LogicalResult verify(LoadOp op) { - return verifyMemoryOpMetadata(op.getOperation()); -} +LogicalResult LoadOp::verify() { return verifyMemoryOpMetadata(*this); } void LoadOp::build(OpBuilder &builder, OperationState &result, Type t, Value addr, unsigned alignment, bool isVolatile, @@ -675,9 +672,7 @@ static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { // Builder, printer and parser for LLVM::StoreOp. //===----------------------------------------------------------------------===// -static LogicalResult verify(StoreOp op) { - return verifyMemoryOpMetadata(op.getOperation()); -} +LogicalResult StoreOp::verify() { return verifyMemoryOpMetadata(*this); } void StoreOp::build(OpBuilder &builder, OperationState &result, Value value, Value addr, unsigned alignment, bool isVolatile, @@ -739,19 +734,18 @@ InvokeOp::getMutableSuccessorOperands(unsigned index) { : getUnwindDestOperandsMutable(); } -static LogicalResult verify(InvokeOp op) { - if (op.getNumResults() > 1) - return op.emitOpError("must have 0 or 1 result"); +LogicalResult InvokeOp::verify() { + if (getNumResults() > 1) + return emitOpError("must have 0 or 1 result"); - Block *unwindDest = op.getUnwindDest(); + Block *unwindDest = getUnwindDest(); if (unwindDest->empty()) - return op.emitError( - "must have at least one operation in unwind destination"); + return emitError("must have at least one operation in unwind destination"); // In unwind destination, first operation must be LandingpadOp if (!isa(unwindDest->front())) - return op.emitError("first operation in unwind destination should be a " - "llvm.landingpad operation"); + return emitError("first operation in unwind destination should be a " + "llvm.landingpad operation"); return success(); } @@ -880,20 +874,20 @@ static ParseResult parseInvokeOp(OpAsmParser &parser, OperationState &result) { /// Verifying/Printing/Parsing for LLVM::LandingpadOp. ///===----------------------------------------------------------------------===// -static LogicalResult verify(LandingpadOp op) { +LogicalResult LandingpadOp::verify() { Value value; - if (LLVMFuncOp func = op->getParentOfType()) { + if (LLVMFuncOp func = (*this)->getParentOfType()) { if (!func.getPersonality().hasValue()) - return op.emitError( + return emitError( "llvm.landingpad needs to be in a function with a personality"); } - if (!op.getCleanup() && op.getOperands().empty()) - return op.emitError("landingpad instruction expects at least one clause or " - "cleanup attribute"); + if (!getCleanup() && getOperands().empty()) + return emitError("landingpad instruction expects at least one clause or " + "cleanup attribute"); - for (unsigned idx = 0, ie = op.getNumOperands(); idx < ie; idx++) { - value = op.getOperand(idx); + for (unsigned idx = 0, ie = getNumOperands(); idx < ie; idx++) { + value = getOperand(idx); bool isFilter = value.getType().isa(); if (isFilter) { // FIXME: Verify filter clauses when arrays are appropriately handled @@ -903,8 +897,7 @@ static LogicalResult verify(LandingpadOp op) { if (auto bcOp = value.getDefiningOp()) { if (auto addrOp = bcOp.getArg().getDefiningOp()) continue; - return op.emitError("constant clauses expected") - .attachNote(bcOp.getLoc()) + return emitError("constant clauses expected").attachNote(bcOp.getLoc()) << "global addresses expected as operand to " "bitcast used in clauses for landingpad"; } @@ -913,7 +906,7 @@ static LogicalResult verify(LandingpadOp op) { continue; if (value.getDefiningOp()) continue; - return op.emitError("clause #") + return emitError("clause #") << idx << " is not a known constant - null, addressof, bitcast"; } } @@ -970,9 +963,9 @@ static ParseResult parseLandingpadOp(OpAsmParser &parser, // Verifying/Printing/parsing for LLVM::CallOp. //===----------------------------------------------------------------------===// -static LogicalResult verify(CallOp &op) { - if (op.getNumResults() > 1) - return op.emitOpError("must have 0 or 1 result"); +LogicalResult CallOp::verify() { + if (getNumResults() > 1) + return emitOpError("must have 0 or 1 result"); // Type for the callee, we'll get it differently depending if it is a direct // or indirect call. @@ -981,75 +974,73 @@ static LogicalResult verify(CallOp &op) { bool isIndirect = false; // If this is an indirect call, the callee attribute is missing. - FlatSymbolRefAttr calleeName = op.getCalleeAttr(); + FlatSymbolRefAttr calleeName = getCalleeAttr(); if (!calleeName) { isIndirect = true; - if (!op.getNumOperands()) - return op.emitOpError( + if (!getNumOperands()) + return emitOpError( "must have either a `callee` attribute or at least an operand"); - auto ptrType = op.getOperand(0).getType().dyn_cast(); + auto ptrType = getOperand(0).getType().dyn_cast(); if (!ptrType) - return op.emitOpError("indirect call expects a pointer as callee: ") + return emitOpError("indirect call expects a pointer as callee: ") << ptrType; fnType = ptrType.getElementType(); } else { Operation *callee = - SymbolTable::lookupNearestSymbolFrom(op, calleeName.getAttr()); + SymbolTable::lookupNearestSymbolFrom(*this, calleeName.getAttr()); if (!callee) - return op.emitOpError() + return emitOpError() << "'" << calleeName.getValue() << "' does not reference a symbol in the current scope"; auto fn = dyn_cast(callee); if (!fn) - return op.emitOpError() << "'" << calleeName.getValue() - << "' does not reference a valid LLVM function"; + return emitOpError() << "'" << calleeName.getValue() + << "' does not reference a valid LLVM function"; fnType = fn.getType(); } LLVMFunctionType funcType = fnType.dyn_cast(); if (!funcType) - return op.emitOpError("callee does not have a functional type: ") << fnType; + return emitOpError("callee does not have a functional type: ") << fnType; // Verify that the operand and result types match the callee. if (!funcType.isVarArg() && - funcType.getNumParams() != (op.getNumOperands() - isIndirect)) - return op.emitOpError() - << "incorrect number of operands (" - << (op.getNumOperands() - isIndirect) - << ") for callee (expecting: " << funcType.getNumParams() << ")"; + funcType.getNumParams() != (getNumOperands() - isIndirect)) + return emitOpError() << "incorrect number of operands (" + << (getNumOperands() - isIndirect) + << ") for callee (expecting: " + << funcType.getNumParams() << ")"; - if (funcType.getNumParams() > (op.getNumOperands() - isIndirect)) - return op.emitOpError() << "incorrect number of operands (" - << (op.getNumOperands() - isIndirect) - << ") for varargs callee (expecting at least: " - << funcType.getNumParams() << ")"; + if (funcType.getNumParams() > (getNumOperands() - isIndirect)) + return emitOpError() << "incorrect number of operands (" + << (getNumOperands() - isIndirect) + << ") for varargs callee (expecting at least: " + << funcType.getNumParams() << ")"; for (unsigned i = 0, e = funcType.getNumParams(); i != e; ++i) - if (op.getOperand(i + isIndirect).getType() != funcType.getParamType(i)) - return op.emitOpError() << "operand type mismatch for operand " << i - << ": " << op.getOperand(i + isIndirect).getType() - << " != " << funcType.getParamType(i); + if (getOperand(i + isIndirect).getType() != funcType.getParamType(i)) + return emitOpError() << "operand type mismatch for operand " << i << ": " + << getOperand(i + isIndirect).getType() + << " != " << funcType.getParamType(i); - if (op.getNumResults() == 0 && + if (getNumResults() == 0 && !funcType.getReturnType().isa()) - return op.emitOpError() << "expected function call to produce a value"; + return emitOpError() << "expected function call to produce a value"; - if (op.getNumResults() != 0 && + if (getNumResults() != 0 && funcType.getReturnType().isa()) - return op.emitOpError() + return emitOpError() << "calling function with void result must not produce values"; - if (op.getNumResults() > 1) - return op.emitOpError() + if (getNumResults() > 1) + return emitOpError() << "expected LLVM function call to produce 0 or 1 result"; - if (op.getNumResults() && - op.getResult(0).getType() != funcType.getReturnType()) - return op.emitOpError() - << "result type mismatch: " << op.getResult(0).getType() - << " != " << funcType.getReturnType(); + if (getNumResults() && getResult(0).getType() != funcType.getReturnType()) + return emitOpError() << "result type mismatch: " << getResult(0).getType() + << " != " << funcType.getReturnType(); return success(); } @@ -1200,17 +1191,17 @@ static ParseResult parseExtractElementOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(ExtractElementOp op) { - Type vectorType = op.getVector().getType(); +LogicalResult ExtractElementOp::verify() { + Type vectorType = getVector().getType(); if (!LLVM::isCompatibleVectorType(vectorType)) - return op->emitOpError("expected LLVM dialect-compatible vector type for " - "operand #1, got") + return emitOpError("expected LLVM dialect-compatible vector type for " + "operand #1, got") << vectorType; Type valueType = LLVM::getVectorElementType(vectorType); - if (valueType != op.getRes().getType()) - return op.emitOpError() << "Type mismatch: extracting from " << vectorType - << " should produce " << valueType - << " but this op returns " << op.getRes().getType(); + if (valueType != getRes().getType()) + return emitOpError() << "Type mismatch: extracting from " << vectorType + << " should produce " << valueType + << " but this op returns " << getRes().getType(); return success(); } @@ -1367,17 +1358,17 @@ OpFoldResult LLVM::ExtractValueOp::fold(ArrayRef operands) { return {}; } -static LogicalResult verify(ExtractValueOp op) { - Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), - op.getPositionAttr(), op); +LogicalResult ExtractValueOp::verify() { + Type valueType = getInsertExtractValueElementType(getContainer().getType(), + getPositionAttr(), *this); if (!valueType) return failure(); - if (op.getRes().getType() != valueType) - return op.emitOpError() - << "Type mismatch: extracting from " << op.getContainer().getType() - << " should produce " << valueType << " but this op returns " - << op.getRes().getType(); + if (getRes().getType() != valueType) + return emitOpError() << "Type mismatch: extracting from " + << getContainer().getType() << " should produce " + << valueType << " but this op returns " + << getRes().getType(); return success(); } @@ -1423,14 +1414,15 @@ static ParseResult parseInsertElementOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(InsertElementOp op) { - Type valueType = LLVM::getVectorElementType(op.getVector().getType()); - if (valueType != op.getValue().getType()) - return op.emitOpError() - << "Type mismatch: cannot insert " << op.getValue().getType() - << " into " << op.getVector().getType(); +LogicalResult InsertElementOp::verify() { + Type valueType = LLVM::getVectorElementType(getVector().getType()); + if (valueType != getValue().getType()) + return emitOpError() << "Type mismatch: cannot insert " + << getValue().getType() << " into " + << getVector().getType(); return success(); } + //===----------------------------------------------------------------------===// // Printing/parsing for LLVM::InsertValueOp. //===----------------------------------------------------------------------===// @@ -1473,16 +1465,16 @@ static ParseResult parseInsertValueOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(InsertValueOp op) { - Type valueType = getInsertExtractValueElementType(op.getContainer().getType(), - op.getPositionAttr(), op); +LogicalResult InsertValueOp::verify() { + Type valueType = getInsertExtractValueElementType(getContainer().getType(), + getPositionAttr(), *this); if (!valueType) return failure(); - if (op.getValue().getType() != valueType) - return op.emitOpError() - << "Type mismatch: cannot insert " << op.getValue().getType() - << " into " << op.getContainer().getType(); + if (getValue().getType() != valueType) + return emitOpError() << "Type mismatch: cannot insert " + << getValue().getType() << " into " + << getContainer().getType(); return success(); } @@ -1519,28 +1511,28 @@ static ParseResult parseReturnOp(OpAsmParser &parser, OperationState &result) { return success(); } -static LogicalResult verify(ReturnOp op) { - if (op->getNumOperands() > 1) - return op->emitOpError("expected at most 1 operand"); +LogicalResult ReturnOp::verify() { + if (getNumOperands() > 1) + return emitOpError("expected at most 1 operand"); - if (auto parent = op->getParentOfType()) { + if (auto parent = (*this)->getParentOfType()) { Type expectedType = parent.getType().getReturnType(); if (expectedType.isa()) { - if (op->getNumOperands() == 0) + if (getNumOperands() == 0) return success(); - InFlightDiagnostic diag = op->emitOpError("expected no operands"); + InFlightDiagnostic diag = emitOpError("expected no operands"); diag.attachNote(parent->getLoc()) << "when returning from function"; return diag; } - if (op->getNumOperands() == 0) { + if (getNumOperands() == 0) { if (expectedType.isa()) return success(); - InFlightDiagnostic diag = op->emitOpError("expected 1 operand"); + InFlightDiagnostic diag = emitOpError("expected 1 operand"); diag.attachNote(parent->getLoc()) << "when returning from function"; return diag; } - if (expectedType != op->getOperand(0).getType()) { - InFlightDiagnostic diag = op->emitOpError("mismatching result types"); + if (expectedType != getOperand(0).getType()) { + InFlightDiagnostic diag = emitOpError("mismatching result types"); diag.attachNote(parent->getLoc()) << "when returning from function"; return diag; } @@ -1548,6 +1540,17 @@ static LogicalResult verify(ReturnOp op) { return success(); } +//===----------------------------------------------------------------------===// +// ResumeOp +//===----------------------------------------------------------------------===// + +LogicalResult ResumeOp::verify() { + if (!getValue().getDefiningOp()) + return emitOpError("expects landingpad value as operand"); + // No check for personality of function - landingpad op verifies it. + return success(); +} + //===----------------------------------------------------------------------===// // Verifier for LLVM::AddressOfOp. //===----------------------------------------------------------------------===// @@ -1572,22 +1575,22 @@ LLVMFuncOp AddressOfOp::getFunction() { getGlobalName()); } -static LogicalResult verify(AddressOfOp op) { - auto global = op.getGlobal(); - auto function = op.getFunction(); +LogicalResult AddressOfOp::verify() { + auto global = getGlobal(); + auto function = getFunction(); if (!global && !function) - return op.emitOpError( + return emitOpError( "must reference a global defined by 'llvm.mlir.global' or 'llvm.func'"); if (global && LLVM::LLVMPointerType::get(global.getType(), global.getAddrSpace()) != - op.getResult().getType()) - return op.emitOpError( + getResult().getType()) + return emitOpError( "the type must be a pointer to the type of the referenced global"); - if (function && LLVM::LLVMPointerType::get(function.getType()) != - op.getResult().getType()) - return op.emitOpError( + if (function && + LLVM::LLVMPointerType::get(function.getType()) != getResult().getType()) + return emitOpError( "the type must be a pointer to the type of the referenced function"); return success(); @@ -1791,60 +1794,60 @@ static bool isZeroAttribute(Attribute value) { return false; } -static LogicalResult verify(GlobalOp op) { - if (!LLVMPointerType::isValidElementType(op.getType())) - return op.emitOpError( +LogicalResult GlobalOp::verify() { + if (!LLVMPointerType::isValidElementType(getType())) + return emitOpError( "expects type to be a valid element type for an LLVM pointer"); - if (op->getParentOp() && !satisfiesLLVMModule(op->getParentOp())) - return op.emitOpError("must appear at the module level"); + if ((*this)->getParentOp() && !satisfiesLLVMModule((*this)->getParentOp())) + return emitOpError("must appear at the module level"); - if (auto strAttr = op.getValueOrNull().dyn_cast_or_null()) { - auto type = op.getType().dyn_cast(); + if (auto strAttr = getValueOrNull().dyn_cast_or_null()) { + auto type = getType().dyn_cast(); IntegerType elementType = type ? type.getElementType().dyn_cast() : nullptr; if (!elementType || elementType.getWidth() != 8 || type.getNumElements() != strAttr.getValue().size()) - return op.emitOpError( + return emitOpError( "requires an i8 array type of the length equal to that of the string " "attribute"); } - if (Block *b = op.getInitializerBlock()) { + if (Block *b = getInitializerBlock()) { ReturnOp ret = cast(b->getTerminator()); if (ret.operand_type_begin() == ret.operand_type_end()) - return op.emitOpError("initializer region cannot return void"); - if (*ret.operand_type_begin() != op.getType()) - return op.emitOpError("initializer region type ") + return emitOpError("initializer region cannot return void"); + if (*ret.operand_type_begin() != getType()) + return emitOpError("initializer region type ") << *ret.operand_type_begin() << " does not match global type " - << op.getType(); + << getType(); - if (op.getValueOrNull()) - return op.emitOpError("cannot have both initializer value and region"); + if (getValueOrNull()) + return emitOpError("cannot have both initializer value and region"); } - if (op.getLinkage() == Linkage::Common) { - if (Attribute value = op.getValueOrNull()) { + if (getLinkage() == Linkage::Common) { + if (Attribute value = getValueOrNull()) { if (!isZeroAttribute(value)) { - return op.emitOpError() + return emitOpError() << "expected zero value for '" << stringifyLinkage(Linkage::Common) << "' linkage"; } } } - if (op.getLinkage() == Linkage::Appending) { - if (!op.getType().isa()) { - return op.emitOpError() - << "expected array type for '" - << stringifyLinkage(Linkage::Appending) << "' linkage"; + if (getLinkage() == Linkage::Appending) { + if (!getType().isa()) { + return emitOpError() << "expected array type for '" + << stringifyLinkage(Linkage::Appending) + << "' linkage"; } } - Optional alignAttr = op.getAlignment(); + Optional alignAttr = getAlignment(); if (alignAttr.hasValue()) { uint64_t value = alignAttr.getValue(); if (!llvm::isPowerOf2_64(value)) - return op->emitError() << "alignment attribute is not a power of 2"; + return emitError() << "alignment attribute is not a power of 2"; } return success(); @@ -1864,9 +1867,9 @@ GlobalCtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -static LogicalResult verify(GlobalCtorsOp op) { - if (op.getCtors().size() != op.getPriorities().size()) - return op.emitError( +LogicalResult GlobalCtorsOp::verify() { + if (getCtors().size() != getPriorities().size()) + return emitError( "mismatch between the number of ctors and the number of priorities"); return success(); } @@ -1885,9 +1888,9 @@ GlobalDtorsOp::verifySymbolUses(SymbolTableCollection &symbolTable) { return success(); } -static LogicalResult verify(GlobalDtorsOp op) { - if (op.getDtors().size() != op.getPriorities().size()) - return op.emitError( +LogicalResult GlobalDtorsOp::verify() { + if (getDtors().size() != getPriorities().size()) + return emitError( "mismatch between the number of dtors and the number of priorities"); return success(); } @@ -1940,6 +1943,14 @@ static ParseResult parseShuffleVectorOp(OpAsmParser &parser, return success(); } +LogicalResult ShuffleVectorOp::verify() { + Type type1 = getV1().getType(); + Type type2 = getV2().getType(); + if (LLVM::getVectorElementType(type1) != LLVM::getVectorElementType(type2)) + return emitOpError("expected matching LLVM IR Dialect element types"); + return success(); +} + //===----------------------------------------------------------------------===// // Implementations for LLVM::LLVMFuncOp. //===----------------------------------------------------------------------===// @@ -2117,42 +2128,43 @@ LogicalResult LLVMFuncOp::verifyType() { // - external functions have 'external' or 'extern_weak' linkage; // - vararg is (currently) only supported for external functions; // - entry block arguments are of LLVM types and match the function signature. -static LogicalResult verify(LLVMFuncOp op) { - if (op.getLinkage() == LLVM::Linkage::Common) - return op.emitOpError() - << "functions cannot have '" - << stringifyLinkage(LLVM::Linkage::Common) << "' linkage"; +LogicalResult LLVMFuncOp::verify() { + if (getLinkage() == LLVM::Linkage::Common) + return emitOpError() << "functions cannot have '" + << stringifyLinkage(LLVM::Linkage::Common) + << "' linkage"; // Check to see if this function has a void return with a result attribute to // it. It isn't clear what semantics we would assign to that. - if (op.getType().getReturnType().isa() && - !op.getResultAttrs(0).empty()) { - return op.emitOpError() + if (getType().getReturnType().isa() && + !getResultAttrs(0).empty()) { + return emitOpError() << "cannot attach result attributes to functions with a void return"; } - if (op.isExternal()) { - if (op.getLinkage() != LLVM::Linkage::External && - op.getLinkage() != LLVM::Linkage::ExternWeak) - return op.emitOpError() - << "external functions must have '" - << stringifyLinkage(LLVM::Linkage::External) << "' or '" - << stringifyLinkage(LLVM::Linkage::ExternWeak) << "' linkage"; + if (isExternal()) { + if (getLinkage() != LLVM::Linkage::External && + getLinkage() != LLVM::Linkage::ExternWeak) + return emitOpError() << "external functions must have '" + << stringifyLinkage(LLVM::Linkage::External) + << "' or '" + << stringifyLinkage(LLVM::Linkage::ExternWeak) + << "' linkage"; return success(); } - if (op.isVarArg()) - return op.emitOpError("only external functions can be variadic"); + if (isVarArg()) + return emitOpError("only external functions can be variadic"); - unsigned numArguments = op.getType().getNumParams(); - Block &entryBlock = op.front(); + unsigned numArguments = getType().getNumParams(); + Block &entryBlock = front(); for (unsigned i = 0; i < numArguments; ++i) { Type argType = entryBlock.getArgument(i).getType(); if (!isCompatibleType(argType)) - return op.emitOpError("entry block argument #") + return emitOpError("entry block argument #") << i << " is not of LLVM type"; - if (op.getType().getParamType(i) != argType) - return op.emitOpError("the type of entry block argument #") + if (getType().getParamType(i) != argType) + return emitOpError("the type of entry block argument #") << i << " does not match the function signature"; } @@ -2163,42 +2175,42 @@ static LogicalResult verify(LLVMFuncOp op) { // Verification for LLVM::ConstantOp. //===----------------------------------------------------------------------===// -static LogicalResult verify(LLVM::ConstantOp op) { - if (StringAttr sAttr = op.getValue().dyn_cast()) { - auto arrayType = op.getType().dyn_cast(); +LogicalResult LLVM::ConstantOp::verify() { + if (StringAttr sAttr = getValue().dyn_cast()) { + auto arrayType = getType().dyn_cast(); if (!arrayType || arrayType.getNumElements() != sAttr.getValue().size() || !arrayType.getElementType().isInteger(8)) { - return op->emitOpError() - << "expected array type of " << sAttr.getValue().size() - << " i8 elements for the string constant"; + return emitOpError() << "expected array type of " + << sAttr.getValue().size() + << " i8 elements for the string constant"; } return success(); } - if (auto structType = op.getType().dyn_cast()) { + if (auto structType = getType().dyn_cast()) { if (structType.getBody().size() != 2 || structType.getBody()[0] != structType.getBody()[1]) { - return op.emitError() << "expected struct type with two elements of the " - "same type, the type of a complex constant"; + return emitError() << "expected struct type with two elements of the " + "same type, the type of a complex constant"; } - auto arrayAttr = op.getValue().dyn_cast(); + auto arrayAttr = getValue().dyn_cast(); if (!arrayAttr || arrayAttr.size() != 2 || arrayAttr[0].getType() != arrayAttr[1].getType()) { - return op.emitOpError() << "expected array attribute with two elements, " - "representing a complex constant"; + return emitOpError() << "expected array attribute with two elements, " + "representing a complex constant"; } Type elementType = structType.getBody()[0]; if (!elementType .isa()) { - return op.emitError() + return emitError() << "expected struct element types to be floating point type or " "integer type"; } return success(); } - if (!op.getValue().isa()) - return op.emitOpError() + if (!getValue().isa()) + return emitOpError() << "only supports integer, float, string or elements attributes"; return success(); } @@ -2294,42 +2306,40 @@ static ParseResult parseAtomicRMWOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(AtomicRMWOp op) { - auto ptrType = op.getPtr().getType().cast(); - auto valType = op.getVal().getType(); +LogicalResult AtomicRMWOp::verify() { + auto ptrType = getPtr().getType().cast(); + auto valType = getVal().getType(); if (valType != ptrType.getElementType()) - return op.emitOpError("expected LLVM IR element type for operand #0 to " - "match type for operand #1"); - auto resType = op.getRes().getType(); + return emitOpError("expected LLVM IR element type for operand #0 to " + "match type for operand #1"); + auto resType = getRes().getType(); if (resType != valType) - return op.emitOpError( + return emitOpError( "expected LLVM IR result type to match type for operand #1"); - if (op.getBinOp() == AtomicBinOp::fadd || - op.getBinOp() == AtomicBinOp::fsub) { + if (getBinOp() == AtomicBinOp::fadd || getBinOp() == AtomicBinOp::fsub) { if (!mlir::LLVM::isCompatibleFloatingPointType(valType)) - return op.emitOpError("expected LLVM IR floating point type"); - } else if (op.getBinOp() == AtomicBinOp::xchg) { + return emitOpError("expected LLVM IR floating point type"); + } else if (getBinOp() == AtomicBinOp::xchg) { auto intType = valType.dyn_cast(); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && !valType.isa() && !valType.isa() && !valType.isa() && !valType.isa()) - return op.emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); + return emitOpError("unexpected LLVM IR type for 'xchg' bin_op"); } else { auto intType = valType.dyn_cast(); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64) - return op.emitOpError("expected LLVM IR integer type"); + return emitOpError("expected LLVM IR integer type"); } - if (static_cast(op.getOrdering()) < + if (static_cast(getOrdering()) < static_cast(AtomicOrdering::monotonic)) - return op.emitOpError() - << "expected at least '" - << stringifyAtomicOrdering(AtomicOrdering::monotonic) - << "' ordering"; + return emitOpError() << "expected at least '" + << stringifyAtomicOrdering(AtomicOrdering::monotonic) + << "' ordering"; return success(); } @@ -2375,28 +2385,28 @@ static ParseResult parseAtomicCmpXchgOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(AtomicCmpXchgOp op) { - auto ptrType = op.getPtr().getType().cast(); +LogicalResult AtomicCmpXchgOp::verify() { + auto ptrType = getPtr().getType().cast(); if (!ptrType) - return op.emitOpError("expected LLVM IR pointer type for operand #0"); - auto cmpType = op.getCmp().getType(); - auto valType = op.getVal().getType(); + return emitOpError("expected LLVM IR pointer type for operand #0"); + auto cmpType = getCmp().getType(); + auto valType = getVal().getType(); if (cmpType != ptrType.getElementType() || cmpType != valType) - return op.emitOpError("expected LLVM IR element type for operand #0 to " - "match type for all other operands"); + return emitOpError("expected LLVM IR element type for operand #0 to " + "match type for all other operands"); auto intType = valType.dyn_cast(); unsigned intBitWidth = intType ? intType.getWidth() : 0; if (!valType.isa() && intBitWidth != 8 && intBitWidth != 16 && intBitWidth != 32 && intBitWidth != 64 && !valType.isa() && !valType.isa() && !valType.isa() && !valType.isa()) - return op.emitOpError("unexpected LLVM IR type"); - if (op.getSuccessOrdering() < AtomicOrdering::monotonic || - op.getFailureOrdering() < AtomicOrdering::monotonic) - return op.emitOpError("ordering must be at least 'monotonic'"); - if (op.getFailureOrdering() == AtomicOrdering::release || - op.getFailureOrdering() == AtomicOrdering::acq_rel) - return op.emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); + return emitOpError("unexpected LLVM IR type"); + if (getSuccessOrdering() < AtomicOrdering::monotonic || + getFailureOrdering() < AtomicOrdering::monotonic) + return emitOpError("ordering must be at least 'monotonic'"); + if (getFailureOrdering() == AtomicOrdering::release || + getFailureOrdering() == AtomicOrdering::acq_rel) + return emitOpError("failure ordering cannot be 'release' or 'acq_rel'"); return success(); } @@ -2432,12 +2442,12 @@ static void printFenceOp(OpAsmPrinter &p, FenceOp &op) { p << stringifyAtomicOrdering(op.getOrdering()); } -static LogicalResult verify(FenceOp &op) { - if (op.getOrdering() == AtomicOrdering::not_atomic || - op.getOrdering() == AtomicOrdering::unordered || - op.getOrdering() == AtomicOrdering::monotonic) - return op.emitOpError("can be given only acquire, release, acq_rel, " - "and seq_cst orderings"); +LogicalResult FenceOp::verify() { + if (getOrdering() == AtomicOrdering::not_atomic || + getOrdering() == AtomicOrdering::unordered || + getOrdering() == AtomicOrdering::monotonic) + return emitOpError("can be given only acquire, release, acq_rel, " + "and seq_cst orderings"); return success(); } diff --git a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp index 811a09aac173..5d5e8f401212 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/NVVMDialect.cpp @@ -62,8 +62,14 @@ static ParseResult parseNVVMVoteBallotOp(OpAsmParser &parser, parser.getNameLoc(), result.operands)); } -static LogicalResult verify(MmaOp op) { - MLIRContext *context = op.getContext(); +LogicalResult CpAsyncOp::verify() { + if (size() != 4 && size() != 8 && size() != 16) + return emitError("expected byte size to be either 4, 8 or 16."); + return success(); +} + +LogicalResult MmaOp::verify() { + MLIRContext *context = getContext(); auto f16Ty = Float16Type::get(context); auto f16x2Ty = LLVM::getFixedVectorType(f16Ty, 2); auto f32Ty = Float32Type::get(context); @@ -72,44 +78,55 @@ static LogicalResult verify(MmaOp op) { auto f32x8StructTy = LLVM::LLVMStructType::getLiteral( context, {f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty}); - SmallVector operandTypes(op.getOperandTypes().begin(), - op.getOperandTypes().end()); + auto operandTypes = getOperandTypes(); if (operandTypes != SmallVector(8, f16x2Ty) && - operandTypes != SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty}) { - return op.emitOpError( - "expected operands to be 4 s followed by either " - "4 s or 8 floats"); + operandTypes != ArrayRef{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty}) { + return emitOpError("expected operands to be 4 s followed by either " + "4 s or 8 floats"); } - if (op.getType() != f32x8StructTy && op.getType() != f16x2x4StructTy) { - return op.emitOpError("expected result type to be a struct of either 4 " - "s or 8 floats"); + if (getType() != f32x8StructTy && getType() != f16x2x4StructTy) { + return emitOpError("expected result type to be a struct of either 4 " + "s or 8 floats"); } - auto alayout = op->getAttrOfType("alayout"); - auto blayout = op->getAttrOfType("blayout"); + auto alayout = (*this)->getAttrOfType("alayout"); + auto blayout = (*this)->getAttrOfType("blayout"); if (!(alayout && blayout) || !(alayout.getValue() == "row" || alayout.getValue() == "col") || !(blayout.getValue() == "row" || blayout.getValue() == "col")) { - return op.emitOpError( - "alayout and blayout attributes must be set to either " - "\"row\" or \"col\""); + return emitOpError("alayout and blayout attributes must be set to either " + "\"row\" or \"col\""); } - if (operandTypes == SmallVector{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, - f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, - f32Ty, f32Ty, f32Ty} && - op.getType() == f32x8StructTy && alayout.getValue() == "row" && + if (operandTypes == ArrayRef{f16x2Ty, f16x2Ty, f16x2Ty, f16x2Ty, f32Ty, + f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, f32Ty, + f32Ty} && + getType() == f32x8StructTy && alayout.getValue() == "row" && blayout.getValue() == "col") { return success(); } - return op.emitOpError("unimplemented mma.sync variant"); + return emitOpError("unimplemented mma.sync variant"); } -std::pair -inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) { +LogicalResult ShflOp::verify() { + if (!(*this)->getAttrOfType("return_value_and_is_valid")) + return success(); + auto type = getType().dyn_cast(); + auto elementType = (type && type.getBody().size() == 2) + ? type.getBody()[1].dyn_cast() + : nullptr; + if (!elementType || elementType.getWidth() != 1) + return emitError("expected return type to be a two-element struct with " + "i1 as the second element"); + return success(); +} + +std::pair NVVM::inferMMAType(NVVM::MMATypes type, + NVVM::MMAFrag frag, + MLIRContext *context) { unsigned numberElements = 0; Type elementType; OpBuilder builder(context); @@ -131,76 +148,72 @@ inferMMAType(NVVM::MMATypes type, NVVM::MMAFrag frag, MLIRContext *context) { return std::make_pair(elementType, numberElements); } -static LogicalResult verify(NVVM::WMMALoadOp op) { +LogicalResult NVVM::WMMALoadOp::verify() { unsigned addressSpace = - op.ptr().getType().cast().getAddressSpace(); + ptr().getType().cast().getAddressSpace(); if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) - return op.emitOpError("expected source pointer in memory " - "space 0, 1, 3"); + return emitOpError("expected source pointer in memory " + "space 0, 1, 3"); - if (NVVM::WMMALoadOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(), - op.eltype(), op.frag()) == 0) - return op.emitOpError() << "invalid attribute combination"; + if (NVVM::WMMALoadOp::getIntrinsicID(m(), n(), k(), layout(), eltype(), + frag()) == 0) + return emitOpError() << "invalid attribute combination"; std::pair typeInfo = - inferMMAType(op.eltype(), op.frag(), op.getContext()); + inferMMAType(eltype(), frag(), getContext()); Type dstType = LLVM::LLVMStructType::getLiteral( - op.getContext(), SmallVector(typeInfo.second, typeInfo.first)); - if (op.getType() != dstType) - return op.emitOpError("expected destination type is a structure of ") + getContext(), SmallVector(typeInfo.second, typeInfo.first)); + if (getType() != dstType) + return emitOpError("expected destination type is a structure of ") << typeInfo.second << " elements of type " << typeInfo.first; return success(); } -static LogicalResult verify(NVVM::WMMAStoreOp op) { +LogicalResult NVVM::WMMAStoreOp::verify() { unsigned addressSpace = - op.ptr().getType().cast().getAddressSpace(); + ptr().getType().cast().getAddressSpace(); if (addressSpace != 0 && addressSpace != 1 && addressSpace != 3) - return op.emitOpError("expected operands to be a source pointer in memory " - "space 0, 1, 3"); + return emitOpError("expected operands to be a source pointer in memory " + "space 0, 1, 3"); - if (NVVM::WMMAStoreOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layout(), - op.eltype()) == 0) - return op.emitOpError() << "invalid attribute combination"; + if (NVVM::WMMAStoreOp::getIntrinsicID(m(), n(), k(), layout(), eltype()) == 0) + return emitOpError() << "invalid attribute combination"; std::pair typeInfo = - inferMMAType(op.eltype(), NVVM::MMAFrag::c, op.getContext()); - if (op.args().size() != typeInfo.second) - return op.emitOpError() - << "expected " << typeInfo.second << " data operands"; - if (llvm::any_of(op.args(), [&typeInfo](Value operands) { + inferMMAType(eltype(), NVVM::MMAFrag::c, getContext()); + if (args().size() != typeInfo.second) + return emitOpError() << "expected " << typeInfo.second << " data operands"; + if (llvm::any_of(args(), [&typeInfo](Value operands) { return operands.getType() != typeInfo.first; })) - return op.emitOpError() - << "expected data operands of type " << typeInfo.first; + return emitOpError() << "expected data operands of type " << typeInfo.first; return success(); } -static LogicalResult verify(NVVM::WMMAMmaOp op) { - if (NVVM::WMMAMmaOp::getIntrinsicID(op.m(), op.n(), op.k(), op.layoutA(), - op.layoutB(), op.eltypeA(), - op.eltypeB()) == 0) - return op.emitOpError() << "invalid attribute combination"; +LogicalResult NVVM::WMMAMmaOp::verify() { + if (NVVM::WMMAMmaOp::getIntrinsicID(m(), n(), k(), layoutA(), layoutB(), + eltypeA(), eltypeB()) == 0) + return emitOpError() << "invalid attribute combination"; std::pair typeInfoA = - inferMMAType(op.eltypeA(), NVVM::MMAFrag::a, op.getContext()); + inferMMAType(eltypeA(), NVVM::MMAFrag::a, getContext()); std::pair typeInfoB = - inferMMAType(op.eltypeA(), NVVM::MMAFrag::b, op.getContext()); + inferMMAType(eltypeA(), NVVM::MMAFrag::b, getContext()); std::pair typeInfoC = - inferMMAType(op.eltypeB(), NVVM::MMAFrag::c, op.getContext()); + inferMMAType(eltypeB(), NVVM::MMAFrag::c, getContext()); SmallVector arguments; arguments.append(typeInfoA.second, typeInfoA.first); arguments.append(typeInfoB.second, typeInfoB.first); arguments.append(typeInfoC.second, typeInfoC.first); unsigned numArgs = arguments.size(); - if (op.args().size() != numArgs) - return op.emitOpError() << "expected " << numArgs << " arguments"; + if (args().size() != numArgs) + return emitOpError() << "expected " << numArgs << " arguments"; for (unsigned i = 0; i < numArgs; i++) { - if (op.args()[i].getType() != arguments[i]) - return op.emitOpError() - << "expected argument " << i << " to be of type " << arguments[i]; + if (args()[i].getType() != arguments[i]) + return emitOpError() << "expected argument " << i << " to be of type " + << arguments[i]; } Type dstType = LLVM::LLVMStructType::getLiteral( - op.getContext(), SmallVector(typeInfoC.second, typeInfoC.first)); - if (op.getType() != dstType) - return op.emitOpError("expected destination type is a structure of ") + getContext(), SmallVector(typeInfoC.second, typeInfoC.first)); + if (getType() != dstType) + return emitOpError("expected destination type is a structure of ") << typeInfoC.second << " elements of type " << typeInfoC.first; return success(); } diff --git a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp index ee7b0580cc48..7b70e53a6e9c 100644 --- a/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp +++ b/mlir/lib/Dialect/X86Vector/IR/X86VectorDialect.cpp @@ -28,17 +28,15 @@ void x86vector::X86VectorDialect::initialize() { >(); } -static LogicalResult verify(x86vector::MaskCompressOp op) { - if (op.src() && op.constant_src()) - return emitError(op.getLoc(), "cannot use both src and constant_src"); +LogicalResult x86vector::MaskCompressOp::verify() { + if (src() && constant_src()) + return emitError("cannot use both src and constant_src"); - if (op.src() && (op.src().getType() != op.dst().getType())) - return emitError(op.getLoc(), - "failed to verify that src and dst have same type"); + if (src() && (src().getType() != dst().getType())) + return emitError("failed to verify that src and dst have same type"); - if (op.constant_src() && (op.constant_src()->getType() != op.dst().getType())) + if (constant_src() && (constant_src()->getType() != dst().getType())) return emitError( - op.getLoc(), "failed to verify that constant_src and dst have same type"); return success();