[mlir][NFC] Update MemRef/Tensor operations to use `hasVerifier` instead of `verifier`

The verifier field is deprecated, and slated for removal.

Differential Revision: https://reviews.llvm.org/D118821
This commit is contained in:
River Riddle 2022-02-02 10:18:06 -08:00
parent bdc7ce975a
commit b98dc0351a
6 changed files with 325 additions and 314 deletions

View File

@ -28,7 +28,6 @@ def MemRefTypeAttr
class MemRef_Op<string mnemonic, list<Trait> traits = []>
: Op<MemRef_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
@ -93,6 +92,7 @@ class AllocLikeOp<string mnemonic,
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -115,6 +115,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
let results = (outs);
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -162,6 +163,7 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, []> {
memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
```
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -205,6 +207,7 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
an alignment on any convenient boundary compatible with the type will be
chosen.
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -253,6 +256,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$bodyRegion);
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -279,11 +283,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
let arguments = (ins Variadic<AnyType>:$results);
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
let assemblyFormat =
[{ attr-dict ($results^ `:` type($results))? }];
// No custom verification needed.
let verifier = ?;
let assemblyFormat = "attr-dict ($results^ `:` type($results))?";
}
//===----------------------------------------------------------------------===//
@ -355,7 +355,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
let builders = [
OpBuilder<(ins "Value":$source, "Type":$destType), [{
impl::buildCastOp($_builder, $_state, source, destType);
@ -370,6 +369,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [
}];
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -408,7 +408,6 @@ def CopyOp : MemRef_Op<"copy",
let hasCanonicalizer = 1;
let hasFolder = 1;
let verifier = ?;
}
//===----------------------------------------------------------------------===//
@ -434,7 +433,6 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", [MemFree]>:$memref);
let hasFolder = 1;
let verifier = ?;
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
@ -488,6 +486,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -646,6 +645,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
}
}];
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -697,6 +697,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
Value getNumElements() { return numElements(); }
}];
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -757,6 +758,7 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
return memref().getType().cast<MemRefType>();
}
}];
let hasVerifier = 1;
}
def AtomicYieldOp : MemRef_Op<"atomic_yield", [
@ -772,6 +774,7 @@ def AtomicYieldOp : MemRef_Op<"atomic_yield", [
let arguments = (ins AnyType:$result);
let assemblyFormat = "$result attr-dict `:` type($result)";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -797,9 +800,6 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
let arguments = (ins FlatSymbolRefAttr:$name);
let results = (outs AnyStaticShapeMemRef:$result);
let assemblyFormat = "$name `:` type($result) attr-dict";
// `GetGlobalOp` is fully verified by its traits.
let verifier = ?;
}
//===----------------------------------------------------------------------===//
@ -866,6 +866,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
return !isExternal() && initial_value().getValue().isa<UnitAttr>();
}
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -939,6 +940,7 @@ def LoadOp : MemRef_Op<"load",
}];
let hasFolder = 1;
let hasVerifier = 1;
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
}
@ -982,6 +984,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
}];
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -1034,6 +1037,7 @@ def MemRef_ReinterpretCastOp:
let parser = ?;
let printer = ?;
let hasVerifier = 1;
let builders = [
// Build a ReinterpretCastOp with mixed static and dynamic entries.
@ -1096,7 +1100,6 @@ def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> {
let arguments = (ins AnyRankedOrUnrankedMemRef:$memref);
let results = (outs Index);
let verifier = ?;
let hasFolder = 1;
let assemblyFormat = "$memref attr-dict `:` type($memref)";
}
@ -1161,6 +1164,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
let assemblyFormat = [{
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -1226,6 +1230,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}
@ -1265,6 +1270,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
```
}];
let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
}
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
@ -1302,6 +1308,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
```
}];
let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -1369,6 +1376,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
}];
let hasFolder = 1;
let hasVerifier = 1;
let assemblyFormat = [{
$value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
@ -1617,6 +1625,7 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -1645,8 +1654,6 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
"the reference to store to", [MemWrite]>:$memref);
// TensorStoreOp is fully verified by traits.
let verifier = ?;
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
}
@ -1681,6 +1688,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>,
}];
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -1749,6 +1757,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
}];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -1796,6 +1805,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
}
}];
let hasFolder = 1;
let hasVerifier = 1;
}
#endif // MEMREF_OPS

View File

@ -21,7 +21,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
: Op<SparseTensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
@ -50,6 +49,7 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
```
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
@ -72,6 +72,7 @@ def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
```
}];
let assemblyFormat = "`[` $sizes `]` attr-dict `:` type($result)";
let hasVerifier = 1;
}
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
@ -113,6 +114,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1;
let hasVerifier = 1;
}
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
@ -137,6 +139,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
}];
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
" `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
@ -161,6 +164,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
}];
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
" `to` type($result)";
let hasVerifier = 1;
}
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
@ -183,6 +187,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
```
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -217,6 +222,7 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
}];
let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
" type($tensor) `,` type($indices) `,` type($value)";
let hasVerifier = 1;
}
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
@ -258,6 +264,7 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
" `,` type($filled) `,` type($added) `,` type($count)";
let hasVerifier = 1;
}
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
@ -292,6 +299,7 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
" $added `,` $count attr-dict `:` type($tensor) `,`"
" type($indices) `,` type($values) `,` type($filled) `,`"
" type($added) `,` type($count)";
let hasVerifier = 1;
}
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
@ -324,6 +332,7 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
```
}];
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
@ -349,6 +358,7 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
```
}];
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
let hasVerifier = 1;
}
def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
@ -369,6 +379,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
```
}];
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
let hasVerifier = 1;
}
#endif // SPARSETENSOR_OPS

View File

@ -20,7 +20,6 @@ include "mlir/Interfaces/ViewLikeInterface.td"
class Tensor_Op<string mnemonic, list<Trait> traits = []>
: Op<Tensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
}
@ -59,7 +58,6 @@ def Tensor_CastOp : Tensor_Op<"cast", [
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasCanonicalizer = 1;
let verifier = ?;
}
//===----------------------------------------------------------------------===//
@ -111,6 +109,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -151,6 +150,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
}]>];
let hasFolder = 1;
let hasVerifier = 1;
}
@ -303,6 +303,7 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -339,9 +340,6 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
let assemblyFormat = "$elements attr-dict `:` type($result)";
// This op is fully verified by its traits.
let verifier = ?;
let skipDefaultBuilders = 1;
let builders = [
OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>,
@ -394,6 +392,7 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
];
let hasCanonicalizer = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -445,6 +444,7 @@ def Tensor_InsertOp : Tensor_Op<"insert",
}]>];
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -564,6 +564,7 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -586,7 +587,6 @@ def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> {
let arguments = (ins AnyTensor:$tensor);
let results = (outs Index);
let verifier = ?;
let hasFolder = 1;
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
}
@ -650,6 +650,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> {
let assemblyFormat = [{
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
}];
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -718,6 +719,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
let hasFolder = 1;
let hasCanonicalizer = 1;
let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
}
@ -748,6 +750,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
```
}];
let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
}
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
@ -776,6 +779,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
```
}];
let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
}
//===----------------------------------------------------------------------===//
@ -961,6 +965,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
let hasCanonicalizer = 1;
let hasFolder = 1;
let hasVerifier = 1;
}
@ -984,7 +989,6 @@ def Tensor_YieldOp : Tensor_Op<"yield",
// Dummy builder to appease code in templated ensureTerminator that
// GenerateOp's auto-generated parser calls.
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
let verifier = ?;
}
#endif // TENSOR_OPS

View File

@ -67,6 +67,10 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext());
}
LogicalResult memref::CastOp::verify() {
return impl::verifyCastOp(*this, areCastCompatible);
}
//===----------------------------------------------------------------------===//
// AllocOp / AllocaOp
//===----------------------------------------------------------------------===//
@ -95,15 +99,15 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
return success();
}
static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
static LogicalResult verify(AllocaOp op) {
LogicalResult AllocaOp::verify() {
// An alloca op needs to have an ancestor with an allocation scope trait.
if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
return op.emitOpError(
if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
return emitOpError(
"requires an ancestor op with AutomaticAllocationScope trait");
return verifyAllocLikeOp(op);
return verifyAllocLikeOp(*this);
}
namespace {
@ -246,11 +250,8 @@ static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(AllocaScopeOp op) {
if (failed(RegionBranchOpInterface::verifyTypes(op)))
return failure();
return success();
LogicalResult AllocaScopeOp::verify() {
return RegionBranchOpInterface::verifyTypes(*this);
}
void AllocaScopeOp::getSuccessorRegions(
@ -268,10 +269,9 @@ void AllocaScopeOp::getSuccessorRegions(
// AssumeAlignmentOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AssumeAlignmentOp op) {
unsigned alignment = op.alignment();
if (!llvm::isPowerOf2_32(alignment))
return op.emitOpError("alignment must be power of 2");
LogicalResult AssumeAlignmentOp::verify() {
if (!llvm::isPowerOf2_32(alignment()))
return emitOpError("alignment must be power of 2");
return success();
}
@ -556,17 +556,17 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {};
}
static LogicalResult verify(DimOp op) {
LogicalResult DimOp::verify() {
// Assume unknown index to be in range.
Optional<int64_t> index = op.getConstantIndex();
Optional<int64_t> index = getConstantIndex();
if (!index.hasValue())
return success();
// Check that constant index is not knowingly out of range.
auto type = op.source().getType();
auto type = source().getType();
if (auto memrefType = type.dyn_cast<MemRefType>()) {
if (index.getValue() >= memrefType.getRank())
return op.emitOpError("index is out of range");
return emitOpError("index is out of range");
} else if (type.isa<UnrankedMemRefType>()) {
// Assume index to be in range.
} else {
@ -866,67 +866,66 @@ static ParseResult parseDmaStartOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(DmaStartOp op) {
unsigned numOperands = op.getNumOperands();
LogicalResult DmaStartOp::verify() {
unsigned numOperands = getNumOperands();
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
// the number of elements.
if (numOperands < 4)
return op.emitOpError("expected at least 4 operands");
return emitOpError("expected at least 4 operands");
// Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position.
// 1. Source memref.
if (!op.getSrcMemRef().getType().isa<MemRefType>())
return op.emitOpError("expected source to be of memref type");
if (numOperands < op.getSrcMemRefRank() + 4)
return op.emitOpError()
<< "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
if (!op.getSrcIndices().empty() &&
!llvm::all_of(op.getSrcIndices().getTypes(),
if (!getSrcMemRef().getType().isa<MemRefType>())
return emitOpError("expected source to be of memref type");
if (numOperands < getSrcMemRefRank() + 4)
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
<< " operands";
if (!getSrcIndices().empty() &&
!llvm::all_of(getSrcIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return op.emitOpError("expected source indices to be of index type");
return emitOpError("expected source indices to be of index type");
// 2. Destination memref.
if (!op.getDstMemRef().getType().isa<MemRefType>())
return op.emitOpError("expected destination to be of memref type");
unsigned numExpectedOperands =
op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
if (!getDstMemRef().getType().isa<MemRefType>())
return emitOpError("expected destination to be of memref type");
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
if (numOperands < numExpectedOperands)
return op.emitOpError()
<< "expected at least " << numExpectedOperands << " operands";
if (!op.getDstIndices().empty() &&
!llvm::all_of(op.getDstIndices().getTypes(),
return emitOpError() << "expected at least " << numExpectedOperands
<< " operands";
if (!getDstIndices().empty() &&
!llvm::all_of(getDstIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return op.emitOpError("expected destination indices to be of index type");
return emitOpError("expected destination indices to be of index type");
// 3. Number of elements.
if (!op.getNumElements().getType().isIndex())
return op.emitOpError("expected num elements to be of index type");
if (!getNumElements().getType().isIndex())
return emitOpError("expected num elements to be of index type");
// 4. Tag memref.
if (!op.getTagMemRef().getType().isa<MemRefType>())
return op.emitOpError("expected tag to be of memref type");
numExpectedOperands += op.getTagMemRefRank();
if (!getTagMemRef().getType().isa<MemRefType>())
return emitOpError("expected tag to be of memref type");
numExpectedOperands += getTagMemRefRank();
if (numOperands < numExpectedOperands)
return op.emitOpError()
<< "expected at least " << numExpectedOperands << " operands";
if (!op.getTagIndices().empty() &&
!llvm::all_of(op.getTagIndices().getTypes(),
return emitOpError() << "expected at least " << numExpectedOperands
<< " operands";
if (!getTagIndices().empty() &&
!llvm::all_of(getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); }))
return op.emitOpError("expected tag indices to be of index type");
return emitOpError("expected tag indices to be of index type");
// Optional stride-related operands must be either both present or both
// absent.
if (numOperands != numExpectedOperands &&
numOperands != numExpectedOperands + 2)
return op.emitOpError("incorrect number of operands");
return emitOpError("incorrect number of operands");
// 5. Strides.
if (op.isStrided()) {
if (!op.getStride().getType().isIndex() ||
!op.getNumElementsPerStride().getType().isIndex())
return op.emitOpError(
if (isStrided()) {
if (!getStride().getType().isIndex() ||
!getNumElementsPerStride().getType().isIndex())
return emitOpError(
"expected stride and num elements per stride to be of type index");
}
@ -949,14 +948,14 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
return foldMemRefCast(*this);
}
static LogicalResult verify(DmaWaitOp op) {
LogicalResult DmaWaitOp::verify() {
// Check that the number of tag indices matches the tagMemRef rank.
unsigned numTagIndices = op.tagIndices().size();
unsigned tagMemRefRank = op.getTagMemRefRank();
unsigned numTagIndices = tagIndices().size();
unsigned tagMemRefRank = getTagMemRefRank();
if (numTagIndices != tagMemRefRank)
return op.emitOpError() << "expected tagIndices to have the same number of "
"elements as the tagMemRef rank, expected "
<< tagMemRefRank << ", but got " << numTagIndices;
return emitOpError() << "expected tagIndices to have the same number of "
"elements as the tagMemRef rank, expected "
<< tagMemRefRank << ", but got " << numTagIndices;
return success();
}
@ -979,14 +978,13 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
}
}
static LogicalResult verify(GenericAtomicRMWOp op) {
auto &body = op.getRegion();
LogicalResult GenericAtomicRMWOp::verify() {
auto &body = getRegion();
if (body.getNumArguments() != 1)
return op.emitOpError("expected single number of entry block arguments");
return emitOpError("expected single number of entry block arguments");
if (op.getResult().getType() != body.getArgument(0).getType())
return op.emitOpError(
"expected block argument of the same type result type");
if (getResult().getType() != body.getArgument(0).getType())
return emitOpError("expected block argument of the same type result type");
bool hasSideEffects =
body.walk([&](Operation *nestedOp) {
@ -1034,12 +1032,12 @@ static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
// AtomicYieldOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicYieldOp op) {
Type parentType = op->getParentOp()->getResultTypes().front();
Type resultType = op.result().getType();
LogicalResult AtomicYieldOp::verify() {
Type parentType = (*this)->getParentOp()->getResultTypes().front();
Type resultType = result().getType();
if (parentType != resultType)
return op.emitOpError() << "types mismatch between yield op: " << resultType
<< " and its parent: " << parentType;
return emitOpError() << "types mismatch between yield op: " << resultType
<< " and its parent: " << parentType;
return success();
}
@ -1090,19 +1088,19 @@ parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
return success();
}
static LogicalResult verify(GlobalOp op) {
auto memrefType = op.type().dyn_cast<MemRefType>();
LogicalResult GlobalOp::verify() {
auto memrefType = type().dyn_cast<MemRefType>();
if (!memrefType || !memrefType.hasStaticShape())
return op.emitOpError("type should be static shaped memref, but got ")
<< op.type();
return emitOpError("type should be static shaped memref, but got ")
<< type();
// Verify that the initial value, if present, is either a unit attribute or
// an elements attribute.
if (op.initial_value().hasValue()) {
Attribute initValue = op.initial_value().getValue();
if (initial_value().hasValue()) {
Attribute initValue = initial_value().getValue();
if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
return op.emitOpError("initial value should be a unit or elements "
"attribute, but got ")
return emitOpError("initial value should be a unit or elements "
"attribute, but got ")
<< initValue;
// Check that the type of the initial value is compatible with the type of
@ -1111,17 +1109,17 @@ static LogicalResult verify(GlobalOp op) {
Type initType = initValue.getType();
Type tensorType = getTensorTypeFromMemRefType(memrefType);
if (initType != tensorType)
return op.emitOpError("initial value expected to be of type ")
return emitOpError("initial value expected to be of type ")
<< tensorType << ", but was of type " << initType;
}
}
if (Optional<uint64_t> alignAttr = op.alignment()) {
if (Optional<uint64_t> alignAttr = alignment()) {
uint64_t alignment = alignAttr.getValue();
if (!llvm::isPowerOf2_64(alignment))
return op->emitError() << "alignment attribute value " << alignment
<< " is not a power of 2";
return emitError() << "alignment attribute value " << alignment
<< " is not a power of 2";
}
// TODO: verify visibility for declarations.
@ -1154,9 +1152,9 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// LoadOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(LoadOp op) {
if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
return op.emitOpError("incorrect number of indices for load");
LogicalResult LoadOp::verify() {
if (getNumOperands() != 1 + getMemRefType().getRank())
return emitOpError("incorrect number of indices for load");
return success();
}
@ -1224,9 +1222,9 @@ static ParseResult parsePrefetchOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(PrefetchOp op) {
if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
return op.emitOpError("too few indices");
LogicalResult PrefetchOp::verify() {
if (getNumOperands() != 1 + getMemRefType().getRank())
return emitOpError("too few indices");
return success();
}
@ -1306,26 +1304,25 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
// TODO: ponder whether we want to allow missing trailing sizes/strides that are
// completed automatically, like we have for subview and extract_slice.
static LogicalResult verify(ReinterpretCastOp op) {
LogicalResult ReinterpretCastOp::verify() {
// The source and result memrefs should be in the same memory space.
auto srcType = op.source().getType().cast<BaseMemRefType>();
auto resultType = op.getType().cast<MemRefType>();
auto srcType = source().getType().cast<BaseMemRefType>();
auto resultType = getType().cast<MemRefType>();
if (srcType.getMemorySpace() != resultType.getMemorySpace())
return op.emitError("different memory spaces specified for source type ")
return emitError("different memory spaces specified for source type ")
<< srcType << " and result memref type " << resultType;
if (srcType.getElementType() != resultType.getElementType())
return op.emitError("different element types specified for source type ")
return emitError("different element types specified for source type ")
<< srcType << " and result memref type " << resultType;
// Match sizes in result memref type and in static_sizes attribute.
for (auto &en :
llvm::enumerate(llvm::zip(resultType.getShape(),
extractFromI64ArrayAttr(op.static_sizes())))) {
for (auto &en : llvm::enumerate(llvm::zip(
resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
int64_t resultSize = std::get<0>(en.value());
int64_t expectedSize = std::get<1>(en.value());
if (!ShapedType::isDynamic(resultSize) &&
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
return op.emitError("expected result type with size = ")
return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize
<< " in dim = " << en.index();
}
@ -1336,27 +1333,26 @@ static LogicalResult verify(ReinterpretCastOp op) {
int64_t resultOffset;
SmallVector<int64_t, 4> resultStrides;
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
return op.emitError(
"expected result type to have strided layout but found ")
return emitError("expected result type to have strided layout but found ")
<< resultType;
// Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
!ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
resultOffset != expectedOffset)
return op.emitError("expected result type with offset = ")
return emitError("expected result type with offset = ")
<< resultOffset << " instead of " << expectedOffset;
// Match strides in result memref type and in static_strides attribute.
for (auto &en : llvm::enumerate(llvm::zip(
resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
resultStrides, extractFromI64ArrayAttr(static_strides())))) {
int64_t resultStride = std::get<0>(en.value());
int64_t expectedStride = std::get<1>(en.value());
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
!ShapedType::isDynamicStrideOrOffset(expectedStride) &&
resultStride != expectedStride)
return op.emitError("expected result type with stride = ")
return emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride
<< " in dim = " << en.index();
}
@ -1532,8 +1528,8 @@ static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
return success();
}
static LogicalResult verify(ExpandShapeOp op) {
return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
LogicalResult ExpandShapeOp::verify() {
return verifyReshapeOp(*this, getResultType(), getSrcType());
}
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@ -1542,8 +1538,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
}
static LogicalResult verify(CollapseShapeOp op) {
return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
LogicalResult CollapseShapeOp::verify() {
return verifyReshapeOp(*this, getSrcType(), getResultType());
}
struct CollapseShapeOpMemRefCastFolder
@ -1593,32 +1589,30 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
// ReshapeOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ReshapeOp op) {
Type operandType = op.source().getType();
Type resultType = op.result().getType();
LogicalResult ReshapeOp::verify() {
Type operandType = source().getType();
Type resultType = result().getType();
Type operandElementType = operandType.cast<ShapedType>().getElementType();
Type resultElementType = resultType.cast<ShapedType>().getElementType();
if (operandElementType != resultElementType)
return op.emitOpError("element types of source and destination memref "
"types should be the same");
return emitOpError("element types of source and destination memref "
"types should be the same");
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
if (!operandMemRefType.getLayout().isIdentity())
return op.emitOpError(
"source memref type should have identity affine map");
return emitOpError("source memref type should have identity affine map");
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
if (resultMemRefType) {
if (!resultMemRefType.getLayout().isIdentity())
return op.emitOpError(
"result memref type should have identity affine map");
return emitOpError("result memref type should have identity affine map");
if (shapeSize == ShapedType::kDynamicSize)
return op.emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked memref type");
return emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked memref type");
if (shapeSize != resultMemRefType.getRank())
return op.emitOpError(
return emitOpError(
"length of shape operand differs from the result's memref rank");
}
return success();
@ -1628,9 +1622,9 @@ static LogicalResult verify(ReshapeOp op) {
// StoreOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(StoreOp op) {
if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
return op.emitOpError("store index operand count not equal to memref rank");
LogicalResult StoreOp::verify() {
if (getNumOperands() != 2 + getMemRefType().getRank())
return emitOpError("store index operand count not equal to memref rank");
return success();
}
@ -1951,29 +1945,29 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
}
/// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) {
MemRefType baseType = op.getSourceType();
MemRefType subViewType = op.getType();
LogicalResult SubViewOp::verify() {
MemRefType baseType = getSourceType();
MemRefType subViewType = getType();
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
return op.emitError("different memory spaces specified for base memref "
"type ")
return emitError("different memory spaces specified for base memref "
"type ")
<< baseType << " and subview memref type " << subViewType;
// Verify that the base memref type has a strided layout map.
if (!isStrided(baseType))
return op.emitError("base type ") << baseType << " is not strided";
return emitError("base type ") << baseType << " is not strided";
// Verify result type against inferred type.
auto expectedType = SubViewOp::inferResultType(
baseType, extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
baseType, extractFromI64ArrayAttr(static_offsets()),
extractFromI64ArrayAttr(static_sizes()),
extractFromI64ArrayAttr(static_strides()));
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
subViewType, op.getMixedSizes());
return produceSubViewErrorMsg(result, op, expectedType);
subViewType, getMixedSizes());
return produceSubViewErrorMsg(result, *this, expectedType);
}
raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
@ -2278,18 +2272,17 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
return success();
}
static LogicalResult verify(TransposeOp op) {
if (!op.permutation().isPermutation())
return op.emitOpError("expected a permutation map");
if (op.permutation().getNumDims() != op.getShapedType().getRank())
return op.emitOpError(
"expected a permutation map of same rank as the input");
LogicalResult TransposeOp::verify() {
if (!permutation().isPermutation())
return emitOpError("expected a permutation map");
if (permutation().getNumDims() != getShapedType().getRank())
return emitOpError("expected a permutation map of same rank as the input");
auto srcType = op.in().getType().cast<MemRefType>();
auto dstType = op.getType().cast<MemRefType>();
auto transposedType = inferTransposeResultType(srcType, op.permutation());
auto srcType = in().getType().cast<MemRefType>();
auto dstType = getType().cast<MemRefType>();
auto transposedType = inferTransposeResultType(srcType, permutation());
if (dstType != transposedType)
return op.emitOpError("output type ")
return emitOpError("output type ")
<< dstType << " does not match transposed input type " << srcType
<< ", " << transposedType;
return success();
@ -2338,29 +2331,28 @@ static void print(OpAsmPrinter &p, ViewOp op) {
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
}
static LogicalResult verify(ViewOp op) {
auto baseType = op.getOperand(0).getType().cast<MemRefType>();
auto viewType = op.getType();
LogicalResult ViewOp::verify() {
auto baseType = getOperand(0).getType().cast<MemRefType>();
auto viewType = getType();
// The base memref should have identity layout map (or none).
if (!baseType.getLayout().isIdentity())
return op.emitError("unsupported map for base memref type ") << baseType;
return emitError("unsupported map for base memref type ") << baseType;
// The result memref should have identity layout map (or none).
if (!viewType.getLayout().isIdentity())
return op.emitError("unsupported map for result memref type ") << viewType;
return emitError("unsupported map for result memref type ") << viewType;
// The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != viewType.getMemorySpace())
return op.emitError("different memory spaces specified for base memref "
"type ")
return emitError("different memory spaces specified for base memref "
"type ")
<< baseType << " and view memref type " << viewType;
// Verify that we have the correct number of sizes for the result type.
unsigned numDynamicDims = viewType.getNumDynamicDims();
if (op.sizes().size() != numDynamicDims)
return op.emitError("incorrect number of size operands for type ")
<< viewType;
if (sizes().size() != numDynamicDims)
return emitError("incorrect number of size operands for type ") << viewType;
return success();
}
@ -2467,19 +2459,19 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
// AtomicRMWOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicRMWOp op) {
if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
return op.emitOpError(
LogicalResult AtomicRMWOp::verify() {
if (getMemRefType().getRank() != getNumOperands() - 2)
return emitOpError(
"expects the number of subscripts to be equal to memref rank");
switch (op.kind()) {
switch (kind()) {
case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::mulf:
if (!op.value().getType().isa<FloatType>())
return op.emitOpError()
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
<< "' expects a floating-point type";
if (!value().getType().isa<FloatType>())
return emitOpError() << "with kind '"
<< arith::stringifyAtomicRMWKind(kind())
<< "' expects a floating-point type";
break;
case arith::AtomicRMWKind::addi:
case arith::AtomicRMWKind::maxs:
@ -2489,10 +2481,10 @@ static LogicalResult verify(AtomicRMWOp op) {
case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori:
case arith::AtomicRMWKind::andi:
if (!op.value().getType().isa<IntegerType>())
return op.emitOpError()
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
<< "' expects an integer type";
if (!value().getType().isa<IntegerType>())
return emitOpError() << "with kind '"
<< arith::stringifyAtomicRMWKind(kind())
<< "' expects an integer type";
break;
default:
break;

View File

@ -209,53 +209,51 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
return failure();
}
static LogicalResult verify(NewOp op) {
if (!getSparseTensorEncoding(op.result().getType()))
return op.emitError("expected a sparse tensor result");
LogicalResult NewOp::verify() {
if (!getSparseTensorEncoding(result().getType()))
return emitError("expected a sparse tensor result");
return success();
}
static LogicalResult verify(InitOp op) {
if (!getSparseTensorEncoding(op.result().getType()))
return op.emitError("expected a sparse tensor result");
RankedTensorType ttp = op.getType().cast<RankedTensorType>();
LogicalResult InitOp::verify() {
if (!getSparseTensorEncoding(result().getType()))
return emitError("expected a sparse tensor result");
RankedTensorType ttp = getType().cast<RankedTensorType>();
unsigned rank = ttp.getRank();
if (rank != op.sizes().size())
return op.emitError("unexpected mismatch between tensor rank and sizes: ")
<< rank << " vs. " << op.sizes().size();
if (rank != sizes().size())
return emitError("unexpected mismatch between tensor rank and sizes: ")
<< rank << " vs. " << sizes().size();
auto shape = ttp.getShape();
for (unsigned i = 0; i < rank; i++) {
if (shape[i] == ShapedType::kDynamicSize)
continue;
IntegerAttr constantAttr;
if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) ||
if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) ||
constantAttr.getInt() != shape[i]) {
return op.emitError("unexpected mismatch with static dimension size ")
return emitError("unexpected mismatch with static dimension size ")
<< shape[i];
}
}
return success();
}
static LogicalResult verify(ConvertOp op) {
if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
LogicalResult ConvertOp::verify() {
if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
if (tp1.getRank() != tp2.getRank())
return op.emitError("unexpected conversion mismatch in rank");
return emitError("unexpected conversion mismatch in rank");
auto shape1 = tp1.getShape();
auto shape2 = tp2.getShape();
// Accept size matches between the source and the destination type
// (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
// matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
return op.emitError("unexpected conversion mismatch in dimension ")
<< d;
}
return emitError("unexpected conversion mismatch in dimension ") << d;
return success();
}
}
return op.emitError("unexpected type in convert");
return emitError("unexpected type in convert");
}
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
@ -264,35 +262,35 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
return {};
}
static LogicalResult verify(ToPointersOp op) {
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
if (failed(isInBounds(op.dim(), op.tensor())))
return op.emitError("requested pointers dimension out of bounds");
if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
return op.emitError("unexpected type for pointers");
LogicalResult ToPointersOp::verify() {
if (auto e = getSparseTensorEncoding(tensor().getType())) {
if (failed(isInBounds(dim(), tensor())))
return emitError("requested pointers dimension out of bounds");
if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
return emitError("unexpected type for pointers");
return success();
}
return op.emitError("expected a sparse tensor to get pointers");
return emitError("expected a sparse tensor to get pointers");
}
static LogicalResult verify(ToIndicesOp op) {
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
if (failed(isInBounds(op.dim(), op.tensor())))
return op.emitError("requested indices dimension out of bounds");
if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
return op.emitError("unexpected type for indices");
LogicalResult ToIndicesOp::verify() {
if (auto e = getSparseTensorEncoding(tensor().getType())) {
if (failed(isInBounds(dim(), tensor())))
return emitError("requested indices dimension out of bounds");
if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
return emitError("unexpected type for indices");
return success();
}
return op.emitError("expected a sparse tensor to get indices");
return emitError("expected a sparse tensor to get indices");
}
static LogicalResult verify(ToValuesOp op) {
if (!getSparseTensorEncoding(op.tensor().getType()))
return op.emitError("expected a sparse tensor to get values");
RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
MemRefType mtp = op.result().getType().cast<MemRefType>();
LogicalResult ToValuesOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor to get values");
RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
MemRefType mtp = result().getType().cast<MemRefType>();
if (ttp.getElementType() != mtp.getElementType())
return op.emitError("unexpected mismatch in element types");
return emitError("unexpected mismatch in element types");
return success();
}
@ -300,39 +298,39 @@ static LogicalResult verify(ToValuesOp op) {
// TensorDialect Management Operations.
//===----------------------------------------------------------------------===//
static LogicalResult verify(LexInsertOp op) {
if (!getSparseTensorEncoding(op.tensor().getType()))
return op.emitError("expected a sparse tensor for insertion");
LogicalResult LexInsertOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor for insertion");
return success();
}
static LogicalResult verify(ExpandOp op) {
if (!getSparseTensorEncoding(op.tensor().getType()))
return op.emitError("expected a sparse tensor for expansion");
LogicalResult ExpandOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor for expansion");
return success();
}
static LogicalResult verify(CompressOp op) {
if (!getSparseTensorEncoding(op.tensor().getType()))
return op.emitError("expected a sparse tensor for compression");
LogicalResult CompressOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor for compression");
return success();
}
static LogicalResult verify(LoadOp op) {
if (!getSparseTensorEncoding(op.tensor().getType()))
return op.emitError("expected a sparse tensor to materialize");
LogicalResult LoadOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor to materialize");
return success();
}
static LogicalResult verify(ReleaseOp op) {
if (!getSparseTensorEncoding(op.tensor().getType()))
return op.emitError("expected a sparse tensor to release");
LogicalResult ReleaseOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor to release");
return success();
}
static LogicalResult verify(OutOp op) {
if (!getSparseTensorEncoding(op.tensor().getType()))
return op.emitError("expected a sparse tensor for output");
LogicalResult OutOp::verify() {
if (!getSparseTensorEncoding(tensor().getType()))
return emitError("expected a sparse tensor for output");
return success();
}

View File

@ -228,17 +228,17 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {};
}
static LogicalResult verify(DimOp op) {
LogicalResult DimOp::verify() {
// Assume unknown index to be in range.
Optional<int64_t> index = op.getConstantIndex();
Optional<int64_t> index = getConstantIndex();
if (!index.hasValue())
return success();
// Check that constant index is not knowingly out of range.
auto type = op.source().getType();
auto type = source().getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (index.getValue() >= tensorType.getRank())
return op.emitOpError("index is out of range");
return emitOpError("index is out of range");
} else if (type.isa<UnrankedTensorType>()) {
// Assume index to be in range.
} else {
@ -328,11 +328,11 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ExtractOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(ExtractOp op) {
LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type.
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
return op.emitOpError("incorrect number of indices for extract_element");
if (auto tensorType = tensor().getType().dyn_cast<RankedTensorType>())
if (tensorType.getRank() != static_cast<int64_t>(indices().size()))
return emitOpError("incorrect number of indices for extract_element");
return success();
}
@ -480,11 +480,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
// InsertOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(InsertOp op) {
LogicalResult InsertOp::verify() {
// Verify the # indices match if we have a ranked type.
if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
return op.emitOpError("incorrect number of indices");
if (auto destType = dest().getType().dyn_cast<RankedTensorType>())
if (destType.getRank() != static_cast<int64_t>(indices().size()))
return emitOpError("incorrect number of indices");
return success();
}
@ -502,27 +502,26 @@ OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
// GenerateOp
//===----------------------------------------------------------------------===//
static LogicalResult verify(GenerateOp op) {
LogicalResult GenerateOp::verify() {
// Ensure that the tensor type has as many dynamic dimensions as are specified
// by the operands.
RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
if (op.getNumOperands() != resultTy.getNumDynamicDims())
return op.emitError("must have as many index operands as dynamic extents "
"in the result type");
RankedTensorType resultTy = getType().cast<RankedTensorType>();
if (getNumOperands() != resultTy.getNumDynamicDims())
return emitError("must have as many index operands as dynamic extents "
"in the result type");
// Ensure that region arguments span the index space.
if (!llvm::all_of(op.body().getArgumentTypes(),
if (!llvm::all_of(body().getArgumentTypes(),
[](Type ty) { return ty.isIndex(); }))
return op.emitError("all body arguments must be index");
if (op.body().getNumArguments() != resultTy.getRank())
return op.emitError("must have one body argument per input dimension");
return emitError("all body arguments must be index");
if (body().getNumArguments() != resultTy.getRank())
return emitError("must have one body argument per input dimension");
// Ensure that the region yields an element of the right type.
auto yieldOp =
llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
if (yieldOp.value().getType() != resultTy.getElementType())
return op.emitOpError(
return emitOpError(
"body must be terminated with a `yield` operation of the tensor "
"element type");
@ -686,16 +685,15 @@ static int64_t getNumElements(ShapedType type) {
return numElements;
}
static LogicalResult verify(ReshapeOp op) {
TensorType operandType = op.source().getType().cast<TensorType>();
TensorType resultType = op.result().getType().cast<TensorType>();
LogicalResult ReshapeOp::verify() {
TensorType operandType = source().getType().cast<TensorType>();
TensorType resultType = result().getType().cast<TensorType>();
if (operandType.getElementType() != resultType.getElementType())
return op.emitOpError("element types of source and destination tensor "
"types should be the same");
return emitOpError("element types of source and destination tensor "
"types should be the same");
int64_t shapeSize =
op.shape().getType().cast<RankedTensorType>().getDimSize(0);
int64_t shapeSize = shape().getType().cast<RankedTensorType>().getDimSize(0);
auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
@ -703,14 +701,14 @@ static LogicalResult verify(ReshapeOp op) {
if (operandRankedType && resultRankedType.hasStaticShape() &&
operandRankedType.hasStaticShape()) {
if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
return op.emitOpError("source and destination tensor should have the "
"same number of elements");
return emitOpError("source and destination tensor should have the "
"same number of elements");
}
if (ShapedType::isDynamic(shapeSize))
return op.emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked tensor type");
return emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked tensor type");
if (shapeSize != resultRankedType.getRank())
return op.emitOpError(
return emitOpError(
"length of shape operand differs from the result's tensor rank");
}
return success();
@ -814,12 +812,12 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
return success();
}
static LogicalResult verify(ExpandShapeOp op) {
return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType());
LogicalResult ExpandShapeOp::verify() {
return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
}
static LogicalResult verify(CollapseShapeOp op) {
return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType());
LogicalResult CollapseShapeOp::verify() {
return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
}
namespace {
@ -1052,14 +1050,12 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
}
/// Verifier for ExtractSliceOp.
static LogicalResult verify(ExtractSliceOp op) {
LogicalResult ExtractSliceOp::verify() {
// Verify result type against inferred type.
auto expectedType =
ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(),
op.getMixedSizes(), op.getMixedStrides());
auto result =
isRankReducedType(expectedType.cast<ShapedType>(), op.getType());
return produceSliceErrorMsg(result, op, expectedType);
auto expectedType = ExtractSliceOp::inferResultType(
getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
auto result = isRankReducedType(expectedType.cast<ShapedType>(), getType());
return produceSliceErrorMsg(result, *this, expectedType);
}
/// Infer the canonical type of the result of an extract_slice op. Returns a
@ -1308,16 +1304,16 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
}
/// Verifier for InsertSliceOp.
static LogicalResult verify(InsertSliceOp op) {
LogicalResult InsertSliceOp::verify() {
// insert_slice is the inverse of extract_slice, use the same type inference.
auto expectedType = ExtractSliceOp::inferRankReducedResultType(
op.getSourceType().getRank(), op.getType(),
extractFromI64ArrayAttr(op.static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()),
extractFromI64ArrayAttr(op.static_strides()));
getSourceType().getRank(), getType(),
extractFromI64ArrayAttr(static_offsets()),
extractFromI64ArrayAttr(static_sizes()),
extractFromI64ArrayAttr(static_strides()));
auto result =
isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType());
return produceSliceErrorMsg(result, op, expectedType);
isRankReducedType(expectedType.cast<ShapedType>(), getSourceType());
return produceSliceErrorMsg(result, *this, expectedType);
}
/// If we have two consecutive InsertSliceOp writing to the same slice, we
@ -1569,40 +1565,40 @@ ParseResult parseInferType(OpAsmParser &parser,
return success();
}
static LogicalResult verify(PadOp op) {
auto sourceType = op.source().getType().cast<RankedTensorType>();
auto resultType = op.result().getType().cast<RankedTensorType>();
auto expectedType = PadOp::inferResultType(
sourceType, extractFromI64ArrayAttr(op.static_low()),
extractFromI64ArrayAttr(op.static_high()));
LogicalResult PadOp::verify() {
auto sourceType = source().getType().cast<RankedTensorType>();
auto resultType = result().getType().cast<RankedTensorType>();
auto expectedType =
PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()),
extractFromI64ArrayAttr(static_high()));
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
continue;
if (expectedType.isDynamicDim(i))
continue;
return op.emitError("specified type ")
return emitError("specified type ")
<< resultType << " does not match the inferred type "
<< expectedType;
}
auto &region = op.region();
auto &region = getRegion();
unsigned rank = resultType.getRank();
Block &block = region.front();
if (block.getNumArguments() != rank)
return op.emitError("expected the block to have ") << rank << " arguments";
return emitError("expected the block to have ") << rank << " arguments";
// Note: the number and type of yield values are checked in the YieldOp.
for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
if (!en.value().isIndex())
return op.emitOpError("expected block argument ")
return emitOpError("expected block argument ")
<< (en.index() + 1) << " to be an index";
}
// Ensure that the region yields an element of the right type.
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
if (yieldOp.value().getType() !=
op.getType().cast<ShapedType>().getElementType())
return op.emitOpError("expected yield type to match shape element type");
getType().cast<ShapedType>().getElementType())
return emitOpError("expected yield type to match shape element type");
return success();
}