diff --git a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td index 346a2eecedbf..2af71109a786 100644 --- a/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td +++ b/mlir/include/mlir/Dialect/MemRef/IR/MemRefOps.td @@ -28,7 +28,6 @@ def MemRefTypeAttr class MemRef_Op traits = []> : Op { let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -93,6 +92,7 @@ class AllocLikeOp { let results = (outs); let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -162,6 +163,7 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, []> { memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> ``` }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -205,6 +207,7 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> { an alignment on any convenient boundary compatible with the type will be chosen. }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -253,6 +256,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope", let results = (outs Variadic:$results); let regions = (region SizedRegion<1>:$bodyRegion); + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -279,11 +283,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return", let arguments = (ins Variadic:$results); let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>]; - let assemblyFormat = - [{ attr-dict ($results^ `:` type($results))? }]; - - // No custom verification needed. - let verifier = ?; + let assemblyFormat = "attr-dict ($results^ `:` type($results))?"; } //===----------------------------------------------------------------------===// @@ -355,7 +355,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [ let arguments = (ins AnyRankedOrUnrankedMemRef:$source); let results = (outs AnyRankedOrUnrankedMemRef:$dest); let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; - let verifier = "return impl::verifyCastOp(*this, areCastCompatible);"; let builders = [ OpBuilder<(ins "Value":$source, "Type":$destType), [{ impl::buildCastOp($_builder, $_state, source, destType); @@ -370,6 +369,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [ }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -408,7 +408,6 @@ def CopyOp : MemRef_Op<"copy", let hasCanonicalizer = 1; let hasFolder = 1; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -434,7 +433,6 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> { let arguments = (ins Arg:$memref); let hasFolder = 1; - let verifier = ?; let assemblyFormat = "$memref attr-dict `:` type($memref)"; } @@ -488,6 +486,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> { let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -646,6 +645,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> { } }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -697,6 +697,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> { Value getNumElements() { return numElements(); } }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -757,6 +758,7 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [ return memref().getType().cast(); } }]; + let hasVerifier = 1; } def AtomicYieldOp : MemRef_Op<"atomic_yield", [ @@ -772,6 +774,7 @@ def AtomicYieldOp : MemRef_Op<"atomic_yield", [ let arguments = (ins AnyType:$result); let assemblyFormat = "$result attr-dict `:` type($result)"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -797,9 +800,6 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global", let arguments = (ins FlatSymbolRefAttr:$name); let results = (outs AnyStaticShapeMemRef:$result); let assemblyFormat = "$name `:` type($result) attr-dict"; - - // `GetGlobalOp` is fully verified by its traits. - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -866,6 +866,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> { return !isExternal() && initial_value().getValue().isa(); } }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -939,6 +940,7 @@ def LoadOp : MemRef_Op<"load", }]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; } @@ -982,6 +984,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> { }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1034,6 +1037,7 @@ def MemRef_ReinterpretCastOp: let parser = ?; let printer = ?; + let hasVerifier = 1; let builders = [ // Build a ReinterpretCastOp with mixed static and dynamic entries. @@ -1096,7 +1100,6 @@ def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> { let arguments = (ins AnyRankedOrUnrankedMemRef:$memref); let results = (outs Index); - let verifier = ?; let hasFolder = 1; let assemblyFormat = "$memref attr-dict `:` type($memref)"; } @@ -1161,6 +1164,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [ let assemblyFormat = [{ $source `(` $shape `)` attr-dict `:` functional-type(operands, results) }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1226,6 +1230,7 @@ class MemRef_ReassociativeReshapeOp traits = []> : let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } @@ -1265,6 +1270,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> { ``` }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasVerifier = 1; } def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> { @@ -1302,6 +1308,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> { ``` }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1369,6 +1376,7 @@ def MemRef_StoreOp : MemRef_Op<"store", }]; let hasFolder = 1; + let hasVerifier = 1; let assemblyFormat = [{ $value `,` $memref `[` $indices `]` attr-dict `:` type($memref) @@ -1617,6 +1625,7 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides< let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1645,8 +1654,6 @@ def TensorStoreOp : MemRef_Op<"tensor_store", let arguments = (ins AnyTensor:$tensor, Arg:$memref); - // TensorStoreOp is fully verified by traits. - let verifier = ?; let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)"; } @@ -1681,6 +1688,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>, }]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1749,6 +1757,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [ }]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -1796,6 +1805,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [ } }]; let hasFolder = 1; + let hasVerifier = 1; } #endif // MEMREF_OPS diff --git a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td index c87ee778665b..278fedbbd3cb 100644 --- a/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td +++ b/mlir/include/mlir/Dialect/SparseTensor/IR/SparseTensorOps.td @@ -21,7 +21,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td" class SparseTensor_Op traits = []> : Op { let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -50,6 +49,7 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>, ``` }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; + let hasVerifier = 1; } def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>, @@ -72,6 +72,7 @@ def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>, ``` }]; let assemblyFormat = "`[` $sizes `]` attr-dict `:` type($result)"; + let hasVerifier = 1; } def SparseTensor_ConvertOp : SparseTensor_Op<"convert", @@ -113,6 +114,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert", }]; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let hasFolder = 1; + let hasVerifier = 1; } def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>, @@ -137,6 +139,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>, }]; let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)" " `to` type($result)"; + let hasVerifier = 1; } def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>, @@ -161,6 +164,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>, }]; let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)" " `to` type($result)"; + let hasVerifier = 1; } def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>, @@ -183,6 +187,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>, ``` }]; let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)"; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -217,6 +222,7 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>, }]; let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`" " type($tensor) `,` type($indices) `,` type($value)"; + let hasVerifier = 1; } def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, @@ -258,6 +264,7 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, }]; let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)" " `,` type($filled) `,` type($added) `,` type($count)"; + let hasVerifier = 1; } def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>, @@ -292,6 +299,7 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>, " $added `,` $count attr-dict `:` type($tensor) `,`" " type($indices) `,` type($values) `,` type($filled) `,`" " type($added) `,` type($count)"; + let hasVerifier = 1; } def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>, @@ -324,6 +332,7 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>, ``` }]; let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)"; + let hasVerifier = 1; } def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>, @@ -349,6 +358,7 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>, ``` }]; let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; + let hasVerifier = 1; } def SparseTensor_OutOp : SparseTensor_Op<"out", []>, @@ -369,6 +379,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>, ``` }]; let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; + let hasVerifier = 1; } #endif // SPARSETENSOR_OPS diff --git a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td index 54af06b42505..a2f15b380b2e 100644 --- a/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td +++ b/mlir/include/mlir/Dialect/Tensor/IR/TensorOps.td @@ -20,7 +20,6 @@ include "mlir/Interfaces/ViewLikeInterface.td" class Tensor_Op traits = []> : Op { let printer = [{ return ::print(p, *this); }]; - let verifier = [{ return ::verify(*this); }]; let parser = [{ return ::parse$cppClass(parser, result); }]; } @@ -59,7 +58,6 @@ def Tensor_CastOp : Tensor_Op<"cast", [ let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let hasCanonicalizer = 1; - let verifier = ?; } //===----------------------------------------------------------------------===// @@ -111,6 +109,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> { let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -151,6 +150,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract", }]>]; let hasFolder = 1; + let hasVerifier = 1; } @@ -303,6 +303,7 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides< let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -339,9 +340,6 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [ let assemblyFormat = "$elements attr-dict `:` type($result)"; - // This op is fully verified by its traits. - let verifier = ?; - let skipDefaultBuilders = 1; let builders = [ OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>, @@ -394,6 +392,7 @@ def Tensor_GenerateOp : Tensor_Op<"generate", ]; let hasCanonicalizer = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -445,6 +444,7 @@ def Tensor_InsertOp : Tensor_Op<"insert", }]>]; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -564,6 +564,7 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides< let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -586,7 +587,6 @@ def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> { let arguments = (ins AnyTensor:$tensor); let results = (outs Index); - let verifier = ?; let hasFolder = 1; let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; } @@ -650,6 +650,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> { let assemblyFormat = [{ $source `(` $shape `)` attr-dict `:` functional-type(operands, results) }]; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -718,6 +719,7 @@ class Tensor_ReassociativeReshapeOp traits = []> : let hasFolder = 1; let hasCanonicalizer = 1; + let hasVerifier = 1; let printer = [{ return ::print(p, *this); }]; let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; } @@ -748,6 +750,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> { ``` }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasVerifier = 1; } def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { @@ -776,6 +779,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { ``` }]; let extraClassDeclaration = commonExtraClassDeclaration; + let hasVerifier = 1; } //===----------------------------------------------------------------------===// @@ -961,6 +965,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect, let hasCanonicalizer = 1; let hasFolder = 1; + let hasVerifier = 1; } @@ -984,7 +989,6 @@ def Tensor_YieldOp : Tensor_Op<"yield", // Dummy builder to appease code in templated ensureTerminator that // GenerateOp's auto-generated parser calls. let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; - let verifier = ?; } #endif // TENSOR_OPS diff --git a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp index d84d7089f173..a282bd2b8ae1 100644 --- a/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp +++ b/mlir/lib/Dialect/MemRef/IR/MemRefOps.cpp @@ -67,6 +67,10 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) { return NoneType::get(type.getContext()); } +LogicalResult memref::CastOp::verify() { + return impl::verifyCastOp(*this, areCastCompatible); +} + //===----------------------------------------------------------------------===// // AllocOp / AllocaOp //===----------------------------------------------------------------------===// @@ -95,15 +99,15 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) { return success(); } -static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } +LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); } -static LogicalResult verify(AllocaOp op) { +LogicalResult AllocaOp::verify() { // An alloca op needs to have an ancestor with an allocation scope trait. - if (!op->getParentWithTrait()) - return op.emitOpError( + if (!(*this)->getParentWithTrait()) + return emitOpError( "requires an ancestor op with AutomaticAllocationScope trait"); - return verifyAllocLikeOp(op); + return verifyAllocLikeOp(*this); } namespace { @@ -246,11 +250,8 @@ static ParseResult parseAllocaScopeOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(AllocaScopeOp op) { - if (failed(RegionBranchOpInterface::verifyTypes(op))) - return failure(); - - return success(); +LogicalResult AllocaScopeOp::verify() { + return RegionBranchOpInterface::verifyTypes(*this); } void AllocaScopeOp::getSuccessorRegions( @@ -268,10 +269,9 @@ void AllocaScopeOp::getSuccessorRegions( // AssumeAlignmentOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AssumeAlignmentOp op) { - unsigned alignment = op.alignment(); - if (!llvm::isPowerOf2_32(alignment)) - return op.emitOpError("alignment must be power of 2"); +LogicalResult AssumeAlignmentOp::verify() { + if (!llvm::isPowerOf2_32(alignment())) + return emitOpError("alignment must be power of 2"); return success(); } @@ -556,17 +556,17 @@ Optional DimOp::getConstantIndex() { return {}; } -static LogicalResult verify(DimOp op) { +LogicalResult DimOp::verify() { // Assume unknown index to be in range. - Optional index = op.getConstantIndex(); + Optional index = getConstantIndex(); if (!index.hasValue()) return success(); // Check that constant index is not knowingly out of range. - auto type = op.source().getType(); + auto type = source().getType(); if (auto memrefType = type.dyn_cast()) { if (index.getValue() >= memrefType.getRank()) - return op.emitOpError("index is out of range"); + return emitOpError("index is out of range"); } else if (type.isa()) { // Assume index to be in range. } else { @@ -866,67 +866,66 @@ static ParseResult parseDmaStartOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(DmaStartOp op) { - unsigned numOperands = op.getNumOperands(); +LogicalResult DmaStartOp::verify() { + unsigned numOperands = getNumOperands(); // Mandatory non-variadic operands are: src memref, dst memref, tag memref and // the number of elements. if (numOperands < 4) - return op.emitOpError("expected at least 4 operands"); + return emitOpError("expected at least 4 operands"); // Check types of operands. The order of these calls is important: the later // calls rely on some type properties to compute the operand position. // 1. Source memref. - if (!op.getSrcMemRef().getType().isa()) - return op.emitOpError("expected source to be of memref type"); - if (numOperands < op.getSrcMemRefRank() + 4) - return op.emitOpError() - << "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; - if (!op.getSrcIndices().empty() && - !llvm::all_of(op.getSrcIndices().getTypes(), + if (!getSrcMemRef().getType().isa()) + return emitOpError("expected source to be of memref type"); + if (numOperands < getSrcMemRefRank() + 4) + return emitOpError() << "expected at least " << getSrcMemRefRank() + 4 + << " operands"; + if (!getSrcIndices().empty() && + !llvm::all_of(getSrcIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return op.emitOpError("expected source indices to be of index type"); + return emitOpError("expected source indices to be of index type"); // 2. Destination memref. - if (!op.getDstMemRef().getType().isa()) - return op.emitOpError("expected destination to be of memref type"); - unsigned numExpectedOperands = - op.getSrcMemRefRank() + op.getDstMemRefRank() + 4; + if (!getDstMemRef().getType().isa()) + return emitOpError("expected destination to be of memref type"); + unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4; if (numOperands < numExpectedOperands) - return op.emitOpError() - << "expected at least " << numExpectedOperands << " operands"; - if (!op.getDstIndices().empty() && - !llvm::all_of(op.getDstIndices().getTypes(), + return emitOpError() << "expected at least " << numExpectedOperands + << " operands"; + if (!getDstIndices().empty() && + !llvm::all_of(getDstIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return op.emitOpError("expected destination indices to be of index type"); + return emitOpError("expected destination indices to be of index type"); // 3. Number of elements. - if (!op.getNumElements().getType().isIndex()) - return op.emitOpError("expected num elements to be of index type"); + if (!getNumElements().getType().isIndex()) + return emitOpError("expected num elements to be of index type"); // 4. Tag memref. - if (!op.getTagMemRef().getType().isa()) - return op.emitOpError("expected tag to be of memref type"); - numExpectedOperands += op.getTagMemRefRank(); + if (!getTagMemRef().getType().isa()) + return emitOpError("expected tag to be of memref type"); + numExpectedOperands += getTagMemRefRank(); if (numOperands < numExpectedOperands) - return op.emitOpError() - << "expected at least " << numExpectedOperands << " operands"; - if (!op.getTagIndices().empty() && - !llvm::all_of(op.getTagIndices().getTypes(), + return emitOpError() << "expected at least " << numExpectedOperands + << " operands"; + if (!getTagIndices().empty() && + !llvm::all_of(getTagIndices().getTypes(), [](Type t) { return t.isIndex(); })) - return op.emitOpError("expected tag indices to be of index type"); + return emitOpError("expected tag indices to be of index type"); // Optional stride-related operands must be either both present or both // absent. if (numOperands != numExpectedOperands && numOperands != numExpectedOperands + 2) - return op.emitOpError("incorrect number of operands"); + return emitOpError("incorrect number of operands"); // 5. Strides. - if (op.isStrided()) { - if (!op.getStride().getType().isIndex() || - !op.getNumElementsPerStride().getType().isIndex()) - return op.emitOpError( + if (isStrided()) { + if (!getStride().getType().isIndex() || + !getNumElementsPerStride().getType().isIndex()) + return emitOpError( "expected stride and num elements per stride to be of type index"); } @@ -949,14 +948,14 @@ LogicalResult DmaWaitOp::fold(ArrayRef cstOperands, return foldMemRefCast(*this); } -static LogicalResult verify(DmaWaitOp op) { +LogicalResult DmaWaitOp::verify() { // Check that the number of tag indices matches the tagMemRef rank. - unsigned numTagIndices = op.tagIndices().size(); - unsigned tagMemRefRank = op.getTagMemRefRank(); + unsigned numTagIndices = tagIndices().size(); + unsigned tagMemRefRank = getTagMemRefRank(); if (numTagIndices != tagMemRefRank) - return op.emitOpError() << "expected tagIndices to have the same number of " - "elements as the tagMemRef rank, expected " - << tagMemRefRank << ", but got " << numTagIndices; + return emitOpError() << "expected tagIndices to have the same number of " + "elements as the tagMemRef rank, expected " + << tagMemRefRank << ", but got " << numTagIndices; return success(); } @@ -979,14 +978,13 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result, } } -static LogicalResult verify(GenericAtomicRMWOp op) { - auto &body = op.getRegion(); +LogicalResult GenericAtomicRMWOp::verify() { + auto &body = getRegion(); if (body.getNumArguments() != 1) - return op.emitOpError("expected single number of entry block arguments"); + return emitOpError("expected single number of entry block arguments"); - if (op.getResult().getType() != body.getArgument(0).getType()) - return op.emitOpError( - "expected block argument of the same type result type"); + if (getResult().getType() != body.getArgument(0).getType()) + return emitOpError("expected block argument of the same type result type"); bool hasSideEffects = body.walk([&](Operation *nestedOp) { @@ -1034,12 +1032,12 @@ static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) { // AtomicYieldOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AtomicYieldOp op) { - Type parentType = op->getParentOp()->getResultTypes().front(); - Type resultType = op.result().getType(); +LogicalResult AtomicYieldOp::verify() { + Type parentType = (*this)->getParentOp()->getResultTypes().front(); + Type resultType = result().getType(); if (parentType != resultType) - return op.emitOpError() << "types mismatch between yield op: " << resultType - << " and its parent: " << parentType; + return emitOpError() << "types mismatch between yield op: " << resultType + << " and its parent: " << parentType; return success(); } @@ -1090,19 +1088,19 @@ parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr, return success(); } -static LogicalResult verify(GlobalOp op) { - auto memrefType = op.type().dyn_cast(); +LogicalResult GlobalOp::verify() { + auto memrefType = type().dyn_cast(); if (!memrefType || !memrefType.hasStaticShape()) - return op.emitOpError("type should be static shaped memref, but got ") - << op.type(); + return emitOpError("type should be static shaped memref, but got ") + << type(); // Verify that the initial value, if present, is either a unit attribute or // an elements attribute. - if (op.initial_value().hasValue()) { - Attribute initValue = op.initial_value().getValue(); + if (initial_value().hasValue()) { + Attribute initValue = initial_value().getValue(); if (!initValue.isa() && !initValue.isa()) - return op.emitOpError("initial value should be a unit or elements " - "attribute, but got ") + return emitOpError("initial value should be a unit or elements " + "attribute, but got ") << initValue; // Check that the type of the initial value is compatible with the type of @@ -1111,17 +1109,17 @@ static LogicalResult verify(GlobalOp op) { Type initType = initValue.getType(); Type tensorType = getTensorTypeFromMemRefType(memrefType); if (initType != tensorType) - return op.emitOpError("initial value expected to be of type ") + return emitOpError("initial value expected to be of type ") << tensorType << ", but was of type " << initType; } } - if (Optional alignAttr = op.alignment()) { + if (Optional alignAttr = alignment()) { uint64_t alignment = alignAttr.getValue(); if (!llvm::isPowerOf2_64(alignment)) - return op->emitError() << "alignment attribute value " << alignment - << " is not a power of 2"; + return emitError() << "alignment attribute value " << alignment + << " is not a power of 2"; } // TODO: verify visibility for declarations. @@ -1154,9 +1152,9 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) { // LoadOp //===----------------------------------------------------------------------===// -static LogicalResult verify(LoadOp op) { - if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) - return op.emitOpError("incorrect number of indices for load"); +LogicalResult LoadOp::verify() { + if (getNumOperands() != 1 + getMemRefType().getRank()) + return emitOpError("incorrect number of indices for load"); return success(); } @@ -1224,9 +1222,9 @@ static ParseResult parsePrefetchOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(PrefetchOp op) { - if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) - return op.emitOpError("too few indices"); +LogicalResult PrefetchOp::verify() { + if (getNumOperands() != 1 + getMemRefType().getRank()) + return emitOpError("too few indices"); return success(); } @@ -1306,26 +1304,25 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result, // TODO: ponder whether we want to allow missing trailing sizes/strides that are // completed automatically, like we have for subview and extract_slice. -static LogicalResult verify(ReinterpretCastOp op) { +LogicalResult ReinterpretCastOp::verify() { // The source and result memrefs should be in the same memory space. - auto srcType = op.source().getType().cast(); - auto resultType = op.getType().cast(); + auto srcType = source().getType().cast(); + auto resultType = getType().cast(); if (srcType.getMemorySpace() != resultType.getMemorySpace()) - return op.emitError("different memory spaces specified for source type ") + return emitError("different memory spaces specified for source type ") << srcType << " and result memref type " << resultType; if (srcType.getElementType() != resultType.getElementType()) - return op.emitError("different element types specified for source type ") + return emitError("different element types specified for source type ") << srcType << " and result memref type " << resultType; // Match sizes in result memref type and in static_sizes attribute. - for (auto &en : - llvm::enumerate(llvm::zip(resultType.getShape(), - extractFromI64ArrayAttr(op.static_sizes())))) { + for (auto &en : llvm::enumerate(llvm::zip( + resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) { int64_t resultSize = std::get<0>(en.value()); int64_t expectedSize = std::get<1>(en.value()); if (!ShapedType::isDynamic(resultSize) && !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) - return op.emitError("expected result type with size = ") + return emitError("expected result type with size = ") << expectedSize << " instead of " << resultSize << " in dim = " << en.index(); } @@ -1336,27 +1333,26 @@ static LogicalResult verify(ReinterpretCastOp op) { int64_t resultOffset; SmallVector resultStrides; if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) - return op.emitError( - "expected result type to have strided layout but found ") + return emitError("expected result type to have strided layout but found ") << resultType; // Match offset in result memref type and in static_offsets attribute. - int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); + int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front(); if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && !ShapedType::isDynamicStrideOrOffset(expectedOffset) && resultOffset != expectedOffset) - return op.emitError("expected result type with offset = ") + return emitError("expected result type with offset = ") << resultOffset << " instead of " << expectedOffset; // Match strides in result memref type and in static_strides attribute. for (auto &en : llvm::enumerate(llvm::zip( - resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { + resultStrides, extractFromI64ArrayAttr(static_strides())))) { int64_t resultStride = std::get<0>(en.value()); int64_t expectedStride = std::get<1>(en.value()); if (!ShapedType::isDynamicStrideOrOffset(resultStride) && !ShapedType::isDynamicStrideOrOffset(expectedStride) && resultStride != expectedStride) - return op.emitError("expected result type with stride = ") + return emitError("expected result type with stride = ") << expectedStride << " instead of " << resultStride << " in dim = " << en.index(); } @@ -1532,8 +1528,8 @@ static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType, return success(); } -static LogicalResult verify(ExpandShapeOp op) { - return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); +LogicalResult ExpandShapeOp::verify() { + return verifyReshapeOp(*this, getResultType(), getSrcType()); } void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, @@ -1542,8 +1538,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, CollapseMixedReshapeOps>(context); } -static LogicalResult verify(CollapseShapeOp op) { - return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); +LogicalResult CollapseShapeOp::verify() { + return verifyReshapeOp(*this, getSrcType(), getResultType()); } struct CollapseShapeOpMemRefCastFolder @@ -1593,32 +1589,30 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef operands) { // ReshapeOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ReshapeOp op) { - Type operandType = op.source().getType(); - Type resultType = op.result().getType(); +LogicalResult ReshapeOp::verify() { + Type operandType = source().getType(); + Type resultType = result().getType(); Type operandElementType = operandType.cast().getElementType(); Type resultElementType = resultType.cast().getElementType(); if (operandElementType != resultElementType) - return op.emitOpError("element types of source and destination memref " - "types should be the same"); + return emitOpError("element types of source and destination memref " + "types should be the same"); if (auto operandMemRefType = operandType.dyn_cast()) if (!operandMemRefType.getLayout().isIdentity()) - return op.emitOpError( - "source memref type should have identity affine map"); + return emitOpError("source memref type should have identity affine map"); - int64_t shapeSize = op.shape().getType().cast().getDimSize(0); + int64_t shapeSize = shape().getType().cast().getDimSize(0); auto resultMemRefType = resultType.dyn_cast(); if (resultMemRefType) { if (!resultMemRefType.getLayout().isIdentity()) - return op.emitOpError( - "result memref type should have identity affine map"); + return emitOpError("result memref type should have identity affine map"); if (shapeSize == ShapedType::kDynamicSize) - return op.emitOpError("cannot use shape operand with dynamic length to " - "reshape to statically-ranked memref type"); + return emitOpError("cannot use shape operand with dynamic length to " + "reshape to statically-ranked memref type"); if (shapeSize != resultMemRefType.getRank()) - return op.emitOpError( + return emitOpError( "length of shape operand differs from the result's memref rank"); } return success(); @@ -1628,9 +1622,9 @@ static LogicalResult verify(ReshapeOp op) { // StoreOp //===----------------------------------------------------------------------===// -static LogicalResult verify(StoreOp op) { - if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) - return op.emitOpError("store index operand count not equal to memref rank"); +LogicalResult StoreOp::verify() { + if (getNumOperands() != 2 + getMemRefType().getRank()) + return emitOpError("store index operand count not equal to memref rank"); return success(); } @@ -1951,29 +1945,29 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result, } /// Verifier for SubViewOp. -static LogicalResult verify(SubViewOp op) { - MemRefType baseType = op.getSourceType(); - MemRefType subViewType = op.getType(); +LogicalResult SubViewOp::verify() { + MemRefType baseType = getSourceType(); + MemRefType subViewType = getType(); // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != subViewType.getMemorySpace()) - return op.emitError("different memory spaces specified for base memref " - "type ") + return emitError("different memory spaces specified for base memref " + "type ") << baseType << " and subview memref type " << subViewType; // Verify that the base memref type has a strided layout map. if (!isStrided(baseType)) - return op.emitError("base type ") << baseType << " is not strided"; + return emitError("base type ") << baseType << " is not strided"; // Verify result type against inferred type. auto expectedType = SubViewOp::inferResultType( - baseType, extractFromI64ArrayAttr(op.static_offsets()), - extractFromI64ArrayAttr(op.static_sizes()), - extractFromI64ArrayAttr(op.static_strides())); + baseType, extractFromI64ArrayAttr(static_offsets()), + extractFromI64ArrayAttr(static_sizes()), + extractFromI64ArrayAttr(static_strides())); auto result = isRankReducedMemRefType(expectedType.cast(), - subViewType, op.getMixedSizes()); - return produceSubViewErrorMsg(result, op, expectedType); + subViewType, getMixedSizes()); + return produceSubViewErrorMsg(result, *this, expectedType); } raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { @@ -2278,18 +2272,17 @@ static ParseResult parseTransposeOp(OpAsmParser &parser, return success(); } -static LogicalResult verify(TransposeOp op) { - if (!op.permutation().isPermutation()) - return op.emitOpError("expected a permutation map"); - if (op.permutation().getNumDims() != op.getShapedType().getRank()) - return op.emitOpError( - "expected a permutation map of same rank as the input"); +LogicalResult TransposeOp::verify() { + if (!permutation().isPermutation()) + return emitOpError("expected a permutation map"); + if (permutation().getNumDims() != getShapedType().getRank()) + return emitOpError("expected a permutation map of same rank as the input"); - auto srcType = op.in().getType().cast(); - auto dstType = op.getType().cast(); - auto transposedType = inferTransposeResultType(srcType, op.permutation()); + auto srcType = in().getType().cast(); + auto dstType = getType().cast(); + auto transposedType = inferTransposeResultType(srcType, permutation()); if (dstType != transposedType) - return op.emitOpError("output type ") + return emitOpError("output type ") << dstType << " does not match transposed input type " << srcType << ", " << transposedType; return success(); @@ -2338,29 +2331,28 @@ static void print(OpAsmPrinter &p, ViewOp op) { p << " : " << op.getOperand(0).getType() << " to " << op.getType(); } -static LogicalResult verify(ViewOp op) { - auto baseType = op.getOperand(0).getType().cast(); - auto viewType = op.getType(); +LogicalResult ViewOp::verify() { + auto baseType = getOperand(0).getType().cast(); + auto viewType = getType(); // The base memref should have identity layout map (or none). if (!baseType.getLayout().isIdentity()) - return op.emitError("unsupported map for base memref type ") << baseType; + return emitError("unsupported map for base memref type ") << baseType; // The result memref should have identity layout map (or none). if (!viewType.getLayout().isIdentity()) - return op.emitError("unsupported map for result memref type ") << viewType; + return emitError("unsupported map for result memref type ") << viewType; // The base memref and the view memref should be in the same memory space. if (baseType.getMemorySpace() != viewType.getMemorySpace()) - return op.emitError("different memory spaces specified for base memref " - "type ") + return emitError("different memory spaces specified for base memref " + "type ") << baseType << " and view memref type " << viewType; // Verify that we have the correct number of sizes for the result type. unsigned numDynamicDims = viewType.getNumDynamicDims(); - if (op.sizes().size() != numDynamicDims) - return op.emitError("incorrect number of size operands for type ") - << viewType; + if (sizes().size() != numDynamicDims) + return emitError("incorrect number of size operands for type ") << viewType; return success(); } @@ -2467,19 +2459,19 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results, // AtomicRMWOp //===----------------------------------------------------------------------===// -static LogicalResult verify(AtomicRMWOp op) { - if (op.getMemRefType().getRank() != op.getNumOperands() - 2) - return op.emitOpError( +LogicalResult AtomicRMWOp::verify() { + if (getMemRefType().getRank() != getNumOperands() - 2) + return emitOpError( "expects the number of subscripts to be equal to memref rank"); - switch (op.kind()) { + switch (kind()) { case arith::AtomicRMWKind::addf: case arith::AtomicRMWKind::maxf: case arith::AtomicRMWKind::minf: case arith::AtomicRMWKind::mulf: - if (!op.value().getType().isa()) - return op.emitOpError() - << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) - << "' expects a floating-point type"; + if (!value().getType().isa()) + return emitOpError() << "with kind '" + << arith::stringifyAtomicRMWKind(kind()) + << "' expects a floating-point type"; break; case arith::AtomicRMWKind::addi: case arith::AtomicRMWKind::maxs: @@ -2489,10 +2481,10 @@ static LogicalResult verify(AtomicRMWOp op) { case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::ori: case arith::AtomicRMWKind::andi: - if (!op.value().getType().isa()) - return op.emitOpError() - << "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) - << "' expects an integer type"; + if (!value().getType().isa()) + return emitOpError() << "with kind '" + << arith::stringifyAtomicRMWKind(kind()) + << "' expects an integer type"; break; default: break; diff --git a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp index 5b0ee4656c49..ecbc989a2c14 100644 --- a/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp +++ b/mlir/lib/Dialect/SparseTensor/IR/SparseTensorDialect.cpp @@ -209,53 +209,51 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) { return failure(); } -static LogicalResult verify(NewOp op) { - if (!getSparseTensorEncoding(op.result().getType())) - return op.emitError("expected a sparse tensor result"); +LogicalResult NewOp::verify() { + if (!getSparseTensorEncoding(result().getType())) + return emitError("expected a sparse tensor result"); return success(); } -static LogicalResult verify(InitOp op) { - if (!getSparseTensorEncoding(op.result().getType())) - return op.emitError("expected a sparse tensor result"); - RankedTensorType ttp = op.getType().cast(); +LogicalResult InitOp::verify() { + if (!getSparseTensorEncoding(result().getType())) + return emitError("expected a sparse tensor result"); + RankedTensorType ttp = getType().cast(); unsigned rank = ttp.getRank(); - if (rank != op.sizes().size()) - return op.emitError("unexpected mismatch between tensor rank and sizes: ") - << rank << " vs. " << op.sizes().size(); + if (rank != sizes().size()) + return emitError("unexpected mismatch between tensor rank and sizes: ") + << rank << " vs. " << sizes().size(); auto shape = ttp.getShape(); for (unsigned i = 0; i < rank; i++) { if (shape[i] == ShapedType::kDynamicSize) continue; IntegerAttr constantAttr; - if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) || + if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) || constantAttr.getInt() != shape[i]) { - return op.emitError("unexpected mismatch with static dimension size ") + return emitError("unexpected mismatch with static dimension size ") << shape[i]; } } return success(); } -static LogicalResult verify(ConvertOp op) { - if (auto tp1 = op.source().getType().dyn_cast()) { - if (auto tp2 = op.dest().getType().dyn_cast()) { +LogicalResult ConvertOp::verify() { + if (auto tp1 = source().getType().dyn_cast()) { + if (auto tp2 = dest().getType().dyn_cast()) { if (tp1.getRank() != tp2.getRank()) - return op.emitError("unexpected conversion mismatch in rank"); + return emitError("unexpected conversion mismatch in rank"); auto shape1 = tp1.getShape(); auto shape2 = tp2.getShape(); // Accept size matches between the source and the destination type // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). - for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { + for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) - return op.emitError("unexpected conversion mismatch in dimension ") - << d; - } + return emitError("unexpected conversion mismatch in dimension ") << d; return success(); } } - return op.emitError("unexpected type in convert"); + return emitError("unexpected type in convert"); } OpFoldResult ConvertOp::fold(ArrayRef operands) { @@ -264,35 +262,35 @@ OpFoldResult ConvertOp::fold(ArrayRef operands) { return {}; } -static LogicalResult verify(ToPointersOp op) { - if (auto e = getSparseTensorEncoding(op.tensor().getType())) { - if (failed(isInBounds(op.dim(), op.tensor()))) - return op.emitError("requested pointers dimension out of bounds"); - if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) - return op.emitError("unexpected type for pointers"); +LogicalResult ToPointersOp::verify() { + if (auto e = getSparseTensorEncoding(tensor().getType())) { + if (failed(isInBounds(dim(), tensor()))) + return emitError("requested pointers dimension out of bounds"); + if (failed(isMatchingWidth(result(), e.getPointerBitWidth()))) + return emitError("unexpected type for pointers"); return success(); } - return op.emitError("expected a sparse tensor to get pointers"); + return emitError("expected a sparse tensor to get pointers"); } -static LogicalResult verify(ToIndicesOp op) { - if (auto e = getSparseTensorEncoding(op.tensor().getType())) { - if (failed(isInBounds(op.dim(), op.tensor()))) - return op.emitError("requested indices dimension out of bounds"); - if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) - return op.emitError("unexpected type for indices"); +LogicalResult ToIndicesOp::verify() { + if (auto e = getSparseTensorEncoding(tensor().getType())) { + if (failed(isInBounds(dim(), tensor()))) + return emitError("requested indices dimension out of bounds"); + if (failed(isMatchingWidth(result(), e.getIndexBitWidth()))) + return emitError("unexpected type for indices"); return success(); } - return op.emitError("expected a sparse tensor to get indices"); + return emitError("expected a sparse tensor to get indices"); } -static LogicalResult verify(ToValuesOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor to get values"); - RankedTensorType ttp = op.tensor().getType().cast(); - MemRefType mtp = op.result().getType().cast(); +LogicalResult ToValuesOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor to get values"); + RankedTensorType ttp = tensor().getType().cast(); + MemRefType mtp = result().getType().cast(); if (ttp.getElementType() != mtp.getElementType()) - return op.emitError("unexpected mismatch in element types"); + return emitError("unexpected mismatch in element types"); return success(); } @@ -300,39 +298,39 @@ static LogicalResult verify(ToValuesOp op) { // TensorDialect Management Operations. //===----------------------------------------------------------------------===// -static LogicalResult verify(LexInsertOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor for insertion"); +LogicalResult LexInsertOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor for insertion"); return success(); } -static LogicalResult verify(ExpandOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor for expansion"); +LogicalResult ExpandOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor for expansion"); return success(); } -static LogicalResult verify(CompressOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor for compression"); +LogicalResult CompressOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor for compression"); return success(); } -static LogicalResult verify(LoadOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor to materialize"); +LogicalResult LoadOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor to materialize"); return success(); } -static LogicalResult verify(ReleaseOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor to release"); +LogicalResult ReleaseOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor to release"); return success(); } -static LogicalResult verify(OutOp op) { - if (!getSparseTensorEncoding(op.tensor().getType())) - return op.emitError("expected a sparse tensor for output"); +LogicalResult OutOp::verify() { + if (!getSparseTensorEncoding(tensor().getType())) + return emitError("expected a sparse tensor for output"); return success(); } diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 2cc927555c3d..91dfddb2dfe9 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -228,17 +228,17 @@ Optional DimOp::getConstantIndex() { return {}; } -static LogicalResult verify(DimOp op) { +LogicalResult DimOp::verify() { // Assume unknown index to be in range. - Optional index = op.getConstantIndex(); + Optional index = getConstantIndex(); if (!index.hasValue()) return success(); // Check that constant index is not knowingly out of range. - auto type = op.source().getType(); + auto type = source().getType(); if (auto tensorType = type.dyn_cast()) { if (index.getValue() >= tensorType.getRank()) - return op.emitOpError("index is out of range"); + return emitOpError("index is out of range"); } else if (type.isa()) { // Assume index to be in range. } else { @@ -328,11 +328,11 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results, // ExtractOp //===----------------------------------------------------------------------===// -static LogicalResult verify(ExtractOp op) { +LogicalResult ExtractOp::verify() { // Verify the # indices match if we have a ranked type. - if (auto tensorType = op.tensor().getType().dyn_cast()) - if (tensorType.getRank() != static_cast(op.indices().size())) - return op.emitOpError("incorrect number of indices for extract_element"); + if (auto tensorType = tensor().getType().dyn_cast()) + if (tensorType.getRank() != static_cast(indices().size())) + return emitOpError("incorrect number of indices for extract_element"); return success(); } @@ -480,11 +480,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, // InsertOp //===----------------------------------------------------------------------===// -static LogicalResult verify(InsertOp op) { +LogicalResult InsertOp::verify() { // Verify the # indices match if we have a ranked type. - if (auto destType = op.dest().getType().dyn_cast()) - if (destType.getRank() != static_cast(op.indices().size())) - return op.emitOpError("incorrect number of indices"); + if (auto destType = dest().getType().dyn_cast()) + if (destType.getRank() != static_cast(indices().size())) + return emitOpError("incorrect number of indices"); return success(); } @@ -502,27 +502,26 @@ OpFoldResult InsertOp::fold(ArrayRef operands) { // GenerateOp //===----------------------------------------------------------------------===// -static LogicalResult verify(GenerateOp op) { +LogicalResult GenerateOp::verify() { // Ensure that the tensor type has as many dynamic dimensions as are specified // by the operands. - RankedTensorType resultTy = op.getType().cast(); - if (op.getNumOperands() != resultTy.getNumDynamicDims()) - return op.emitError("must have as many index operands as dynamic extents " - "in the result type"); + RankedTensorType resultTy = getType().cast(); + if (getNumOperands() != resultTy.getNumDynamicDims()) + return emitError("must have as many index operands as dynamic extents " + "in the result type"); // Ensure that region arguments span the index space. - if (!llvm::all_of(op.body().getArgumentTypes(), + if (!llvm::all_of(body().getArgumentTypes(), [](Type ty) { return ty.isIndex(); })) - return op.emitError("all body arguments must be index"); - if (op.body().getNumArguments() != resultTy.getRank()) - return op.emitError("must have one body argument per input dimension"); + return emitError("all body arguments must be index"); + if (body().getNumArguments() != resultTy.getRank()) + return emitError("must have one body argument per input dimension"); // Ensure that the region yields an element of the right type. - auto yieldOp = - llvm::cast(op.body().getBlocks().front().getTerminator()); + auto yieldOp = cast(body().getBlocks().front().getTerminator()); if (yieldOp.value().getType() != resultTy.getElementType()) - return op.emitOpError( + return emitOpError( "body must be terminated with a `yield` operation of the tensor " "element type"); @@ -686,16 +685,15 @@ static int64_t getNumElements(ShapedType type) { return numElements; } -static LogicalResult verify(ReshapeOp op) { - TensorType operandType = op.source().getType().cast(); - TensorType resultType = op.result().getType().cast(); +LogicalResult ReshapeOp::verify() { + TensorType operandType = source().getType().cast(); + TensorType resultType = result().getType().cast(); if (operandType.getElementType() != resultType.getElementType()) - return op.emitOpError("element types of source and destination tensor " - "types should be the same"); + return emitOpError("element types of source and destination tensor " + "types should be the same"); - int64_t shapeSize = - op.shape().getType().cast().getDimSize(0); + int64_t shapeSize = shape().getType().cast().getDimSize(0); auto resultRankedType = resultType.dyn_cast(); auto operandRankedType = operandType.dyn_cast(); @@ -703,14 +701,14 @@ static LogicalResult verify(ReshapeOp op) { if (operandRankedType && resultRankedType.hasStaticShape() && operandRankedType.hasStaticShape()) { if (getNumElements(operandRankedType) != getNumElements(resultRankedType)) - return op.emitOpError("source and destination tensor should have the " - "same number of elements"); + return emitOpError("source and destination tensor should have the " + "same number of elements"); } if (ShapedType::isDynamic(shapeSize)) - return op.emitOpError("cannot use shape operand with dynamic length to " - "reshape to statically-ranked tensor type"); + return emitOpError("cannot use shape operand with dynamic length to " + "reshape to statically-ranked tensor type"); if (shapeSize != resultRankedType.getRank()) - return op.emitOpError( + return emitOpError( "length of shape operand differs from the result's tensor rank"); } return success(); @@ -814,12 +812,12 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op, return success(); } -static LogicalResult verify(ExpandShapeOp op) { - return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); +LogicalResult ExpandShapeOp::verify() { + return verifyTensorReshapeOp(*this, getResultType(), getSrcType()); } -static LogicalResult verify(CollapseShapeOp op) { - return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); +LogicalResult CollapseShapeOp::verify() { + return verifyTensorReshapeOp(*this, getSrcType(), getResultType()); } namespace { @@ -1052,14 +1050,12 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result, } /// Verifier for ExtractSliceOp. -static LogicalResult verify(ExtractSliceOp op) { +LogicalResult ExtractSliceOp::verify() { // Verify result type against inferred type. - auto expectedType = - ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(), - op.getMixedSizes(), op.getMixedStrides()); - auto result = - isRankReducedType(expectedType.cast(), op.getType()); - return produceSliceErrorMsg(result, op, expectedType); + auto expectedType = ExtractSliceOp::inferResultType( + getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides()); + auto result = isRankReducedType(expectedType.cast(), getType()); + return produceSliceErrorMsg(result, *this, expectedType); } /// Infer the canonical type of the result of an extract_slice op. Returns a @@ -1308,16 +1304,16 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source, } /// Verifier for InsertSliceOp. -static LogicalResult verify(InsertSliceOp op) { +LogicalResult InsertSliceOp::verify() { // insert_slice is the inverse of extract_slice, use the same type inference. auto expectedType = ExtractSliceOp::inferRankReducedResultType( - op.getSourceType().getRank(), op.getType(), - extractFromI64ArrayAttr(op.static_offsets()), - extractFromI64ArrayAttr(op.static_sizes()), - extractFromI64ArrayAttr(op.static_strides())); + getSourceType().getRank(), getType(), + extractFromI64ArrayAttr(static_offsets()), + extractFromI64ArrayAttr(static_sizes()), + extractFromI64ArrayAttr(static_strides())); auto result = - isRankReducedType(expectedType.cast(), op.getSourceType()); - return produceSliceErrorMsg(result, op, expectedType); + isRankReducedType(expectedType.cast(), getSourceType()); + return produceSliceErrorMsg(result, *this, expectedType); } /// If we have two consecutive InsertSliceOp writing to the same slice, we @@ -1569,40 +1565,40 @@ ParseResult parseInferType(OpAsmParser &parser, return success(); } -static LogicalResult verify(PadOp op) { - auto sourceType = op.source().getType().cast(); - auto resultType = op.result().getType().cast(); - auto expectedType = PadOp::inferResultType( - sourceType, extractFromI64ArrayAttr(op.static_low()), - extractFromI64ArrayAttr(op.static_high())); +LogicalResult PadOp::verify() { + auto sourceType = source().getType().cast(); + auto resultType = result().getType().cast(); + auto expectedType = + PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()), + extractFromI64ArrayAttr(static_high())); for (int i = 0, e = sourceType.getRank(); i < e; ++i) { if (resultType.getDimSize(i) == expectedType.getDimSize(i)) continue; if (expectedType.isDynamicDim(i)) continue; - return op.emitError("specified type ") + return emitError("specified type ") << resultType << " does not match the inferred type " << expectedType; } - auto ®ion = op.region(); + auto ®ion = getRegion(); unsigned rank = resultType.getRank(); Block &block = region.front(); if (block.getNumArguments() != rank) - return op.emitError("expected the block to have ") << rank << " arguments"; + return emitError("expected the block to have ") << rank << " arguments"; // Note: the number and type of yield values are checked in the YieldOp. for (const auto &en : llvm::enumerate(block.getArgumentTypes())) { if (!en.value().isIndex()) - return op.emitOpError("expected block argument ") + return emitOpError("expected block argument ") << (en.index() + 1) << " to be an index"; } // Ensure that the region yields an element of the right type. auto yieldOp = llvm::cast(block.getTerminator()); if (yieldOp.value().getType() != - op.getType().cast().getElementType()) - return op.emitOpError("expected yield type to match shape element type"); + getType().cast().getElementType()) + return emitOpError("expected yield type to match shape element type"); return success(); }