From f7d85f010f0963b828875894dc56298e84ff9031 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Wed, 2 Feb 2022 10:06:30 -0800 Subject: [PATCH] [mlir][NFC] Update SPIRV operations to use `hasVerifier` instead of `verifier` The verifier field is deprecated, and slated for removal. Differential Revision: https://reviews.llvm.org/D118817 --- mlir/docs/Dialects/SPIR-V.md | 4 +- .../mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td | 5 - .../mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td | 4 - .../mlir/Dialect/SPIRV/IR/SPIRVBase.td | 8 +- .../mlir/Dialect/SPIRV/IR/SPIRVBitOps.td | 6 +- .../mlir/Dialect/SPIRV/IR/SPIRVCastOps.td | 15 - .../Dialect/SPIRV/IR/SPIRVCompositeOps.td | 4 +- .../Dialect/SPIRV/IR/SPIRVControlFlowOps.td | 2 +- .../SPIRV/IR/SPIRVCooperativeMatrixOps.td | 14 +- .../mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td | 12 +- .../mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td | 2 +- .../mlir/Dialect/SPIRV/IR/SPIRVImageOps.td | 7 +- .../mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td | 6 - .../mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td | 2 - .../mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td | 2 +- .../Dialect/SPIRV/IR/SPIRVNonUniformOps.td | 2 - .../mlir/Dialect/SPIRV/IR/SPIRVOCLOps.td | 4 +- .../Dialect/SPIRV/IR/SPIRVStructureOps.td | 6 +- mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp | 923 +++++++++++------- 19 files changed, 618 insertions(+), 410 deletions(-) diff --git a/mlir/docs/Dialects/SPIR-V.md b/mlir/docs/Dialects/SPIR-V.md index e84a50346528..623a6e5e0083 100644 --- a/mlir/docs/Dialects/SPIR-V.md +++ b/mlir/docs/Dialects/SPIR-V.md @@ -1309,7 +1309,7 @@ sometimes we need to manually write additional verification logic in [`SPIRVOps.cpp`][MlirSpirvOpsCpp] in a function with the following signature: ```c++ -static LogicalResult verify(spirv::Op op); +LogicalResult spirv::Op::verify(); ``` See any such function in [`SPIRVOps.cpp`][MlirSpirvOpsCpp] as an example. @@ -1318,7 +1318,7 @@ If no additional verification is needed, one needs to add the following to the op's Op Definition Spec: ``` -let verifier = [{ return success(); }]; +let hasVerifier = 0; ``` To suppress the requirement of the above C++ verification function. diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td index a1ab7834b7ed..d15de160ecbd 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVAtomicOps.td @@ -18,7 +18,6 @@ class SPV_AtomicUpdateOp traits = []> : SPV_Op { let parser = [{ return ::parseAtomicUpdateOp(parser, result, false); }]; let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; - let verifier = [{ return ::verifyAtomicUpdateOp(getOperation()); }]; let arguments = (ins SPV_AnyPtr:$pointer, @@ -35,7 +34,6 @@ class SPV_AtomicUpdateWithValueOp traits = []> : SPV_Op { let parser = [{ return ::parseAtomicUpdateOp(parser, result, true); }]; let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; - let verifier = [{ return ::verifyAtomicUpdateOp(getOperation()); }]; let arguments = (ins SPV_AnyPtr:$pointer, @@ -168,7 +166,6 @@ def SPV_AtomicCompareExchangeOp : SPV_Op<"AtomicCompareExchange", []> { let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }]; let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }]; - let verifier = [{ return ::verifyAtomicCompareExchangeImpl(*this); }]; } // ----- @@ -221,7 +218,6 @@ def SPV_AtomicCompareExchangeWeakOp : SPV_Op<"AtomicCompareExchangeWeak", []> { let parser = [{ return ::parseAtomicCompareExchangeImpl(parser, result); }]; let printer = [{ return ::printAtomicCompareExchangeImpl(*this, p); }]; - let verifier = [{ return ::verifyAtomicCompareExchangeImpl(*this); }]; } // ----- @@ -338,7 +334,6 @@ def SPV_AtomicFAddEXTOp : SPV_Op<"AtomicFAddEXT", []> { let parser = [{ return ::parseAtomicUpdateOp(parser, result, true); }]; let printer = [{ return ::printAtomicUpdateOp(getOperation(), p); }]; - let verifier = [{ return ::verifyAtomicUpdateOp(getOperation()); }]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td index 79e4ef55014c..e00d8fa0c842 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBarrierOps.td @@ -77,8 +77,6 @@ def SPV_ControlBarrierOp : SPV_Op<"ControlBarrier", []> { let results = (outs); - let verifier = [{ return verifyMemorySemantics(getOperation(), memory_semantics()); }]; - let autogenSerialization = 0; let assemblyFormat = [{ @@ -131,8 +129,6 @@ def SPV_MemoryBarrierOp : SPV_Op<"MemoryBarrier", []> { let results = (outs); - let verifier = [{ return verifyMemorySemantics(getOperation(), memory_semantics()); }]; - let autogenSerialization = 0; let assemblyFormat = "$memory_scope `,` $memory_semantics attr-dict"; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td index ca4a605f96d2..625e44681cae 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVBase.td @@ -4272,10 +4272,10 @@ class SPV_Op traits = []> : // * static ParseResult parse(OpAsmParser &parser, // OperationState &result) // * static void print(OpAsmPrinter &p, op) - // * static LogicalResult verify( op) + // * LogicalResult ::verify() let parser = [{ return ::parse$cppClass(parser, result); }]; let printer = [{ return ::print(*this, p); }]; - let verifier = [{ return ::verify(*this); }]; + let hasVerifier = 1; // Specifies whether this op has a direct corresponding SPIR-V binary // instruction opcode. The (de)serializer use this field to determine whether @@ -4323,7 +4323,7 @@ class SPV_UnaryOp traits = []> : SPV_ScalarOrVectorOf:$result ); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let assemblyFormat = [{ operands attr-dict `:` type($base) `,` type($offset) `,` type($count) @@ -54,7 +54,7 @@ class SPV_ShiftOp traits = []> : [NoSideEffect, SameOperandsAndResultShape])> { let parser = [{ return ::parseShiftOp(parser, result); }]; let printer = [{ ::printShiftOp(this->getOperation(), p); }]; - let verifier = [{ return ::verifyShiftOp(this->getOperation()); }]; + let hasVerifier = 1; } // ----- @@ -162,7 +162,7 @@ def SPV_BitFieldInsertOp : SPV_Op<"BitFieldInsert", SPV_ScalarOrVectorOf:$result ); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let assemblyFormat = [{ operands attr-dict `:` type($base) `,` type($offset) `,` type($count) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td index dc3c1189b3fd..ec2c73f03ac2 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCastOps.td @@ -32,7 +32,6 @@ class SPV_CastOpgetOperation(), p); }]; - let verifier = [{ return verifyCastOp(this->getOperation()); }]; } // ----- @@ -122,8 +121,6 @@ def SPV_ConvertFToSOp : SPV_CastOp<"ConvertFToS", SPV_Integer, SPV_Float, []> { %3 = spv.ConvertFToS %2 : vector<3xf32> to vector<3xi32> ``` }]; - - let verifier = [{ return verifyCastOp(this->getOperation(), false, true); }]; } // ----- @@ -157,8 +154,6 @@ def SPV_ConvertFToUOp : SPV_CastOp<"ConvertFToU", SPV_Integer, SPV_Float, []> { %3 = spv.ConvertFToU %2 : vector<3xf32> to vector<3xi32> ``` }]; - - let verifier = [{ return verifyCastOp(this->getOperation(), false, true); }]; } // ----- @@ -193,8 +188,6 @@ def SPV_ConvertSToFOp : SPV_CastOp<"ConvertSToF", %3 = spv.ConvertSToF %2 : vector<3xi32> to vector<3xf32> ``` }]; - - let verifier = [{ return verifyCastOp(this->getOperation(), false, true); }]; } // ----- @@ -229,8 +222,6 @@ def SPV_ConvertUToFOp : SPV_CastOp<"ConvertUToF", %3 = spv.ConvertUToF %2 : vector<3xi32> to vector<3xf32> ``` }]; - - let verifier = [{ return verifyCastOp(this->getOperation(), false, true); }]; } // ----- @@ -267,8 +258,6 @@ def SPV_FConvertOp : SPV_CastOp<"FConvert", %3 = spv.FConvertOp %2 : vector<3xf32> to vector<3xf64> ``` }]; - - let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; } // ----- @@ -304,8 +293,6 @@ def SPV_SConvertOp : SPV_CastOp<"SConvert", %3 = spv.SConvertOp %2 : vector<3xi32> to vector<3xi64> ``` }]; - - let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; } // ----- @@ -342,8 +329,6 @@ def SPV_UConvertOp : SPV_CastOp<"UConvert", %3 = spv.UConvertOp %2 : vector<3xi32> to vector<3xi64> ``` }]; - - let verifier = [{ return verifyCastOp(this->getOperation(), false); }]; } #endif // MLIR_DIALECT_SPIRV_IR_CAST_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td index a1daf12c4773..dd6ee93b4bae 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCompositeOps.td @@ -210,7 +210,7 @@ def SPV_VectorExtractDynamicOp : SPV_Op<"VectorExtractDynamic", [ SPV_Scalar:$result ); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let assemblyFormat = [{ $vector `[` $index `]` attr-dict `:` type($vector) `,` type($index) @@ -274,7 +274,7 @@ def SPV_VectorInsertDynamicOp : SPV_Op<"VectorInsertDynamic", [ SPV_Vector:$result ); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let assemblyFormat = [{ $component `,` $vector `[` $index `]` attr-dict `:` type($vector) `,` type($index) diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td index 0d8f6d120f01..b65843214caa 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVControlFlowOps.td @@ -51,7 +51,7 @@ def SPV_BranchOp : SPV_Op<"Branch", [ let successors = (successor AnySuccessor:$target); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let builders = [ OpBuilder<(ins "Block *":$successor, CArg<"ValueRange", "{}">:$arguments), diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td index c64e8d4e9221..8a6e31f4e96c 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVCooperativeMatrixOps.td @@ -55,7 +55,7 @@ def SPV_CooperativeMatrixLengthNVOp : SPV_Op<"CooperativeMatrixLengthNV", let results = (outs SPV_Int32:$result ); - let verifier = [{ return success(); }]; + let hasVerifier = 0; } // ----- @@ -132,11 +132,6 @@ def SPV_CooperativeMatrixLoadNVOp : SPV_Op<"CooperativeMatrixLoadNV", []> { let results = (outs SPV_AnyCooperativeMatrix:$result ); - - let verifier = [{ - return verifyPointerAndCoopMatrixType(*this, pointer().getType(), - result().getType()); - }]; } // ----- @@ -211,8 +206,6 @@ def SPV_CooperativeMatrixMulAddNVOp : SPV_Op<"CooperativeMatrixMulAddNV", let results = (outs SPV_AnyCooperativeMatrix:$result ); - - let verifier = [{ return verifyCoopMatrixMulAdd(*this); }]; } // ----- @@ -274,11 +267,6 @@ def SPV_CooperativeMatrixStoreNVOp : SPV_Op<"CooperativeMatrixStoreNV", []> { ); let results = (outs); - - let verifier = [{ - return verifyPointerAndCoopMatrixType(*this, pointer().getType(), - object().getType()); - }]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td index e3535e48c172..1532d3ea7133 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGLSLOps.td @@ -49,7 +49,7 @@ class SPV_GLSLUnaryOp { let assemblyFormat = [{ attr-dict $operand `:` type($operand) `->` type($result) }]; - - let verifier = [{ return ::verifyGLSLFrexpStructOp(*this); }]; } def SPV_GLSLLdexpOp : @@ -1187,8 +1185,6 @@ def SPV_GLSLLdexpOp : let assemblyFormat = [{ attr-dict $x `:` type($x) `,` $exp `:` type($exp) `->` type($y) }]; - - let verifier = [{ return ::verify(*this); }]; } def SPV_GLSLFMixOp : @@ -1227,7 +1223,7 @@ def SPV_GLSLFMixOp : attr-dict $x `:` type($x) `,` $y `:` type($y) `,` $a `:` type($a) `->` type($result) }]; - let verifier = [{ return success(); }]; + let hasVerifier = 0; } #endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td index 02643e19304e..ef3868914304 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVGroupOps.td @@ -139,7 +139,7 @@ def SPV_SubgroupBallotKHROp : SPV_Op<"SubgroupBallotKHR", []> { SPV_Int32Vec4:$result ); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let assemblyFormat = "$predicate attr-dict `:` type($result)"; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td index f987b962fabc..636c0d1288a1 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVImageOps.td @@ -82,8 +82,6 @@ def SPV_ImageDrefGatherOp : SPV_Op<"ImageDrefGather", [NoSideEffect]> { ( `(` $operand_arguments^ `:` type($operand_arguments) `)`)? attr-dict `->` type($result)}]; - - let verifier = [{ return ::verify(*this); }]; } // ----- @@ -141,8 +139,6 @@ def SPV_ImageQuerySizeOp : SPV_Op<"ImageQuerySize", [NoSideEffect]> { ); let assemblyFormat = "attr-dict $image `:` type($image) `->` type($result)"; - - let verifier = [{return ::verify(*this);}]; } // ----- @@ -179,8 +175,7 @@ def SPV_ImageOp : SPV_Op<"Image", let assemblyFormat = "attr-dict $sampledimage `:` type($sampledimage)"; - let verifier = ?; - + let hasVerifier = 0; } #endif // MLIR_DIALECT_SPIRV_IR_GLSL_OPS diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td index b55afaa7fe69..ab9077b82d05 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMatrixOps.td @@ -66,8 +66,6 @@ def SPV_MatrixTimesMatrixOp : SPV_Op<"MatrixTimesMatrix", [NoSideEffect]> { let assemblyFormat = [{ operands attr-dict `:` type($leftmatrix) `,` type($rightmatrix) `->` type($result) }]; - - let verifier = [{ return verifyMatrixTimesMatrix(*this); }]; } // ----- @@ -130,8 +128,6 @@ def SPV_MatrixTimesScalarOp : SPV_Op<"MatrixTimesScalar", [NoSideEffect]> { Extension<[]>, Capability<[SPV_C_Matrix]> ]; - - let verifier = [{ return verifyMatrixTimesScalar(*this); }]; } // ----- @@ -184,8 +180,6 @@ def SPV_TransposeOp : SPV_Op<"Transpose", [NoSideEffect]> { let assemblyFormat = [{ operands attr-dict `:` type($matrix) `->` type($result) }]; - - let verifier = [{ return verifyTranspose(*this); }]; } // ----- diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td index 63cd3fe0213d..e69ae9c86af8 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMemoryOps.td @@ -130,8 +130,6 @@ def SPV_CopyMemoryOp : SPV_Op<"CopyMemory", []> { let results = (outs); - let verifier = [{ return verifyCopyMemory(*this); }]; - let autogenSerialization = 0; } diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td index 498b28c5f78a..7b4f7a24d9bd 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVMiscOps.td @@ -48,7 +48,7 @@ def SPV_UndefOp : SPV_Op<"Undef", []> { SPV_Type:$result ); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let hasOpcode = 0; let autogenSerialization = 0; diff --git a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td index 68d0be3f2e39..1077e8a802a8 100644 --- a/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/IR/SPIRVNonUniformOps.td @@ -30,8 +30,6 @@ class SPV_GroupNonUniformArithmeticOp { let results = (outs); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let autogenSerialization = 0; @@ -296,7 +296,7 @@ def SPV_FuncOp : SPV_Op<"func", [ let regions = (region AnyRegion:$body); - let verifier = [{ return success(); }]; + let hasVerifier = 0; let builders = [ OpBuilder<(ins "StringRef":$name, "FunctionType":$type, @@ -788,7 +788,7 @@ def SPV_YieldOp : SPV_Op<"mlir.yield", [ let assemblyFormat = "attr-dict $operand `:` type($operand)"; - let verifier = [{ return success(); }]; + let hasVerifier = 0; } // ----- diff --git a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp index ff81086d5f57..b45b422c3c6f 100644 --- a/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/IR/SPIRVOps.cpp @@ -1125,8 +1125,8 @@ static LogicalResult verifyAccessChain(Op accessChainOp, ValueRange indices) { return success(); } -static LogicalResult verify(spirv::AccessChainOp accessChainOp) { - return verifyAccessChain(accessChainOp, accessChainOp.indices()); +LogicalResult spirv::AccessChainOp::verify() { + return verifyAccessChain(*this, indices()); } //===----------------------------------------------------------------------===// @@ -1138,15 +1138,15 @@ void spirv::AddressOfOp::build(OpBuilder &builder, OperationState &state, build(builder, state, var.type(), SymbolRefAttr::get(var)); } -static LogicalResult verify(spirv::AddressOfOp addressOfOp) { +LogicalResult spirv::AddressOfOp::verify() { auto varOp = dyn_cast_or_null( - SymbolTable::lookupNearestSymbolFrom(addressOfOp->getParentOp(), - addressOfOp.variableAttr())); + SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), + variableAttr())); if (!varOp) { - return addressOfOp.emitOpError("expected spv.GlobalVariable symbol"); + return emitOpError("expected spv.GlobalVariable symbol"); } - if (addressOfOp.pointer().getType() != varOp.type()) { - return addressOfOp.emitOpError( + if (pointer().getType() != varOp.type()) { + return emitOpError( "result type mismatch with the referenced global variable's type"); } return success(); @@ -1224,6 +1224,30 @@ static LogicalResult verifyAtomicCompareExchangeImpl(T atomOp) { return success(); } +//===----------------------------------------------------------------------===// +// spv.AtomicAndOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicAndOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicCompareExchangeOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicCompareExchangeOp::verify() { + return ::verifyAtomicCompareExchangeImpl(*this); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicCompareExchangeWeakOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicCompareExchangeWeakOp::verify() { + return ::verifyAtomicCompareExchangeImpl(*this); +} + //===----------------------------------------------------------------------===// // spv.AtomicExchange //===----------------------------------------------------------------------===// @@ -1260,50 +1284,136 @@ static ParseResult parseAtomicExchangeOp(OpAsmParser &parser, return parser.addTypeToList(ptrType.getPointeeType(), state.types); } -static LogicalResult verify(spirv::AtomicExchangeOp atomOp) { - if (atomOp.getType() != atomOp.value().getType()) - return atomOp.emitOpError("value operand must have the same type as the op " - "result, but found ") - << atomOp.value().getType() << " vs " << atomOp.getType(); +LogicalResult spirv::AtomicExchangeOp::verify() { + if (getType() != value().getType()) + return emitOpError("value operand must have the same type as the op " + "result, but found ") + << value().getType() << " vs " << getType(); Type pointeeType = - atomOp.pointer().getType().cast().getPointeeType(); - if (atomOp.getType() != pointeeType) - return atomOp.emitOpError( - "pointer operand's pointee type must have the same " - "as the op result type, but found ") - << pointeeType << " vs " << atomOp.getType(); + pointer().getType().cast().getPointeeType(); + if (getType() != pointeeType) + return emitOpError("pointer operand's pointee type must have the same " + "as the op result type, but found ") + << pointeeType << " vs " << getType(); return success(); } +//===----------------------------------------------------------------------===// +// spv.AtomicFAddEXTOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicIAddOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicFAddEXTOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicFAddEXTOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicIDecrementOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicIDecrementOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicIIncrementOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicIIncrementOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicISubOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicISubOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicOrOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicOrOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicSMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicSMaxOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicSMinOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicSMinOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicUMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicUMaxOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicUMinOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicUMinOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + +//===----------------------------------------------------------------------===// +// spv.AtomicXorOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::AtomicXorOp::verify() { + return ::verifyAtomicUpdateOp(getOperation()); +} + //===----------------------------------------------------------------------===// // spv.BitcastOp //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::BitcastOp bitcastOp) { +LogicalResult spirv::BitcastOp::verify() { // TODO: The SPIR-V spec validation rules are different for different // versions. - auto operandType = bitcastOp.operand().getType(); - auto resultType = bitcastOp.result().getType(); + auto operandType = operand().getType(); + auto resultType = result().getType(); if (operandType == resultType) { - return bitcastOp.emitError( - "result type must be different from operand type"); + return emitError("result type must be different from operand type"); } if (operandType.isa() && !resultType.isa()) { - return bitcastOp.emitError( + return emitError( "unhandled bit cast conversion from pointer type to non-pointer type"); } if (!operandType.isa() && resultType.isa()) { - return bitcastOp.emitError( + return emitError( "unhandled bit cast conversion from non-pointer type to pointer type"); } auto operandBitWidth = getBitWidth(operandType); auto resultBitWidth = getBitWidth(resultType); if (operandBitWidth != resultBitWidth) { - return bitcastOp.emitOpError("mismatch in result type bitwidth ") + return emitOpError("mismatch in result type bitwidth ") << resultBitWidth << " and operand type bitwidth " << operandBitWidth; } @@ -1401,15 +1511,15 @@ static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter &printer) { branchOp.getFalseBlockArguments()); } -static LogicalResult verify(spirv::BranchConditionalOp branchOp) { - if (auto weights = branchOp.branch_weights()) { +LogicalResult spirv::BranchConditionalOp::verify() { + if (auto weights = branch_weights()) { if (weights->getValue().size() != 2) { - return branchOp.emitOpError("must have exactly two branch weights"); + return emitOpError("must have exactly two branch weights"); } if (llvm::all_of(*weights, [](Attribute attr) { return attr.cast().getValue().isNullValue(); })) - return branchOp.emitOpError("branch weights cannot both be zero"); + return emitOpError("branch weights cannot both be zero"); } return success(); @@ -1459,26 +1569,23 @@ static void print(spirv::CompositeConstructOp compositeConstructOp, << compositeConstructOp.getResult().getType(); } -static LogicalResult verify(spirv::CompositeConstructOp compositeConstructOp) { - auto cType = compositeConstructOp.getType().cast(); - SmallVector constituents(compositeConstructOp.constituents()); +LogicalResult spirv::CompositeConstructOp::verify() { + auto cType = getType().cast(); + operand_range constituents = this->constituents(); if (cType.isa()) { if (constituents.size() != 1) - return compositeConstructOp.emitError( - "has incorrect number of operands: expected ") + return emitError("has incorrect number of operands: expected ") << "1, but provided " << constituents.size(); } else if (constituents.size() != cType.getNumElements()) { - return compositeConstructOp.emitError( - "has incorrect number of operands: expected ") + return emitError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << constituents.size(); } for (auto index : llvm::seq(0, constituents.size())) { if (constituents[index].getType() != cType.getElementType(index)) { - return compositeConstructOp.emitError( - "operand type mismatch: expected operand type ") + return emitError("operand type mismatch: expected operand type ") << cType.getElementType(index) << ", but provided " << constituents[index].getType(); } @@ -1534,16 +1641,16 @@ static void print(spirv::CompositeExtractOp compositeExtractOp, << compositeExtractOp.composite().getType(); } -static LogicalResult verify(spirv::CompositeExtractOp compExOp) { - auto indicesArrayAttr = compExOp.indices().dyn_cast(); - auto resultType = getElementType(compExOp.composite().getType(), - indicesArrayAttr, compExOp.getLoc()); +LogicalResult spirv::CompositeExtractOp::verify() { + auto indicesArrayAttr = indices().dyn_cast(); + auto resultType = + getElementType(composite().getType(), indicesArrayAttr, getLoc()); if (!resultType) return failure(); - if (resultType != compExOp.getType()) { - return compExOp.emitOpError("invalid result type: expected ") - << resultType << " but provided " << compExOp.getType(); + if (resultType != getType()) { + return emitOpError("invalid result type: expected ") + << resultType << " but provided " << getType(); } return success(); @@ -1577,25 +1684,22 @@ static ParseResult parseCompositeInsertOp(OpAsmParser &parser, parser.addTypesToList(compositeType, state.types)); } -static LogicalResult verify(spirv::CompositeInsertOp compositeInsertOp) { - auto indicesArrayAttr = compositeInsertOp.indices().dyn_cast(); +LogicalResult spirv::CompositeInsertOp::verify() { + auto indicesArrayAttr = indices().dyn_cast(); auto objectType = - getElementType(compositeInsertOp.composite().getType(), indicesArrayAttr, - compositeInsertOp.getLoc()); + getElementType(composite().getType(), indicesArrayAttr, getLoc()); if (!objectType) return failure(); - if (objectType != compositeInsertOp.object().getType()) { - return compositeInsertOp.emitOpError("object operand type should be ") - << objectType << ", but found " - << compositeInsertOp.object().getType(); + if (objectType != object().getType()) { + return emitOpError("object operand type should be ") + << objectType << ", but found " << object().getType(); } - if (compositeInsertOp.composite().getType() != compositeInsertOp.getType()) { - return compositeInsertOp.emitOpError("result type should be the same as " - "the composite type, but found ") - << compositeInsertOp.composite().getType() << " vs " - << compositeInsertOp.getType(); + if (composite().getType() != getType()) { + return emitOpError("result type should be the same as " + "the composite type, but found ") + << composite().getType() << " vs " << getType(); } return success(); @@ -1633,9 +1737,9 @@ static void print(spirv::ConstantOp constOp, OpAsmPrinter &printer) { printer << " : " << constOp.getType(); } -static LogicalResult verify(spirv::ConstantOp constOp) { - auto opType = constOp.getType(); - auto value = constOp.value(); +LogicalResult spirv::ConstantOp::verify() { + auto opType = getType(); + auto value = valueAttr(); auto valueType = value.getType(); // ODS already generates checks to make sure the result type is valid. We just @@ -1643,7 +1747,7 @@ static LogicalResult verify(spirv::ConstantOp constOp) { // with the result type. if (value.isa()) { if (valueType != opType) - return constOp.emitOpError("result type (") + return emitOpError("result type (") << opType << ") does not match value type (" << valueType << ")"; return success(); } @@ -1652,10 +1756,8 @@ static LogicalResult verify(spirv::ConstantOp constOp) { return success(); auto arrayType = opType.dyn_cast(); auto shapedType = valueType.dyn_cast(); - if (!arrayType) { - return constOp.emitOpError( - "must have spv.array result type for array value"); - } + if (!arrayType) + return emitOpError("must have spv.array result type for array value"); int numElements = arrayType.getNumElements(); auto opElemType = arrayType.getElementType(); @@ -1664,17 +1766,17 @@ static LogicalResult verify(spirv::ConstantOp constOp) { opElemType = t.getElementType(); } if (!opElemType.isIntOrFloat()) - return constOp.emitOpError("only support nested array result type"); + return emitOpError("only support nested array result type"); auto valueElemType = shapedType.getElementType(); if (valueElemType != opElemType) { - return constOp.emitOpError("result element type (") + return emitOpError("result element type (") << opElemType << ") does not match value element type (" << valueElemType << ")"; } if (numElements != shapedType.getNumElements()) { - return constOp.emitOpError("result number of elements (") + return emitOpError("result number of elements (") << numElements << ") does not match value number of elements (" << shapedType.getNumElements() << ")"; } @@ -1683,19 +1785,18 @@ static LogicalResult verify(spirv::ConstantOp constOp) { if (auto attayAttr = value.dyn_cast()) { auto arrayType = opType.dyn_cast(); if (!arrayType) - return constOp.emitOpError( - "must have spv.array result type for array value"); + return emitOpError("must have spv.array result type for array value"); Type elemType = arrayType.getElementType(); for (Attribute element : attayAttr.getValue()) { if (element.getType() != elemType) - return constOp.emitOpError("has array element whose type (") + return emitOpError("has array element whose type (") << element.getType() << ") does not match the result element type (" << elemType << ')'; } return success(); } - return constOp.emitOpError("cannot have value of type ") << valueType; + return emitOpError("cannot have value of type ") << valueType; } bool spirv::ConstantOp::isBuildableWith(Type type) { @@ -1825,6 +1926,50 @@ void mlir::spirv::AddressOfOp::getAsmResultNames( setNameFn(getResult(), specialName.str()); } +//===----------------------------------------------------------------------===// +// spv.ControlBarrierOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ControlBarrierOp::verify() { + return verifyMemorySemantics(getOperation(), memory_semantics()); +} + +//===----------------------------------------------------------------------===// +// spv.ConvertFToSOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ConvertFToSOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false, + /*skipBitWidthCheck=*/true); +} + +//===----------------------------------------------------------------------===// +// spv.ConvertFToUOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ConvertFToUOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false, + /*skipBitWidthCheck=*/true); +} + +//===----------------------------------------------------------------------===// +// spv.ConvertSToFOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ConvertSToFOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false, + /*skipBitWidthCheck=*/true); +} + +//===----------------------------------------------------------------------===// +// spv.ConvertUToFOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ConvertUToFOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false, + /*skipBitWidthCheck=*/true); +} + //===----------------------------------------------------------------------===// // spv.EntryPoint //===----------------------------------------------------------------------===// @@ -1880,7 +2025,7 @@ static void print(spirv::EntryPointOp entryPointOp, OpAsmPrinter &printer) { } } -static LogicalResult verify(spirv::EntryPointOp entryPointOp) { +LogicalResult spirv::EntryPointOp::verify() { // Checks for fn and interface symbol reference are done in spirv::ModuleOp // verification. return success(); @@ -1937,6 +2082,30 @@ static void print(spirv::ExecutionModeOp execModeOp, OpAsmPrinter &printer) { }); } +//===----------------------------------------------------------------------===// +// spv.FConvertOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::FConvertOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false); +} + +//===----------------------------------------------------------------------===// +// spv.SConvertOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::SConvertOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false); +} + +//===----------------------------------------------------------------------===// +// spv.UConvertOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::UConvertOp::verify() { + return verifyCastOp(*this, /*requireSameBitWidth=*/false); +} + //===----------------------------------------------------------------------===// // spv.func //===----------------------------------------------------------------------===// @@ -2077,54 +2246,50 @@ ArrayRef spirv::FuncOp::getCallableResults() { // spv.FunctionCall //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::FunctionCallOp functionCallOp) { - auto fnName = functionCallOp.calleeAttr(); +LogicalResult spirv::FunctionCallOp::verify() { + auto fnName = calleeAttr(); - auto funcOp = - dyn_cast_or_null(SymbolTable::lookupNearestSymbolFrom( - functionCallOp->getParentOp(), fnName)); + auto funcOp = dyn_cast_or_null( + SymbolTable::lookupNearestSymbolFrom((*this)->getParentOp(), fnName)); if (!funcOp) { - return functionCallOp.emitOpError("callee function '") + return emitOpError("callee function '") << fnName.getValue() << "' not found in nearest symbol table"; } auto functionType = funcOp.getType(); - if (functionCallOp.getNumResults() > 1) { - return functionCallOp.emitOpError( + if (getNumResults() > 1) { + return emitOpError( "expected callee function to have 0 or 1 result, but provided ") - << functionCallOp.getNumResults(); + << getNumResults(); } - if (functionType.getNumInputs() != functionCallOp.getNumOperands()) { - return functionCallOp.emitOpError( - "has incorrect number of operands for callee: expected ") + if (functionType.getNumInputs() != getNumOperands()) { + return emitOpError("has incorrect number of operands for callee: expected ") << functionType.getNumInputs() << ", but provided " - << functionCallOp.getNumOperands(); + << getNumOperands(); } for (uint32_t i = 0, e = functionType.getNumInputs(); i != e; ++i) { - if (functionCallOp.getOperand(i).getType() != functionType.getInput(i)) { - return functionCallOp.emitOpError( - "operand type mismatch: expected operand type ") + if (getOperand(i).getType() != functionType.getInput(i)) { + return emitOpError("operand type mismatch: expected operand type ") << functionType.getInput(i) << ", but provided " - << functionCallOp.getOperand(i).getType() << " for operand number " - << i; + << getOperand(i).getType() << " for operand number " << i; } } - if (functionType.getNumResults() != functionCallOp.getNumResults()) { - return functionCallOp.emitOpError( + if (functionType.getNumResults() != getNumResults()) { + return emitOpError( "has incorrect number of results has for callee: expected ") << functionType.getNumResults() << ", but provided " - << functionCallOp.getNumResults(); + << getNumResults(); } - if (functionCallOp.getNumResults() && - (functionCallOp.getResult(0).getType() != functionType.getResult(0))) { - return functionCallOp.emitOpError("result type mismatch: expected ") + if (getNumResults() && + (getResult(0).getType() != functionType.getResult(0))) { + return emitOpError("result type mismatch: expected ") << functionType.getResult(0) << ", but provided " - << functionCallOp.getResult(0).getType(); + << getResult(0).getType(); } return success(); @@ -2222,29 +2387,29 @@ static void print(spirv::GlobalVariableOp varOp, OpAsmPrinter &printer) { printer << " : " << varOp.type(); } -static LogicalResult verify(spirv::GlobalVariableOp varOp) { +LogicalResult spirv::GlobalVariableOp::verify() { // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." // Also, Function storage class is reserved by spv.Variable. - auto storageClass = varOp.storageClass(); + auto storageClass = this->storageClass(); if (storageClass == spirv::StorageClass::Generic || storageClass == spirv::StorageClass::Function) { - return varOp.emitOpError("storage class cannot be '") + return emitOpError("storage class cannot be '") << stringifyStorageClass(storageClass) << "'"; } if (auto init = - varOp->getAttrOfType(kInitializerAttrName)) { + (*this)->getAttrOfType(kInitializerAttrName)) { Operation *initOp = SymbolTable::lookupNearestSymbolFrom( - varOp->getParentOp(), init.getAttr()); + (*this)->getParentOp(), init.getAttr()); // TODO: Currently only variable initialization with specialization // constants and other variables is supported. They could be normal // constants in the module scope as well. if (!initOp || !isa(initOp)) { - return varOp.emitOpError("initializer must be result of a " - "spv.SpecConstant or spv.GlobalVariable op"); + return emitOpError("initializer must be result of a " + "spv.SpecConstant or spv.GlobalVariable op"); } } @@ -2255,16 +2420,15 @@ static LogicalResult verify(spirv::GlobalVariableOp varOp) { // spv.GroupBroadcast //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) { - spirv::Scope scope = broadcastOp.execution_scope(); +LogicalResult spirv::GroupBroadcastOp::verify() { + spirv::Scope scope = execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return broadcastOp.emitOpError( - "execution scope must be 'Workgroup' or 'Subgroup'"); + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); - if (auto localIdTy = broadcastOp.localid().getType().dyn_cast()) + if (auto localIdTy = localid().getType().dyn_cast()) if (!(localIdTy.getNumElements() == 2 || localIdTy.getNumElements() == 3)) - return broadcastOp.emitOpError("localid is a vector and can be with only " - " 2 or 3 components, actual number is ") + return emitOpError("localid is a vector and can be with only " + " 2 or 3 components, actual number is ") << localIdTy.getNumElements(); return success(); @@ -2274,11 +2438,10 @@ static LogicalResult verify(spirv::GroupBroadcastOp broadcastOp) { // spv.GroupNonUniformBallotOp //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { - spirv::Scope scope = ballotOp.execution_scope(); +LogicalResult spirv::GroupNonUniformBallotOp::verify() { + spirv::Scope scope = execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return ballotOp.emitOpError( - "execution scope must be 'Workgroup' or 'Subgroup'"); + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); return success(); } @@ -2287,23 +2450,22 @@ static LogicalResult verify(spirv::GroupNonUniformBallotOp ballotOp) { // spv.GroupNonUniformBroadcast //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::GroupNonUniformBroadcastOp broadcastOp) { - spirv::Scope scope = broadcastOp.execution_scope(); +LogicalResult spirv::GroupNonUniformBroadcastOp::verify() { + spirv::Scope scope = execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return broadcastOp.emitOpError( - "execution scope must be 'Workgroup' or 'Subgroup'"); + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); // SPIR-V spec: "Before version 1.5, Id must come from a // constant instruction. - auto targetEnv = spirv::getDefaultTargetEnv(broadcastOp.getContext()); - if (auto spirvModule = broadcastOp->getParentOfType()) + auto targetEnv = spirv::getDefaultTargetEnv(getContext()); + if (auto spirvModule = (*this)->getParentOfType()) targetEnv = spirv::lookupTargetEnvOrDefault(spirvModule); if (targetEnv.getVersion() < spirv::Version::V_1_5) { - auto *idOp = broadcastOp.id().getDefiningOp(); + auto *idOp = id().getDefiningOp(); if (!idOp || !isa(idOp)) // for spec constant - return broadcastOp.emitOpError("id must be the result of a constant op"); + return emitOpError("id must be the result of a constant op"); } return success(); @@ -2343,9 +2505,8 @@ static void print(spirv::SubgroupBlockReadINTELOp blockReadOp, printer << " : " << blockReadOp.getType(); } -static LogicalResult verify(spirv::SubgroupBlockReadINTELOp blockReadOp) { - if (failed(verifyBlockReadWritePtrAndValTypes(blockReadOp, blockReadOp.ptr(), - blockReadOp.value()))) +LogicalResult spirv::SubgroupBlockReadINTELOp::verify() { + if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) return failure(); return success(); @@ -2386,9 +2547,8 @@ static void print(spirv::SubgroupBlockWriteINTELOp blockWriteOp, printer << " : " << blockWriteOp.value().getType(); } -static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) { - if (failed(verifyBlockReadWritePtrAndValTypes( - blockWriteOp, blockWriteOp.ptr(), blockWriteOp.value()))) +LogicalResult spirv::SubgroupBlockWriteINTELOp::verify() { + if (failed(verifyBlockReadWritePtrAndValTypes(*this, ptr(), value()))) return failure(); return success(); @@ -2398,15 +2558,94 @@ static LogicalResult verify(spirv::SubgroupBlockWriteINTELOp blockWriteOp) { // spv.GroupNonUniformElectOp //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::GroupNonUniformElectOp groupOp) { - spirv::Scope scope = groupOp.execution_scope(); +LogicalResult spirv::GroupNonUniformElectOp::verify() { + spirv::Scope scope = execution_scope(); if (scope != spirv::Scope::Workgroup && scope != spirv::Scope::Subgroup) - return groupOp.emitOpError( - "execution scope must be 'Workgroup' or 'Subgroup'"); + return emitOpError("execution scope must be 'Workgroup' or 'Subgroup'"); return success(); } +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformFAddOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformFAddOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformFMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformFMaxOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformFMinOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformFMinOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformFMulOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformFMulOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformIAddOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformIAddOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformIMulOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformIMulOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformSMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformSMaxOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformSMinOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformSMinOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformUMaxOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformUMaxOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.GroupNonUniformUMinOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::GroupNonUniformUMinOp::verify() { + return verifyGroupNonUniformArithmeticOp(*this); +} + //===----------------------------------------------------------------------===// // spv.LoadOp //===----------------------------------------------------------------------===// @@ -2453,15 +2692,14 @@ static void print(spirv::LoadOp loadOp, OpAsmPrinter &printer) { printer << " : " << loadOp.getType(); } -static LogicalResult verify(spirv::LoadOp loadOp) { +LogicalResult spirv::LoadOp::verify() { // SPIR-V spec : "Result Type is the type of the loaded object. It must be a // type with fixed size; i.e., it cannot be, nor include, any // OpTypeRuntimeArray types." - if (failed(verifyLoadStorePtrAndValTypes(loadOp, loadOp.ptr(), - loadOp.value()))) { + if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) { return failure(); } - return verifyMemoryAccessAttribute(loadOp); + return verifyMemoryAccessAttribute(*this); } //===----------------------------------------------------------------------===// @@ -2504,8 +2742,8 @@ static inline bool hasOneBranchOpTo(Block &srcBlock, Block &dstBlock) { return branchOp && branchOp.getSuccessor() == &dstBlock; } -static LogicalResult verify(spirv::LoopOp loopOp) { - auto *op = loopOp.getOperation(); +LogicalResult spirv::LoopOp::verify() { + auto *op = getOperation(); // We need to verify that the blocks follow the following layout: // @@ -2541,27 +2779,27 @@ static LogicalResult verify(spirv::LoopOp loopOp) { // The last block is the merge block. Block &merge = region.back(); if (!isMergeBlock(merge)) - return loopOp.emitOpError( + return emitOpError( "last block must be the merge block with only one 'spv.mlir.merge' op"); if (std::next(region.begin()) == region.end()) - return loopOp.emitOpError( + return emitOpError( "must have an entry block branching to the loop header block"); // The first block is the entry block. Block &entry = region.front(); if (std::next(region.begin(), 2) == region.end()) - return loopOp.emitOpError( + return emitOpError( "must have a loop header block branched from the entry block"); // The second block is the loop header block. Block &header = *std::next(region.begin(), 1); if (!hasOneBranchOpTo(entry, header)) - return loopOp.emitOpError( + return emitOpError( "entry block must only have one 'spv.Branch' op to the second block"); if (std::next(region.begin(), 3) == region.end()) - return loopOp.emitOpError( + return emitOpError( "requires a loop continue block branching to the loop header block"); // The second to last block is the loop continue block. Block &cont = *std::prev(region.end(), 2); @@ -2571,8 +2809,8 @@ static LogicalResult verify(spirv::LoopOp loopOp) { if (llvm::none_of( llvm::seq(0, cont.getNumSuccessors()), [&](unsigned index) { return cont.getSuccessor(index) == &header; })) - return loopOp.emitOpError("second to last block must be the loop continue " - "block that branches to the loop header block"); + return emitOpError("second to last block must be the loop continue " + "block that branches to the loop header block"); // Make sure that no other blocks (except the entry and loop continue block) // branches to the loop header block. @@ -2580,8 +2818,8 @@ static LogicalResult verify(spirv::LoopOp loopOp) { std::prev(region.end(), 2))) { for (auto i : llvm::seq(0, block.getNumSuccessors())) { if (block.getSuccessor(i) == &header) { - return loopOp.emitOpError("can only have the entry and loop continue " - "block branching to the loop header block"); + return emitOpError("can only have the entry and loop continue " + "block branching to the loop header block"); } } } @@ -2623,20 +2861,28 @@ void spirv::LoopOp::addEntryAndMergeBlock() { builder.create(getLoc()); } +//===----------------------------------------------------------------------===// +// spv.MemoryBarrierOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::MemoryBarrierOp::verify() { + return verifyMemorySemantics(getOperation(), memory_semantics()); +} + //===----------------------------------------------------------------------===// // spv.mlir.merge //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::MergeOp mergeOp) { - auto *parentOp = mergeOp->getParentOp(); +LogicalResult spirv::MergeOp::verify() { + auto *parentOp = (*this)->getParentOp(); if (!parentOp || !isa(parentOp)) - return mergeOp.emitOpError( + return emitOpError( "expected parent op to be 'spv.mlir.selection' or 'spv.mlir.loop'"); - Block &parentLastBlock = mergeOp->getParentRegion()->back(); - if (mergeOp.getOperation() != parentLastBlock.getTerminator()) - return mergeOp.emitOpError("can only be used in the last block of " - "'spv.mlir.selection' or 'spv.mlir.loop'"); + Block &parentLastBlock = (*this)->getParentRegion()->back(); + if (getOperation() != parentLastBlock.getTerminator()) + return emitOpError("can only be used in the last block of " + "'spv.mlir.selection' or 'spv.mlir.loop'"); return success(); } @@ -2734,14 +2980,13 @@ static void print(spirv::ModuleOp moduleOp, OpAsmPrinter &printer) { printer.printRegion(moduleOp.getRegion()); } -static LogicalResult verify(spirv::ModuleOp moduleOp) { - auto &op = *moduleOp.getOperation(); - auto *dialect = op.getDialect(); +LogicalResult spirv::ModuleOp::verify() { + Dialect *dialect = (*this)->getDialect(); DenseMap, spirv::EntryPointOp> entryPoints; - SymbolTable table(moduleOp); + mlir::SymbolTable table(*this); - for (auto &op : *moduleOp.getBody()) { + for (auto &op : *getBody()) { if (op.getDialect() != dialect) return op.emitError("'spv.module' can only contain spv.* ops"); @@ -2801,9 +3046,9 @@ static LogicalResult verify(spirv::ModuleOp moduleOp) { // spv.mlir.referenceof //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { +LogicalResult spirv::ReferenceOfOp::verify() { auto *specConstSym = SymbolTable::lookupNearestSymbolFrom( - referenceOfOp->getParentOp(), referenceOfOp.spec_constAttr()); + (*this)->getParentOp(), spec_constAttr()); Type constType; auto specConstOp = dyn_cast_or_null(specConstSym); @@ -2816,12 +3061,12 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { constType = specConstCompositeOp.type(); if (!specConstOp && !specConstCompositeOp) - return referenceOfOp.emitOpError( + return emitOpError( "expected spv.SpecConstant or spv.SpecConstantComposite symbol"); - if (referenceOfOp.reference().getType() != constType) - return referenceOfOp.emitOpError("result type mismatch with the referenced " - "specialization constant's type"); + if (reference().getType() != constType) + return emitOpError("result type mismatch with the referenced " + "specialization constant's type"); return success(); } @@ -2830,7 +3075,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { // spv.Return //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::ReturnOp returnOp) { +LogicalResult spirv::ReturnOp::verify() { // Verification is performed in spv.func op. return success(); } @@ -2839,7 +3084,7 @@ static LogicalResult verify(spirv::ReturnOp returnOp) { // spv.ReturnValue //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::ReturnValueOp retValOp) { +LogicalResult spirv::ReturnValueOp::verify() { // Verification is performed in spv.func op. return success(); } @@ -2848,16 +3093,16 @@ static LogicalResult verify(spirv::ReturnValueOp retValOp) { // spv.Select //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::SelectOp op) { - if (auto conditionTy = op.condition().getType().dyn_cast()) { - auto resultVectorTy = op.result().getType().dyn_cast(); +LogicalResult spirv::SelectOp::verify() { + if (auto conditionTy = condition().getType().dyn_cast()) { + auto resultVectorTy = result().getType().dyn_cast(); if (!resultVectorTy) { - return op.emitOpError("result expected to be of vector type when " - "condition is of vector type"); + return emitOpError("result expected to be of vector type when " + "condition is of vector type"); } if (resultVectorTy.getNumElements() != conditionTy.getNumElements()) { - return op.emitOpError("result should have the same number of elements as " - "the condition when condition is of vector type"); + return emitOpError("result should have the same number of elements as " + "the condition when condition is of vector type"); } } return success(); @@ -2885,8 +3130,8 @@ static void print(spirv::SelectionOp selectionOp, OpAsmPrinter &printer) { /*printBlockTerminators=*/true); } -static LogicalResult verify(spirv::SelectionOp selectionOp) { - auto *op = selectionOp.getOperation(); +LogicalResult spirv::SelectionOp::verify() { + auto *op = getOperation(); // We need to verify that the blocks follow the following layout: // @@ -2917,11 +3162,11 @@ static LogicalResult verify(spirv::SelectionOp selectionOp) { // The last block is the merge block. if (!isMergeBlock(region.back())) - return selectionOp.emitOpError( + return emitOpError( "last block must be the merge block with only one 'spv.mlir.merge' op"); if (std::next(region.begin()) == region.end()) - return selectionOp.emitOpError("must have a selection header block"); + return emitOpError("must have a selection header block"); return success(); } @@ -3016,19 +3261,19 @@ static void print(spirv::SpecConstantOp constOp, OpAsmPrinter &printer) { printer << " = " << constOp.default_value(); } -static LogicalResult verify(spirv::SpecConstantOp constOp) { - if (auto specID = constOp->getAttrOfType(kSpecIdAttrName)) +LogicalResult spirv::SpecConstantOp::verify() { + if (auto specID = (*this)->getAttrOfType(kSpecIdAttrName)) if (specID.getValue().isNegative()) - return constOp.emitOpError("SpecId cannot be negative"); + return emitOpError("SpecId cannot be negative"); - auto value = constOp.default_value(); + auto value = default_value(); if (value.isa()) { // Make sure bitwidth is allowed. if (!value.getType().isa()) - return constOp.emitOpError("default value bitwidth disallowed"); + return emitOpError("default value bitwidth disallowed"); return success(); } - return constOp.emitOpError( + return emitOpError( "default value can only be a bool, integer, or float scalar"); } @@ -3070,27 +3315,24 @@ static void print(spirv::StoreOp storeOp, OpAsmPrinter &printer) { printer.printOptionalAttrDict(op->getAttrs(), elidedAttrs); } -static LogicalResult verify(spirv::StoreOp storeOp) { +LogicalResult spirv::StoreOp::verify() { // SPIR-V spec : "Pointer is the pointer to store through. Its type must be an // OpTypePointer whose Type operand is the same as the type of Object." - if (failed(verifyLoadStorePtrAndValTypes(storeOp, storeOp.ptr(), - storeOp.value()))) { + if (failed(verifyLoadStorePtrAndValTypes(*this, ptr(), value()))) return failure(); - } - return verifyMemoryAccessAttribute(storeOp); + return verifyMemoryAccessAttribute(*this); } //===----------------------------------------------------------------------===// // spv.Unreachable //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::UnreachableOp unreachableOp) { - auto *op = unreachableOp.getOperation(); - auto *block = op->getBlock(); +LogicalResult spirv::UnreachableOp::verify() { + auto *block = (*this)->getBlock(); // Fast track: if this is in entry block, its invalid. Otherwise, if no // predecessors, it's valid. if (block->isEntryBlock()) - return unreachableOp.emitOpError("cannot be used in reachable block"); + return emitOpError("cannot be used in reachable block"); if (block->hasNoPredecessors()) return success(); @@ -3156,34 +3398,34 @@ static void print(spirv::VariableOp varOp, OpAsmPrinter &printer) { printer << " : " << varOp.getType(); } -static LogicalResult verify(spirv::VariableOp varOp) { +LogicalResult spirv::VariableOp::verify() { // SPIR-V spec: "Storage Class is the Storage Class of the memory holding the // object. It cannot be Generic. It must be the same as the Storage Class // operand of the Result Type." - if (varOp.storage_class() != spirv::StorageClass::Function) { - return varOp.emitOpError( + if (storage_class() != spirv::StorageClass::Function) { + return emitOpError( "can only be used to model function-level variables. Use " "spv.GlobalVariable for module-level variables."); } - auto pointerType = varOp.pointer().getType().cast(); - if (varOp.storage_class() != pointerType.getStorageClass()) - return varOp.emitOpError( + auto pointerType = pointer().getType().cast(); + if (storage_class() != pointerType.getStorageClass()) + return emitOpError( "storage class must match result pointer's storage class"); - if (varOp.getNumOperands() != 0) { + if (getNumOperands() != 0) { // SPIR-V spec: "Initializer must be an from a constant instruction or // a global (module scope) OpVariable instruction". - auto *initOp = varOp.getOperand(0).getDefiningOp(); + auto *initOp = getOperand(0).getDefiningOp(); if (!initOp || !isa(initOp)) - return varOp.emitOpError("initializer must be the result of a " - "constant or spv.GlobalVariable op"); + return emitOpError("initializer must be the result of a " + "constant or spv.GlobalVariable op"); } // TODO: generate these strings using ODS. - auto *op = varOp.getOperation(); + auto *op = getOperation(); auto descriptorSetName = llvm::convertToSnakeFromCamelCase( stringifyDecoration(spirv::Decoration::DescriptorSet)); auto bindingName = llvm::convertToSnakeFromCamelCase( @@ -3193,7 +3435,7 @@ static LogicalResult verify(spirv::VariableOp varOp) { for (const auto &attr : {descriptorSetName, bindingName, builtInName}) { if (op->getAttr(attr)) - return varOp.emitOpError("cannot have '") + return emitOpError("cannot have '") << attr << "' attribute (only allowed in spv.GlobalVariable)"; } @@ -3204,26 +3446,25 @@ static LogicalResult verify(spirv::VariableOp varOp) { // spv.VectorShuffle //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::VectorShuffleOp shuffleOp) { - VectorType resultType = shuffleOp.getType().cast(); +LogicalResult spirv::VectorShuffleOp::verify() { + VectorType resultType = getType().cast(); size_t numResultElements = resultType.getNumElements(); - if (numResultElements != shuffleOp.components().size()) - return shuffleOp.emitOpError("result type element count (") + if (numResultElements != components().size()) + return emitOpError("result type element count (") << numResultElements << ") mismatch with the number of component selectors (" - << shuffleOp.components().size() << ")"; + << components().size() << ")"; size_t totalSrcElements = - shuffleOp.vector1().getType().cast().getNumElements() + - shuffleOp.vector2().getType().cast().getNumElements(); + vector1().getType().cast().getNumElements() + + vector2().getType().cast().getNumElements(); - for (const auto &selector : - shuffleOp.components().getAsValueRange()) { + for (const auto &selector : components().getAsValueRange()) { uint32_t index = selector.getZExtValue(); if (index >= totalSrcElements && index != std::numeric_limits().max()) - return shuffleOp.emitOpError("component selector ") + return emitOpError("component selector ") << index << " out of range: expected to be in [0, " << totalSrcElements << ") or 0xffffffff"; } @@ -3284,6 +3525,11 @@ static LogicalResult verifyPointerAndCoopMatrixType(Operation *op, Type pointer, return success(); } +LogicalResult spirv::CooperativeMatrixLoadNVOp::verify() { + return verifyPointerAndCoopMatrixType(*this, pointer().getType(), + result().getType()); +} + //===----------------------------------------------------------------------===// // spv.CooperativeMatrixStoreNV //===----------------------------------------------------------------------===// @@ -3321,6 +3567,11 @@ static void print(spirv::CooperativeMatrixStoreNVOp coopMatrix, << coopMatrix.getOperand(1).getType(); } +LogicalResult spirv::CooperativeMatrixStoreNVOp::verify() { + return verifyPointerAndCoopMatrixType(*this, pointer().getType(), + object().getType()); +} + //===----------------------------------------------------------------------===// // spv.CooperativeMatrixMulAddNV //===----------------------------------------------------------------------===// @@ -3347,39 +3598,43 @@ verifyCoopMatrixMulAdd(spirv::CooperativeMatrixMulAddNVOp op) { return success(); } +LogicalResult spirv::CooperativeMatrixMulAddNVOp::verify() { + return verifyCoopMatrixMulAdd(*this); +} + //===----------------------------------------------------------------------===// // spv.MatrixTimesScalar //===----------------------------------------------------------------------===// -static LogicalResult verifyMatrixTimesScalar(spirv::MatrixTimesScalarOp op) { +LogicalResult spirv::MatrixTimesScalarOp::verify() { // We already checked that result and matrix are both of matrix type in the // auto-generated verify method. - auto inputMatrix = op.matrix().getType().cast(); - auto resultMatrix = op.result().getType().cast(); + auto inputMatrix = matrix().getType().cast(); + auto resultMatrix = result().getType().cast(); // Check that the scalar type is the same as the matrix element type. - if (op.scalar().getType() != inputMatrix.getElementType()) - return op.emitError("input matrix components' type and scaling value must " - "have the same type"); + if (scalar().getType() != inputMatrix.getElementType()) + return emitError("input matrix components' type and scaling value must " + "have the same type"); // Note that the next three checks could be done using the AllTypesMatch // trait in the Op definition file but it generates a vague error message. // Check that the input and result matrices have the same columns' count if (inputMatrix.getNumColumns() != resultMatrix.getNumColumns()) - return op.emitError("input and result matrices must have the same " - "number of columns"); + return emitError("input and result matrices must have the same " + "number of columns"); // Check that the input and result matrices' have the same rows count if (inputMatrix.getNumRows() != resultMatrix.getNumRows()) - return op.emitError("input and result matrices' columns must have " - "the same size"); + return emitError("input and result matrices' columns must have " + "the same size"); // Check that the input and result matrices' have the same component type if (inputMatrix.getElementType() != resultMatrix.getElementType()) - return op.emitError("input and result matrices' columns must have " - "the same component type"); + return emitError("input and result matrices' columns must have " + "the same component type"); return success(); } @@ -3462,21 +3717,18 @@ static ParseResult parseCopyMemoryOp(OpAsmParser &parser, return success(); } -static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) { +LogicalResult spirv::CopyMemoryOp::verify() { Type targetType = - copyMemory.target().getType().cast().getPointeeType(); + target().getType().cast().getPointeeType(); Type sourceType = - copyMemory.source().getType().cast().getPointeeType(); + source().getType().cast().getPointeeType(); - if (targetType != sourceType) { - return copyMemory.emitOpError( - "both operands must be pointers to the same type"); - } + if (targetType != sourceType) + return emitOpError("both operands must be pointers to the same type"); - if (failed(verifyMemoryAccessAttribute(copyMemory))) { + if (failed(verifyMemoryAccessAttribute(*this))) return failure(); - } // TODO - According to the spec: // @@ -3486,30 +3738,30 @@ static LogicalResult verifyCopyMemory(spirv::CopyMemoryOp copyMemory) { // // Add such verification here. - return verifySourceMemoryAccessAttribute(copyMemory); + return verifySourceMemoryAccessAttribute(*this); } //===----------------------------------------------------------------------===// // spv.Transpose //===----------------------------------------------------------------------===// -static LogicalResult verifyTranspose(spirv::TransposeOp op) { - auto inputMatrix = op.matrix().getType().cast(); - auto resultMatrix = op.result().getType().cast(); +LogicalResult spirv::TransposeOp::verify() { + auto inputMatrix = matrix().getType().cast(); + auto resultMatrix = result().getType().cast(); // Verify that the input and output matrices have correct shapes. if (inputMatrix.getNumRows() != resultMatrix.getNumColumns()) - return op.emitError("input matrix rows count must be equal to " - "output matrix columns count"); + return emitError("input matrix rows count must be equal to " + "output matrix columns count"); if (inputMatrix.getNumColumns() != resultMatrix.getNumRows()) - return op.emitError("input matrix columns count must be equal to " - "output matrix rows count"); + return emitError("input matrix columns count must be equal to " + "output matrix rows count"); // Verify that the input and output matrices have the same component type if (inputMatrix.getElementType() != resultMatrix.getElementType()) - return op.emitError("input and output matrices must have the same " - "component type"); + return emitError("input and output matrices must have the same " + "component type"); return success(); } @@ -3518,35 +3770,34 @@ static LogicalResult verifyTranspose(spirv::TransposeOp op) { // spv.MatrixTimesMatrix //===----------------------------------------------------------------------===// -static LogicalResult verifyMatrixTimesMatrix(spirv::MatrixTimesMatrixOp op) { - auto leftMatrix = op.leftmatrix().getType().cast(); - auto rightMatrix = op.rightmatrix().getType().cast(); - auto resultMatrix = op.result().getType().cast(); +LogicalResult spirv::MatrixTimesMatrixOp::verify() { + auto leftMatrix = leftmatrix().getType().cast(); + auto rightMatrix = rightmatrix().getType().cast(); + auto resultMatrix = result().getType().cast(); // left matrix columns' count and right matrix rows' count must be equal if (leftMatrix.getNumColumns() != rightMatrix.getNumRows()) - return op.emitError("left matrix columns' count must be equal to " - "the right matrix rows' count"); + return emitError("left matrix columns' count must be equal to " + "the right matrix rows' count"); // right and result matrices columns' count must be the same if (rightMatrix.getNumColumns() != resultMatrix.getNumColumns()) - return op.emitError( + return emitError( "right and result matrices must have equal columns' count"); // right and result matrices component type must be the same if (rightMatrix.getElementType() != resultMatrix.getElementType()) - return op.emitError("right and result matrices' component type must" - " be the same"); + return emitError("right and result matrices' component type must" + " be the same"); // left and result matrices component type must be the same if (leftMatrix.getElementType() != resultMatrix.getElementType()) - return op.emitError("left and result matrices' component type" - " must be the same"); + return emitError("left and result matrices' component type" + " must be the same"); // left and result matrices rows count must be the same if (leftMatrix.getNumRows() != resultMatrix.getNumRows()) - return op.emitError("left and result matrices must have equal rows'" - " count"); + return emitError("left and result matrices must have equal rows' count"); return success(); } @@ -3607,19 +3858,18 @@ static void print(spirv::SpecConstantCompositeOp op, OpAsmPrinter &printer) { printer << ") : " << op.type(); } -static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) { - auto cType = constOp.type().dyn_cast(); - auto constituents = constOp.constituents().getValue(); +LogicalResult spirv::SpecConstantCompositeOp::verify() { + auto cType = type().dyn_cast(); + auto constituents = this->constituents().getValue(); if (!cType) - return constOp.emitError( - "result type must be a composite type, but provided ") - << constOp.type(); + return emitError("result type must be a composite type, but provided ") + << type(); if (cType.isa()) - return constOp.emitError("unsupported composite type ") << cType; + return emitError("unsupported composite type ") << cType; if (constituents.size() != cType.getNumElements()) - return constOp.emitError("has incorrect number of operands: expected ") + return emitError("has incorrect number of operands: expected ") << cType.getNumElements() << ", but provided " << constituents.size(); @@ -3628,11 +3878,11 @@ static LogicalResult verify(spirv::SpecConstantCompositeOp constOp) { auto constituentSpecConstOp = dyn_cast(SymbolTable::lookupNearestSymbolFrom( - constOp->getParentOp(), constituent.getAttr())); + (*this)->getParentOp(), constituent.getAttr())); if (constituentSpecConstOp.default_value().getType() != cType.getElementType(index)) - return constOp.emitError("has incorrect types of operands: expected ") + return emitError("has incorrect types of operands: expected ") << cType.getElementType(index) << ", but provided " << constituentSpecConstOp.default_value().getType(); } @@ -3676,21 +3926,21 @@ static void print(spirv::SpecConstantOperationOp op, OpAsmPrinter &printer) { printer.printGenericOp(&op.body().front().front()); } -static LogicalResult verify(spirv::SpecConstantOperationOp constOp) { - Block &block = constOp.getRegion().getBlocks().front(); +LogicalResult spirv::SpecConstantOperationOp::verify() { + Block &block = getRegion().getBlocks().front(); if (block.getOperations().size() != 2) - return constOp.emitOpError("expected exactly 2 nested ops"); + return emitOpError("expected exactly 2 nested ops"); Operation &enclosedOp = block.getOperations().front(); if (!enclosedOp.hasTrait()) - return constOp.emitOpError("invalid enclosed op"); + return emitOpError("invalid enclosed op"); for (auto operand : enclosedOp.getOperands()) if (!isa(operand.getDefiningOp())) - return constOp.emitOpError( + return emitOpError( "invalid operand, must be defined by a constant operation"); return success(); @@ -3699,39 +3949,35 @@ static LogicalResult verify(spirv::SpecConstantOperationOp constOp) { //===----------------------------------------------------------------------===// // spv.GLSL.FrexpStruct //===----------------------------------------------------------------------===// -static LogicalResult -verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp) { - spirv::StructType structTy = - frexpStructOp.result().getType().dyn_cast(); + +LogicalResult spirv::GLSLFrexpStructOp::verify() { + spirv::StructType structTy = result().getType().dyn_cast(); if (structTy.getNumElements() != 2) - return frexpStructOp.emitError("result type must be a struct type " - "with two memebers"); + return emitError("result type must be a struct type with two memebers"); Type significandTy = structTy.getElementType(0); Type exponentTy = structTy.getElementType(1); VectorType exponentVecTy = exponentTy.dyn_cast(); IntegerType exponentIntTy = exponentTy.dyn_cast(); - Type operandTy = frexpStructOp.operand().getType(); + Type operandTy = operand().getType(); VectorType operandVecTy = operandTy.dyn_cast(); FloatType operandFTy = operandTy.dyn_cast(); if (significandTy != operandTy) - return frexpStructOp.emitError("member zero of the resulting struct type " - "must be the same type as the operand"); + return emitError("member zero of the resulting struct type must be the " + "same type as the operand"); if (exponentVecTy) { IntegerType componentIntTy = exponentVecTy.getElementType().dyn_cast(); if (!(componentIntTy && componentIntTy.getWidth() == 32)) - return frexpStructOp.emitError( - "member one of the resulting struct type must" - "be a scalar or vector of 32 bit integer type"); + return emitError("member one of the resulting struct type must" + "be a scalar or vector of 32 bit integer type"); } else if (!(exponentIntTy && exponentIntTy.getWidth() == 32)) { - return frexpStructOp.emitError( - "member one of the resulting struct type " - "must be a scalar or vector of 32 bit integer type"); + return emitError("member one of the resulting struct type " + "must be a scalar or vector of 32 bit integer type"); } // Check that the two member types have the same number of components @@ -3742,21 +3988,20 @@ verifyGLSLFrexpStructOp(spirv::GLSLFrexpStructOp frexpStructOp) { if (operandFTy && exponentIntTy) return success(); - return frexpStructOp.emitError( - "member one of the resulting struct type " - "must have the same number of components as the operand type"); + return emitError("member one of the resulting struct type must have the same " + "number of components as the operand type"); } //===----------------------------------------------------------------------===// // spv.GLSL.Ldexp //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::GLSLLdexpOp ldexpOp) { - Type significandType = ldexpOp.x().getType(); - Type exponentType = ldexpOp.exp().getType(); +LogicalResult spirv::GLSLLdexpOp::verify() { + Type significandType = x().getType(); + Type exponentType = exp().getType(); if (significandType.isa() != exponentType.isa()) - return ldexpOp.emitOpError("operands must both be scalars or vectors"); + return emitOpError("operands must both be scalars or vectors"); auto getNumElements = [](Type type) -> unsigned { if (auto vectorType = type.dyn_cast()) @@ -3765,8 +4010,7 @@ static LogicalResult verify(spirv::GLSLLdexpOp ldexpOp) { }; if (getNumElements(significandType) != getNumElements(exponentType)) - return ldexpOp.emitOpError( - "operands must have the same number of elements"); + return emitOpError("operands must have the same number of elements"); return success(); } @@ -3775,22 +4019,19 @@ static LogicalResult verify(spirv::GLSLLdexpOp ldexpOp) { // spv.ImageDrefGather //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::ImageDrefGatherOp imageDrefGatherOp) { - VectorType resultType = - imageDrefGatherOp.result().getType().cast(); - auto sampledImageType = imageDrefGatherOp.sampledimage() - .getType() - .cast(); +LogicalResult spirv::ImageDrefGatherOp::verify() { + VectorType resultType = result().getType().cast(); + auto sampledImageType = + sampledimage().getType().cast(); auto imageType = sampledImageType.getImageType().cast(); if (resultType.getNumElements() != 4) - return imageDrefGatherOp.emitOpError( - "result type must be a vector of four components"); + return emitOpError("result type must be a vector of four components"); Type elementType = resultType.getElementType(); Type sampledElementType = imageType.getElementType(); if (!sampledElementType.isa() && elementType != sampledElementType) - return imageDrefGatherOp.emitOpError( + return emitOpError( "the component type of result must be the same as sampled type of the " "underlying image type"); @@ -3799,28 +4040,50 @@ static LogicalResult verify(spirv::ImageDrefGatherOp imageDrefGatherOp) { if (imageDim != spirv::Dim::Dim2D && imageDim != spirv::Dim::Cube && imageDim != spirv::Dim::Rect) - return imageDrefGatherOp.emitOpError( + return emitOpError( "the Dim operand of the underlying image type must be 2D, Cube, or " "Rect"); if (imageMS != spirv::ImageSamplingInfo::SingleSampled) - return imageDrefGatherOp.emitOpError( - "the MS operand of the underlying image type must be 0"); + return emitOpError("the MS operand of the underlying image type must be 0"); - spirv::ImageOperandsAttr attr = imageDrefGatherOp.imageoperandsAttr(); - auto operandArguments = imageDrefGatherOp.operand_arguments(); + spirv::ImageOperandsAttr attr = imageoperandsAttr(); + auto operandArguments = operand_arguments(); - return verifyImageOperands(imageDrefGatherOp, attr, operandArguments); + return verifyImageOperands(*this, attr, operandArguments); +} + +//===----------------------------------------------------------------------===// +// spv.ShiftLeftLogicalOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ShiftLeftLogicalOp::verify() { + return verifyShiftOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.ShiftRightArithmeticOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ShiftRightArithmeticOp::verify() { + return verifyShiftOp(*this); +} + +//===----------------------------------------------------------------------===// +// spv.ShiftRightLogicalOp +//===----------------------------------------------------------------------===// + +LogicalResult spirv::ShiftRightLogicalOp::verify() { + return verifyShiftOp(*this); } //===----------------------------------------------------------------------===// // spv.ImageQuerySize //===----------------------------------------------------------------------===// -static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) { - spirv::ImageType imageType = - imageQuerySizeOp.image().getType().cast(); - Type resultType = imageQuerySizeOp.result().getType(); +LogicalResult spirv::ImageQuerySizeOp::verify() { + spirv::ImageType imageType = image().getType().cast(); + Type resultType = result().getType(); spirv::Dim dim = imageType.getDim(); spirv::ImageSamplingInfo samplingInfo = imageType.getSamplingInfo(); @@ -3833,7 +4096,7 @@ static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) { if (!(samplingInfo == spirv::ImageSamplingInfo::MultiSampled || samplerInfo == spirv::ImageSamplerUseInfo::SamplerUnknown || samplerInfo == spirv::ImageSamplerUseInfo::NoSampler)) - return imageQuerySizeOp.emitError( + return emitError( "if Dim is 1D, 2D, 3D, or Cube, " "it must also have either an MS of 1 or a Sampled of 0 or 2"); break; @@ -3841,8 +4104,8 @@ static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) { case spirv::Dim::Rect: break; default: - return imageQuerySizeOp.emitError("the Dim operand of the image type must " - "be 1D, 2D, 3D, Buffer, Cube, or Rect"); + return emitError("the Dim operand of the image type must " + "be 1D, 2D, 3D, Buffer, Cube, or Rect"); } unsigned componentNumber = 0; @@ -3871,7 +4134,7 @@ static LogicalResult verify(spirv::ImageQuerySizeOp imageQuerySizeOp) { resultComponentNumber = resultVectorType.getNumElements(); if (componentNumber != resultComponentNumber) - return imageQuerySizeOp.emitError("expected the result to have ") + return emitError("expected the result to have ") << componentNumber << " component(s), but found " << resultComponentNumber << " component(s)"; @@ -3951,8 +4214,8 @@ static void print(spirv::InBoundsPtrAccessChainOp op, OpAsmPrinter &printer) { printAccessChain(op, concatElemAndIndices(op), printer); } -static LogicalResult verify(spirv::InBoundsPtrAccessChainOp accessChainOp) { - return verifyAccessChain(accessChainOp, accessChainOp.indices()); +LogicalResult spirv::InBoundsPtrAccessChainOp::verify() { + return verifyAccessChain(*this, indices()); } //===----------------------------------------------------------------------===// @@ -3977,8 +4240,8 @@ static void print(spirv::PtrAccessChainOp op, OpAsmPrinter &printer) { printAccessChain(op, concatElemAndIndices(op), printer); } -static LogicalResult verify(spirv::PtrAccessChainOp accessChainOp) { - return verifyAccessChain(accessChainOp, accessChainOp.indices()); +LogicalResult spirv::PtrAccessChainOp::verify() { + return verifyAccessChain(*this, indices()); } // TableGen'erated operation interfaces for querying versions, extensions, and