[mlir][NFC] Update MemRef/Tensor operations to use `hasVerifier` instead of `verifier`

The verifier field is deprecated, and slated for removal.

Differential Revision: https://reviews.llvm.org/D118821
This commit is contained in:
River Riddle 2022-02-02 10:18:06 -08:00
parent bdc7ce975a
commit b98dc0351a
6 changed files with 325 additions and 314 deletions

View File

@ -28,7 +28,6 @@ def MemRefTypeAttr
class MemRef_Op<string mnemonic, list<Trait> traits = []> class MemRef_Op<string mnemonic, list<Trait> traits = []>
: Op<MemRef_Dialect, mnemonic, traits> { : Op<MemRef_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
} }
@ -93,6 +92,7 @@ class AllocLikeOp<string mnemonic,
}]; }];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -115,6 +115,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
let results = (outs); let results = (outs);
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)"; let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -162,6 +163,7 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, []> {
memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1> memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
``` ```
}]; }];
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -205,6 +207,7 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
an alignment on any convenient boundary compatible with the type will be an alignment on any convenient boundary compatible with the type will be
chosen. chosen.
}]; }];
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -253,6 +256,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
let results = (outs Variadic<AnyType>:$results); let results = (outs Variadic<AnyType>:$results);
let regions = (region SizedRegion<1>:$bodyRegion); let regions = (region SizedRegion<1>:$bodyRegion);
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -279,11 +283,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
let arguments = (ins Variadic<AnyType>:$results); let arguments = (ins Variadic<AnyType>:$results);
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>]; let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
let assemblyFormat = let assemblyFormat = "attr-dict ($results^ `:` type($results))?";
[{ attr-dict ($results^ `:` type($results))? }];
// No custom verification needed.
let verifier = ?;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -355,7 +355,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
let arguments = (ins AnyRankedOrUnrankedMemRef:$source); let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
let results = (outs AnyRankedOrUnrankedMemRef:$dest); let results = (outs AnyRankedOrUnrankedMemRef:$dest);
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
let builders = [ let builders = [
OpBuilder<(ins "Value":$source, "Type":$destType), [{ OpBuilder<(ins "Value":$source, "Type":$destType), [{
impl::buildCastOp($_builder, $_state, source, destType); impl::buildCastOp($_builder, $_state, source, destType);
@ -370,6 +369,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -408,7 +408,6 @@ def CopyOp : MemRef_Op<"copy",
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let verifier = ?;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -434,7 +433,6 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", [MemFree]>:$memref); let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", [MemFree]>:$memref);
let hasFolder = 1; let hasFolder = 1;
let verifier = ?;
let assemblyFormat = "$memref attr-dict `:` type($memref)"; let assemblyFormat = "$memref attr-dict `:` type($memref)";
} }
@ -488,6 +486,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -646,6 +645,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
} }
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -697,6 +697,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
Value getNumElements() { return numElements(); } Value getNumElements() { return numElements(); }
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -757,6 +758,7 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
return memref().getType().cast<MemRefType>(); return memref().getType().cast<MemRefType>();
} }
}]; }];
let hasVerifier = 1;
} }
def AtomicYieldOp : MemRef_Op<"atomic_yield", [ def AtomicYieldOp : MemRef_Op<"atomic_yield", [
@ -772,6 +774,7 @@ def AtomicYieldOp : MemRef_Op<"atomic_yield", [
let arguments = (ins AnyType:$result); let arguments = (ins AnyType:$result);
let assemblyFormat = "$result attr-dict `:` type($result)"; let assemblyFormat = "$result attr-dict `:` type($result)";
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -797,9 +800,6 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
let arguments = (ins FlatSymbolRefAttr:$name); let arguments = (ins FlatSymbolRefAttr:$name);
let results = (outs AnyStaticShapeMemRef:$result); let results = (outs AnyStaticShapeMemRef:$result);
let assemblyFormat = "$name `:` type($result) attr-dict"; let assemblyFormat = "$name `:` type($result) attr-dict";
// `GetGlobalOp` is fully verified by its traits.
let verifier = ?;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -866,6 +866,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
return !isExternal() && initial_value().getValue().isa<UnitAttr>(); return !isExternal() && initial_value().getValue().isa<UnitAttr>();
} }
}]; }];
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -939,6 +940,7 @@ def LoadOp : MemRef_Op<"load",
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)"; let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
} }
@ -982,6 +984,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1034,6 +1037,7 @@ def MemRef_ReinterpretCastOp:
let parser = ?; let parser = ?;
let printer = ?; let printer = ?;
let hasVerifier = 1;
let builders = [ let builders = [
// Build a ReinterpretCastOp with mixed static and dynamic entries. // Build a ReinterpretCastOp with mixed static and dynamic entries.
@ -1096,7 +1100,6 @@ def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> {
let arguments = (ins AnyRankedOrUnrankedMemRef:$memref); let arguments = (ins AnyRankedOrUnrankedMemRef:$memref);
let results = (outs Index); let results = (outs Index);
let verifier = ?;
let hasFolder = 1; let hasFolder = 1;
let assemblyFormat = "$memref attr-dict `:` type($memref)"; let assemblyFormat = "$memref attr-dict `:` type($memref)";
} }
@ -1161,6 +1164,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
let assemblyFormat = [{ let assemblyFormat = [{
$source `(` $shape `)` attr-dict `:` functional-type(operands, results) $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
}]; }];
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1226,6 +1230,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
} }
@ -1265,6 +1270,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
``` ```
}]; }];
let extraClassDeclaration = commonExtraClassDeclaration; let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
} }
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> { def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
@ -1302,6 +1308,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
``` ```
}]; }];
let extraClassDeclaration = commonExtraClassDeclaration; let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1369,6 +1376,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
let assemblyFormat = [{ let assemblyFormat = [{
$value `,` $memref `[` $indices `]` attr-dict `:` type($memref) $value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
@ -1617,6 +1625,7 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1645,8 +1654,6 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef, let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
"the reference to store to", [MemWrite]>:$memref); "the reference to store to", [MemWrite]>:$memref);
// TensorStoreOp is fully verified by traits.
let verifier = ?;
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)"; let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
} }
@ -1681,6 +1688,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>,
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1749,6 +1757,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
}]; }];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -1796,6 +1805,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
} }
}]; }];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
#endif // MEMREF_OPS #endif // MEMREF_OPS

View File

@ -21,7 +21,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
class SparseTensor_Op<string mnemonic, list<Trait> traits = []> class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
: Op<SparseTensor_Dialect, mnemonic, traits> { : Op<SparseTensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
} }
@ -50,6 +49,7 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
``` ```
}]; }];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)"; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
let hasVerifier = 1;
} }
def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>, def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
@ -72,6 +72,7 @@ def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
``` ```
}]; }];
let assemblyFormat = "`[` $sizes `]` attr-dict `:` type($result)"; let assemblyFormat = "`[` $sizes `]` attr-dict `:` type($result)";
let hasVerifier = 1;
} }
def SparseTensor_ConvertOp : SparseTensor_Op<"convert", def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
@ -113,6 +114,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
}]; }];
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>, def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
@ -137,6 +139,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
}]; }];
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)" let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
" `to` type($result)"; " `to` type($result)";
let hasVerifier = 1;
} }
def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>, def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
@ -161,6 +164,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
}]; }];
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)" let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
" `to` type($result)"; " `to` type($result)";
let hasVerifier = 1;
} }
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>, def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
@ -183,6 +187,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
``` ```
}]; }];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)"; let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -217,6 +222,7 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
}]; }];
let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`" let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
" type($tensor) `,` type($indices) `,` type($value)"; " type($tensor) `,` type($indices) `,` type($value)";
let hasVerifier = 1;
} }
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>, def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
@ -258,6 +264,7 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
}]; }];
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)" let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
" `,` type($filled) `,` type($added) `,` type($count)"; " `,` type($filled) `,` type($added) `,` type($count)";
let hasVerifier = 1;
} }
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>, def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
@ -292,6 +299,7 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
" $added `,` $count attr-dict `:` type($tensor) `,`" " $added `,` $count attr-dict `:` type($tensor) `,`"
" type($indices) `,` type($values) `,` type($filled) `,`" " type($indices) `,` type($values) `,` type($filled) `,`"
" type($added) `,` type($count)"; " type($added) `,` type($count)";
let hasVerifier = 1;
} }
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>, def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
@ -324,6 +332,7 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
``` ```
}]; }];
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)"; let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
let hasVerifier = 1;
} }
def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>, def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
@ -349,6 +358,7 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
``` ```
}]; }];
let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
let hasVerifier = 1;
} }
def SparseTensor_OutOp : SparseTensor_Op<"out", []>, def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
@ -369,6 +379,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
``` ```
}]; }];
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)"; let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
let hasVerifier = 1;
} }
#endif // SPARSETENSOR_OPS #endif // SPARSETENSOR_OPS

View File

@ -20,7 +20,6 @@ include "mlir/Interfaces/ViewLikeInterface.td"
class Tensor_Op<string mnemonic, list<Trait> traits = []> class Tensor_Op<string mnemonic, list<Trait> traits = []>
: Op<Tensor_Dialect, mnemonic, traits> { : Op<Tensor_Dialect, mnemonic, traits> {
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }]; let parser = [{ return ::parse$cppClass(parser, result); }];
} }
@ -59,7 +58,6 @@ def Tensor_CastOp : Tensor_Op<"cast", [
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)"; let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let verifier = ?;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -111,6 +109,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -151,6 +150,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
}]>]; }]>];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
@ -303,6 +303,7 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -339,9 +340,6 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
let assemblyFormat = "$elements attr-dict `:` type($result)"; let assemblyFormat = "$elements attr-dict `:` type($result)";
// This op is fully verified by its traits.
let verifier = ?;
let skipDefaultBuilders = 1; let skipDefaultBuilders = 1;
let builders = [ let builders = [
OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>, OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>,
@ -394,6 +392,7 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
]; ];
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -445,6 +444,7 @@ def Tensor_InsertOp : Tensor_Op<"insert",
}]>]; }]>];
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -564,6 +564,7 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -586,7 +587,6 @@ def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> {
let arguments = (ins AnyTensor:$tensor); let arguments = (ins AnyTensor:$tensor);
let results = (outs Index); let results = (outs Index);
let verifier = ?;
let hasFolder = 1; let hasFolder = 1;
let assemblyFormat = "$tensor attr-dict `:` type($tensor)"; let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
} }
@ -650,6 +650,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> {
let assemblyFormat = [{ let assemblyFormat = [{
$source `(` $shape `)` attr-dict `:` functional-type(operands, results) $source `(` $shape `)` attr-dict `:` functional-type(operands, results)
}]; }];
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -718,6 +719,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
let hasFolder = 1; let hasFolder = 1;
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasVerifier = 1;
let printer = [{ return ::print(p, *this); }]; let printer = [{ return ::print(p, *this); }];
let parser = [{ return ::parseReshapeLikeOp(parser, result); }]; let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
} }
@ -748,6 +750,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
``` ```
}]; }];
let extraClassDeclaration = commonExtraClassDeclaration; let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
} }
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> { def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
@ -776,6 +779,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
``` ```
}]; }];
let extraClassDeclaration = commonExtraClassDeclaration; let extraClassDeclaration = commonExtraClassDeclaration;
let hasVerifier = 1;
} }
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -961,6 +965,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
let hasCanonicalizer = 1; let hasCanonicalizer = 1;
let hasFolder = 1; let hasFolder = 1;
let hasVerifier = 1;
} }
@ -984,7 +989,6 @@ def Tensor_YieldOp : Tensor_Op<"yield",
// Dummy builder to appease code in templated ensureTerminator that // Dummy builder to appease code in templated ensureTerminator that
// GenerateOp's auto-generated parser calls. // GenerateOp's auto-generated parser calls.
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>]; let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
let verifier = ?;
} }
#endif // TENSOR_OPS #endif // TENSOR_OPS

View File

@ -67,6 +67,10 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
return NoneType::get(type.getContext()); return NoneType::get(type.getContext());
} }
LogicalResult memref::CastOp::verify() {
return impl::verifyCastOp(*this, areCastCompatible);
}
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
// AllocOp / AllocaOp // AllocOp / AllocaOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
@ -95,15 +99,15 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
return success(); return success();
} }
static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); } LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
static LogicalResult verify(AllocaOp op) { LogicalResult AllocaOp::verify() {
// An alloca op needs to have an ancestor with an allocation scope trait. // An alloca op needs to have an ancestor with an allocation scope trait.
if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>()) if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
return op.emitOpError( return emitOpError(
"requires an ancestor op with AutomaticAllocationScope trait"); "requires an ancestor op with AutomaticAllocationScope trait");
return verifyAllocLikeOp(op); return verifyAllocLikeOp(*this);
} }
namespace { namespace {
@ -246,11 +250,8 @@ static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
return success(); return success();
} }
static LogicalResult verify(AllocaScopeOp op) { LogicalResult AllocaScopeOp::verify() {
if (failed(RegionBranchOpInterface::verifyTypes(op))) return RegionBranchOpInterface::verifyTypes(*this);
return failure();
return success();
} }
void AllocaScopeOp::getSuccessorRegions( void AllocaScopeOp::getSuccessorRegions(
@ -268,10 +269,9 @@ void AllocaScopeOp::getSuccessorRegions(
// AssumeAlignmentOp // AssumeAlignmentOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(AssumeAlignmentOp op) { LogicalResult AssumeAlignmentOp::verify() {
unsigned alignment = op.alignment(); if (!llvm::isPowerOf2_32(alignment()))
if (!llvm::isPowerOf2_32(alignment)) return emitOpError("alignment must be power of 2");
return op.emitOpError("alignment must be power of 2");
return success(); return success();
} }
@ -556,17 +556,17 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {}; return {};
} }
static LogicalResult verify(DimOp op) { LogicalResult DimOp::verify() {
// Assume unknown index to be in range. // Assume unknown index to be in range.
Optional<int64_t> index = op.getConstantIndex(); Optional<int64_t> index = getConstantIndex();
if (!index.hasValue()) if (!index.hasValue())
return success(); return success();
// Check that constant index is not knowingly out of range. // Check that constant index is not knowingly out of range.
auto type = op.source().getType(); auto type = source().getType();
if (auto memrefType = type.dyn_cast<MemRefType>()) { if (auto memrefType = type.dyn_cast<MemRefType>()) {
if (index.getValue() >= memrefType.getRank()) if (index.getValue() >= memrefType.getRank())
return op.emitOpError("index is out of range"); return emitOpError("index is out of range");
} else if (type.isa<UnrankedMemRefType>()) { } else if (type.isa<UnrankedMemRefType>()) {
// Assume index to be in range. // Assume index to be in range.
} else { } else {
@ -866,67 +866,66 @@ static ParseResult parseDmaStartOp(OpAsmParser &parser,
return success(); return success();
} }
static LogicalResult verify(DmaStartOp op) { LogicalResult DmaStartOp::verify() {
unsigned numOperands = op.getNumOperands(); unsigned numOperands = getNumOperands();
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
// the number of elements. // the number of elements.
if (numOperands < 4) if (numOperands < 4)
return op.emitOpError("expected at least 4 operands"); return emitOpError("expected at least 4 operands");
// Check types of operands. The order of these calls is important: the later // Check types of operands. The order of these calls is important: the later
// calls rely on some type properties to compute the operand position. // calls rely on some type properties to compute the operand position.
// 1. Source memref. // 1. Source memref.
if (!op.getSrcMemRef().getType().isa<MemRefType>()) if (!getSrcMemRef().getType().isa<MemRefType>())
return op.emitOpError("expected source to be of memref type"); return emitOpError("expected source to be of memref type");
if (numOperands < op.getSrcMemRefRank() + 4) if (numOperands < getSrcMemRefRank() + 4)
return op.emitOpError() return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
<< "expected at least " << op.getSrcMemRefRank() + 4 << " operands"; << " operands";
if (!op.getSrcIndices().empty() && if (!getSrcIndices().empty() &&
!llvm::all_of(op.getSrcIndices().getTypes(), !llvm::all_of(getSrcIndices().getTypes(),
[](Type t) { return t.isIndex(); })) [](Type t) { return t.isIndex(); }))
return op.emitOpError("expected source indices to be of index type"); return emitOpError("expected source indices to be of index type");
// 2. Destination memref. // 2. Destination memref.
if (!op.getDstMemRef().getType().isa<MemRefType>()) if (!getDstMemRef().getType().isa<MemRefType>())
return op.emitOpError("expected destination to be of memref type"); return emitOpError("expected destination to be of memref type");
unsigned numExpectedOperands = unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
if (numOperands < numExpectedOperands) if (numOperands < numExpectedOperands)
return op.emitOpError() return emitOpError() << "expected at least " << numExpectedOperands
<< "expected at least " << numExpectedOperands << " operands"; << " operands";
if (!op.getDstIndices().empty() && if (!getDstIndices().empty() &&
!llvm::all_of(op.getDstIndices().getTypes(), !llvm::all_of(getDstIndices().getTypes(),
[](Type t) { return t.isIndex(); })) [](Type t) { return t.isIndex(); }))
return op.emitOpError("expected destination indices to be of index type"); return emitOpError("expected destination indices to be of index type");
// 3. Number of elements. // 3. Number of elements.
if (!op.getNumElements().getType().isIndex()) if (!getNumElements().getType().isIndex())
return op.emitOpError("expected num elements to be of index type"); return emitOpError("expected num elements to be of index type");
// 4. Tag memref. // 4. Tag memref.
if (!op.getTagMemRef().getType().isa<MemRefType>()) if (!getTagMemRef().getType().isa<MemRefType>())
return op.emitOpError("expected tag to be of memref type"); return emitOpError("expected tag to be of memref type");
numExpectedOperands += op.getTagMemRefRank(); numExpectedOperands += getTagMemRefRank();
if (numOperands < numExpectedOperands) if (numOperands < numExpectedOperands)
return op.emitOpError() return emitOpError() << "expected at least " << numExpectedOperands
<< "expected at least " << numExpectedOperands << " operands"; << " operands";
if (!op.getTagIndices().empty() && if (!getTagIndices().empty() &&
!llvm::all_of(op.getTagIndices().getTypes(), !llvm::all_of(getTagIndices().getTypes(),
[](Type t) { return t.isIndex(); })) [](Type t) { return t.isIndex(); }))
return op.emitOpError("expected tag indices to be of index type"); return emitOpError("expected tag indices to be of index type");
// Optional stride-related operands must be either both present or both // Optional stride-related operands must be either both present or both
// absent. // absent.
if (numOperands != numExpectedOperands && if (numOperands != numExpectedOperands &&
numOperands != numExpectedOperands + 2) numOperands != numExpectedOperands + 2)
return op.emitOpError("incorrect number of operands"); return emitOpError("incorrect number of operands");
// 5. Strides. // 5. Strides.
if (op.isStrided()) { if (isStrided()) {
if (!op.getStride().getType().isIndex() || if (!getStride().getType().isIndex() ||
!op.getNumElementsPerStride().getType().isIndex()) !getNumElementsPerStride().getType().isIndex())
return op.emitOpError( return emitOpError(
"expected stride and num elements per stride to be of type index"); "expected stride and num elements per stride to be of type index");
} }
@ -949,14 +948,14 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
return foldMemRefCast(*this); return foldMemRefCast(*this);
} }
static LogicalResult verify(DmaWaitOp op) { LogicalResult DmaWaitOp::verify() {
// Check that the number of tag indices matches the tagMemRef rank. // Check that the number of tag indices matches the tagMemRef rank.
unsigned numTagIndices = op.tagIndices().size(); unsigned numTagIndices = tagIndices().size();
unsigned tagMemRefRank = op.getTagMemRefRank(); unsigned tagMemRefRank = getTagMemRefRank();
if (numTagIndices != tagMemRefRank) if (numTagIndices != tagMemRefRank)
return op.emitOpError() << "expected tagIndices to have the same number of " return emitOpError() << "expected tagIndices to have the same number of "
"elements as the tagMemRef rank, expected " "elements as the tagMemRef rank, expected "
<< tagMemRefRank << ", but got " << numTagIndices; << tagMemRefRank << ", but got " << numTagIndices;
return success(); return success();
} }
@ -979,14 +978,13 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
} }
} }
static LogicalResult verify(GenericAtomicRMWOp op) { LogicalResult GenericAtomicRMWOp::verify() {
auto &body = op.getRegion(); auto &body = getRegion();
if (body.getNumArguments() != 1) if (body.getNumArguments() != 1)
return op.emitOpError("expected single number of entry block arguments"); return emitOpError("expected single number of entry block arguments");
if (op.getResult().getType() != body.getArgument(0).getType()) if (getResult().getType() != body.getArgument(0).getType())
return op.emitOpError( return emitOpError("expected block argument of the same type result type");
"expected block argument of the same type result type");
bool hasSideEffects = bool hasSideEffects =
body.walk([&](Operation *nestedOp) { body.walk([&](Operation *nestedOp) {
@ -1034,12 +1032,12 @@ static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
// AtomicYieldOp // AtomicYieldOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicYieldOp op) { LogicalResult AtomicYieldOp::verify() {
Type parentType = op->getParentOp()->getResultTypes().front(); Type parentType = (*this)->getParentOp()->getResultTypes().front();
Type resultType = op.result().getType(); Type resultType = result().getType();
if (parentType != resultType) if (parentType != resultType)
return op.emitOpError() << "types mismatch between yield op: " << resultType return emitOpError() << "types mismatch between yield op: " << resultType
<< " and its parent: " << parentType; << " and its parent: " << parentType;
return success(); return success();
} }
@ -1090,19 +1088,19 @@ parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
return success(); return success();
} }
static LogicalResult verify(GlobalOp op) { LogicalResult GlobalOp::verify() {
auto memrefType = op.type().dyn_cast<MemRefType>(); auto memrefType = type().dyn_cast<MemRefType>();
if (!memrefType || !memrefType.hasStaticShape()) if (!memrefType || !memrefType.hasStaticShape())
return op.emitOpError("type should be static shaped memref, but got ") return emitOpError("type should be static shaped memref, but got ")
<< op.type(); << type();
// Verify that the initial value, if present, is either a unit attribute or // Verify that the initial value, if present, is either a unit attribute or
// an elements attribute. // an elements attribute.
if (op.initial_value().hasValue()) { if (initial_value().hasValue()) {
Attribute initValue = op.initial_value().getValue(); Attribute initValue = initial_value().getValue();
if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>()) if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
return op.emitOpError("initial value should be a unit or elements " return emitOpError("initial value should be a unit or elements "
"attribute, but got ") "attribute, but got ")
<< initValue; << initValue;
// Check that the type of the initial value is compatible with the type of // Check that the type of the initial value is compatible with the type of
@ -1111,17 +1109,17 @@ static LogicalResult verify(GlobalOp op) {
Type initType = initValue.getType(); Type initType = initValue.getType();
Type tensorType = getTensorTypeFromMemRefType(memrefType); Type tensorType = getTensorTypeFromMemRefType(memrefType);
if (initType != tensorType) if (initType != tensorType)
return op.emitOpError("initial value expected to be of type ") return emitOpError("initial value expected to be of type ")
<< tensorType << ", but was of type " << initType; << tensorType << ", but was of type " << initType;
} }
} }
if (Optional<uint64_t> alignAttr = op.alignment()) { if (Optional<uint64_t> alignAttr = alignment()) {
uint64_t alignment = alignAttr.getValue(); uint64_t alignment = alignAttr.getValue();
if (!llvm::isPowerOf2_64(alignment)) if (!llvm::isPowerOf2_64(alignment))
return op->emitError() << "alignment attribute value " << alignment return emitError() << "alignment attribute value " << alignment
<< " is not a power of 2"; << " is not a power of 2";
} }
// TODO: verify visibility for declarations. // TODO: verify visibility for declarations.
@ -1154,9 +1152,9 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
// LoadOp // LoadOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(LoadOp op) { LogicalResult LoadOp::verify() {
if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) if (getNumOperands() != 1 + getMemRefType().getRank())
return op.emitOpError("incorrect number of indices for load"); return emitOpError("incorrect number of indices for load");
return success(); return success();
} }
@ -1224,9 +1222,9 @@ static ParseResult parsePrefetchOp(OpAsmParser &parser,
return success(); return success();
} }
static LogicalResult verify(PrefetchOp op) { LogicalResult PrefetchOp::verify() {
if (op.getNumOperands() != 1 + op.getMemRefType().getRank()) if (getNumOperands() != 1 + getMemRefType().getRank())
return op.emitOpError("too few indices"); return emitOpError("too few indices");
return success(); return success();
} }
@ -1306,26 +1304,25 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
// TODO: ponder whether we want to allow missing trailing sizes/strides that are // TODO: ponder whether we want to allow missing trailing sizes/strides that are
// completed automatically, like we have for subview and extract_slice. // completed automatically, like we have for subview and extract_slice.
static LogicalResult verify(ReinterpretCastOp op) { LogicalResult ReinterpretCastOp::verify() {
// The source and result memrefs should be in the same memory space. // The source and result memrefs should be in the same memory space.
auto srcType = op.source().getType().cast<BaseMemRefType>(); auto srcType = source().getType().cast<BaseMemRefType>();
auto resultType = op.getType().cast<MemRefType>(); auto resultType = getType().cast<MemRefType>();
if (srcType.getMemorySpace() != resultType.getMemorySpace()) if (srcType.getMemorySpace() != resultType.getMemorySpace())
return op.emitError("different memory spaces specified for source type ") return emitError("different memory spaces specified for source type ")
<< srcType << " and result memref type " << resultType; << srcType << " and result memref type " << resultType;
if (srcType.getElementType() != resultType.getElementType()) if (srcType.getElementType() != resultType.getElementType())
return op.emitError("different element types specified for source type ") return emitError("different element types specified for source type ")
<< srcType << " and result memref type " << resultType; << srcType << " and result memref type " << resultType;
// Match sizes in result memref type and in static_sizes attribute. // Match sizes in result memref type and in static_sizes attribute.
for (auto &en : for (auto &en : llvm::enumerate(llvm::zip(
llvm::enumerate(llvm::zip(resultType.getShape(), resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
extractFromI64ArrayAttr(op.static_sizes())))) {
int64_t resultSize = std::get<0>(en.value()); int64_t resultSize = std::get<0>(en.value());
int64_t expectedSize = std::get<1>(en.value()); int64_t expectedSize = std::get<1>(en.value());
if (!ShapedType::isDynamic(resultSize) && if (!ShapedType::isDynamic(resultSize) &&
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize) !ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
return op.emitError("expected result type with size = ") return emitError("expected result type with size = ")
<< expectedSize << " instead of " << resultSize << expectedSize << " instead of " << resultSize
<< " in dim = " << en.index(); << " in dim = " << en.index();
} }
@ -1336,27 +1333,26 @@ static LogicalResult verify(ReinterpretCastOp op) {
int64_t resultOffset; int64_t resultOffset;
SmallVector<int64_t, 4> resultStrides; SmallVector<int64_t, 4> resultStrides;
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset))) if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
return op.emitError( return emitError("expected result type to have strided layout but found ")
"expected result type to have strided layout but found ")
<< resultType; << resultType;
// Match offset in result memref type and in static_offsets attribute. // Match offset in result memref type and in static_offsets attribute.
int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front(); int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) && if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
!ShapedType::isDynamicStrideOrOffset(expectedOffset) && !ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
resultOffset != expectedOffset) resultOffset != expectedOffset)
return op.emitError("expected result type with offset = ") return emitError("expected result type with offset = ")
<< resultOffset << " instead of " << expectedOffset; << resultOffset << " instead of " << expectedOffset;
// Match strides in result memref type and in static_strides attribute. // Match strides in result memref type and in static_strides attribute.
for (auto &en : llvm::enumerate(llvm::zip( for (auto &en : llvm::enumerate(llvm::zip(
resultStrides, extractFromI64ArrayAttr(op.static_strides())))) { resultStrides, extractFromI64ArrayAttr(static_strides())))) {
int64_t resultStride = std::get<0>(en.value()); int64_t resultStride = std::get<0>(en.value());
int64_t expectedStride = std::get<1>(en.value()); int64_t expectedStride = std::get<1>(en.value());
if (!ShapedType::isDynamicStrideOrOffset(resultStride) && if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
!ShapedType::isDynamicStrideOrOffset(expectedStride) && !ShapedType::isDynamicStrideOrOffset(expectedStride) &&
resultStride != expectedStride) resultStride != expectedStride)
return op.emitError("expected result type with stride = ") return emitError("expected result type with stride = ")
<< expectedStride << " instead of " << resultStride << expectedStride << " instead of " << resultStride
<< " in dim = " << en.index(); << " in dim = " << en.index();
} }
@ -1532,8 +1528,8 @@ static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
return success(); return success();
} }
static LogicalResult verify(ExpandShapeOp op) { LogicalResult ExpandShapeOp::verify() {
return verifyReshapeOp(op, op.getResultType(), op.getSrcType()); return verifyReshapeOp(*this, getResultType(), getSrcType());
} }
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results, void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
@ -1542,8 +1538,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context); CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
} }
static LogicalResult verify(CollapseShapeOp op) { LogicalResult CollapseShapeOp::verify() {
return verifyReshapeOp(op, op.getSrcType(), op.getResultType()); return verifyReshapeOp(*this, getSrcType(), getResultType());
} }
struct CollapseShapeOpMemRefCastFolder struct CollapseShapeOpMemRefCastFolder
@ -1593,32 +1589,30 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
// ReshapeOp // ReshapeOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(ReshapeOp op) { LogicalResult ReshapeOp::verify() {
Type operandType = op.source().getType(); Type operandType = source().getType();
Type resultType = op.result().getType(); Type resultType = result().getType();
Type operandElementType = operandType.cast<ShapedType>().getElementType(); Type operandElementType = operandType.cast<ShapedType>().getElementType();
Type resultElementType = resultType.cast<ShapedType>().getElementType(); Type resultElementType = resultType.cast<ShapedType>().getElementType();
if (operandElementType != resultElementType) if (operandElementType != resultElementType)
return op.emitOpError("element types of source and destination memref " return emitOpError("element types of source and destination memref "
"types should be the same"); "types should be the same");
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>()) if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
if (!operandMemRefType.getLayout().isIdentity()) if (!operandMemRefType.getLayout().isIdentity())
return op.emitOpError( return emitOpError("source memref type should have identity affine map");
"source memref type should have identity affine map");
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0); int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
auto resultMemRefType = resultType.dyn_cast<MemRefType>(); auto resultMemRefType = resultType.dyn_cast<MemRefType>();
if (resultMemRefType) { if (resultMemRefType) {
if (!resultMemRefType.getLayout().isIdentity()) if (!resultMemRefType.getLayout().isIdentity())
return op.emitOpError( return emitOpError("result memref type should have identity affine map");
"result memref type should have identity affine map");
if (shapeSize == ShapedType::kDynamicSize) if (shapeSize == ShapedType::kDynamicSize)
return op.emitOpError("cannot use shape operand with dynamic length to " return emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked memref type"); "reshape to statically-ranked memref type");
if (shapeSize != resultMemRefType.getRank()) if (shapeSize != resultMemRefType.getRank())
return op.emitOpError( return emitOpError(
"length of shape operand differs from the result's memref rank"); "length of shape operand differs from the result's memref rank");
} }
return success(); return success();
@ -1628,9 +1622,9 @@ static LogicalResult verify(ReshapeOp op) {
// StoreOp // StoreOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(StoreOp op) { LogicalResult StoreOp::verify() {
if (op.getNumOperands() != 2 + op.getMemRefType().getRank()) if (getNumOperands() != 2 + getMemRefType().getRank())
return op.emitOpError("store index operand count not equal to memref rank"); return emitOpError("store index operand count not equal to memref rank");
return success(); return success();
} }
@ -1951,29 +1945,29 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
} }
/// Verifier for SubViewOp. /// Verifier for SubViewOp.
static LogicalResult verify(SubViewOp op) { LogicalResult SubViewOp::verify() {
MemRefType baseType = op.getSourceType(); MemRefType baseType = getSourceType();
MemRefType subViewType = op.getType(); MemRefType subViewType = getType();
// The base memref and the view memref should be in the same memory space. // The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != subViewType.getMemorySpace()) if (baseType.getMemorySpace() != subViewType.getMemorySpace())
return op.emitError("different memory spaces specified for base memref " return emitError("different memory spaces specified for base memref "
"type ") "type ")
<< baseType << " and subview memref type " << subViewType; << baseType << " and subview memref type " << subViewType;
// Verify that the base memref type has a strided layout map. // Verify that the base memref type has a strided layout map.
if (!isStrided(baseType)) if (!isStrided(baseType))
return op.emitError("base type ") << baseType << " is not strided"; return emitError("base type ") << baseType << " is not strided";
// Verify result type against inferred type. // Verify result type against inferred type.
auto expectedType = SubViewOp::inferResultType( auto expectedType = SubViewOp::inferResultType(
baseType, extractFromI64ArrayAttr(op.static_offsets()), baseType, extractFromI64ArrayAttr(static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(static_sizes()),
extractFromI64ArrayAttr(op.static_strides())); extractFromI64ArrayAttr(static_strides()));
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(), auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
subViewType, op.getMixedSizes()); subViewType, getMixedSizes());
return produceSubViewErrorMsg(result, op, expectedType); return produceSubViewErrorMsg(result, *this, expectedType);
} }
raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) { raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
@ -2278,18 +2272,17 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
return success(); return success();
} }
static LogicalResult verify(TransposeOp op) { LogicalResult TransposeOp::verify() {
if (!op.permutation().isPermutation()) if (!permutation().isPermutation())
return op.emitOpError("expected a permutation map"); return emitOpError("expected a permutation map");
if (op.permutation().getNumDims() != op.getShapedType().getRank()) if (permutation().getNumDims() != getShapedType().getRank())
return op.emitOpError( return emitOpError("expected a permutation map of same rank as the input");
"expected a permutation map of same rank as the input");
auto srcType = op.in().getType().cast<MemRefType>(); auto srcType = in().getType().cast<MemRefType>();
auto dstType = op.getType().cast<MemRefType>(); auto dstType = getType().cast<MemRefType>();
auto transposedType = inferTransposeResultType(srcType, op.permutation()); auto transposedType = inferTransposeResultType(srcType, permutation());
if (dstType != transposedType) if (dstType != transposedType)
return op.emitOpError("output type ") return emitOpError("output type ")
<< dstType << " does not match transposed input type " << srcType << dstType << " does not match transposed input type " << srcType
<< ", " << transposedType; << ", " << transposedType;
return success(); return success();
@ -2338,29 +2331,28 @@ static void print(OpAsmPrinter &p, ViewOp op) {
p << " : " << op.getOperand(0).getType() << " to " << op.getType(); p << " : " << op.getOperand(0).getType() << " to " << op.getType();
} }
static LogicalResult verify(ViewOp op) { LogicalResult ViewOp::verify() {
auto baseType = op.getOperand(0).getType().cast<MemRefType>(); auto baseType = getOperand(0).getType().cast<MemRefType>();
auto viewType = op.getType(); auto viewType = getType();
// The base memref should have identity layout map (or none). // The base memref should have identity layout map (or none).
if (!baseType.getLayout().isIdentity()) if (!baseType.getLayout().isIdentity())
return op.emitError("unsupported map for base memref type ") << baseType; return emitError("unsupported map for base memref type ") << baseType;
// The result memref should have identity layout map (or none). // The result memref should have identity layout map (or none).
if (!viewType.getLayout().isIdentity()) if (!viewType.getLayout().isIdentity())
return op.emitError("unsupported map for result memref type ") << viewType; return emitError("unsupported map for result memref type ") << viewType;
// The base memref and the view memref should be in the same memory space. // The base memref and the view memref should be in the same memory space.
if (baseType.getMemorySpace() != viewType.getMemorySpace()) if (baseType.getMemorySpace() != viewType.getMemorySpace())
return op.emitError("different memory spaces specified for base memref " return emitError("different memory spaces specified for base memref "
"type ") "type ")
<< baseType << " and view memref type " << viewType; << baseType << " and view memref type " << viewType;
// Verify that we have the correct number of sizes for the result type. // Verify that we have the correct number of sizes for the result type.
unsigned numDynamicDims = viewType.getNumDynamicDims(); unsigned numDynamicDims = viewType.getNumDynamicDims();
if (op.sizes().size() != numDynamicDims) if (sizes().size() != numDynamicDims)
return op.emitError("incorrect number of size operands for type ") return emitError("incorrect number of size operands for type ") << viewType;
<< viewType;
return success(); return success();
} }
@ -2467,19 +2459,19 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
// AtomicRMWOp // AtomicRMWOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(AtomicRMWOp op) { LogicalResult AtomicRMWOp::verify() {
if (op.getMemRefType().getRank() != op.getNumOperands() - 2) if (getMemRefType().getRank() != getNumOperands() - 2)
return op.emitOpError( return emitOpError(
"expects the number of subscripts to be equal to memref rank"); "expects the number of subscripts to be equal to memref rank");
switch (op.kind()) { switch (kind()) {
case arith::AtomicRMWKind::addf: case arith::AtomicRMWKind::addf:
case arith::AtomicRMWKind::maxf: case arith::AtomicRMWKind::maxf:
case arith::AtomicRMWKind::minf: case arith::AtomicRMWKind::minf:
case arith::AtomicRMWKind::mulf: case arith::AtomicRMWKind::mulf:
if (!op.value().getType().isa<FloatType>()) if (!value().getType().isa<FloatType>())
return op.emitOpError() return emitOpError() << "with kind '"
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) << arith::stringifyAtomicRMWKind(kind())
<< "' expects a floating-point type"; << "' expects a floating-point type";
break; break;
case arith::AtomicRMWKind::addi: case arith::AtomicRMWKind::addi:
case arith::AtomicRMWKind::maxs: case arith::AtomicRMWKind::maxs:
@ -2489,10 +2481,10 @@ static LogicalResult verify(AtomicRMWOp op) {
case arith::AtomicRMWKind::muli: case arith::AtomicRMWKind::muli:
case arith::AtomicRMWKind::ori: case arith::AtomicRMWKind::ori:
case arith::AtomicRMWKind::andi: case arith::AtomicRMWKind::andi:
if (!op.value().getType().isa<IntegerType>()) if (!value().getType().isa<IntegerType>())
return op.emitOpError() return emitOpError() << "with kind '"
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind()) << arith::stringifyAtomicRMWKind(kind())
<< "' expects an integer type"; << "' expects an integer type";
break; break;
default: default:
break; break;

View File

@ -209,53 +209,51 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
return failure(); return failure();
} }
static LogicalResult verify(NewOp op) { LogicalResult NewOp::verify() {
if (!getSparseTensorEncoding(op.result().getType())) if (!getSparseTensorEncoding(result().getType()))
return op.emitError("expected a sparse tensor result"); return emitError("expected a sparse tensor result");
return success(); return success();
} }
static LogicalResult verify(InitOp op) { LogicalResult InitOp::verify() {
if (!getSparseTensorEncoding(op.result().getType())) if (!getSparseTensorEncoding(result().getType()))
return op.emitError("expected a sparse tensor result"); return emitError("expected a sparse tensor result");
RankedTensorType ttp = op.getType().cast<RankedTensorType>(); RankedTensorType ttp = getType().cast<RankedTensorType>();
unsigned rank = ttp.getRank(); unsigned rank = ttp.getRank();
if (rank != op.sizes().size()) if (rank != sizes().size())
return op.emitError("unexpected mismatch between tensor rank and sizes: ") return emitError("unexpected mismatch between tensor rank and sizes: ")
<< rank << " vs. " << op.sizes().size(); << rank << " vs. " << sizes().size();
auto shape = ttp.getShape(); auto shape = ttp.getShape();
for (unsigned i = 0; i < rank; i++) { for (unsigned i = 0; i < rank; i++) {
if (shape[i] == ShapedType::kDynamicSize) if (shape[i] == ShapedType::kDynamicSize)
continue; continue;
IntegerAttr constantAttr; IntegerAttr constantAttr;
if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) || if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) ||
constantAttr.getInt() != shape[i]) { constantAttr.getInt() != shape[i]) {
return op.emitError("unexpected mismatch with static dimension size ") return emitError("unexpected mismatch with static dimension size ")
<< shape[i]; << shape[i];
} }
} }
return success(); return success();
} }
static LogicalResult verify(ConvertOp op) { LogicalResult ConvertOp::verify() {
if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) { if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) { if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
if (tp1.getRank() != tp2.getRank()) if (tp1.getRank() != tp2.getRank())
return op.emitError("unexpected conversion mismatch in rank"); return emitError("unexpected conversion mismatch in rank");
auto shape1 = tp1.getShape(); auto shape1 = tp1.getShape();
auto shape2 = tp2.getShape(); auto shape2 = tp2.getShape();
// Accept size matches between the source and the destination type // Accept size matches between the source and the destination type
// (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or // (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
// matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10). // matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) { for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize) if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
return op.emitError("unexpected conversion mismatch in dimension ") return emitError("unexpected conversion mismatch in dimension ") << d;
<< d;
}
return success(); return success();
} }
} }
return op.emitError("unexpected type in convert"); return emitError("unexpected type in convert");
} }
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) { OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
@ -264,35 +262,35 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
return {}; return {};
} }
static LogicalResult verify(ToPointersOp op) { LogicalResult ToPointersOp::verify() {
if (auto e = getSparseTensorEncoding(op.tensor().getType())) { if (auto e = getSparseTensorEncoding(tensor().getType())) {
if (failed(isInBounds(op.dim(), op.tensor()))) if (failed(isInBounds(dim(), tensor())))
return op.emitError("requested pointers dimension out of bounds"); return emitError("requested pointers dimension out of bounds");
if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth()))) if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
return op.emitError("unexpected type for pointers"); return emitError("unexpected type for pointers");
return success(); return success();
} }
return op.emitError("expected a sparse tensor to get pointers"); return emitError("expected a sparse tensor to get pointers");
} }
static LogicalResult verify(ToIndicesOp op) { LogicalResult ToIndicesOp::verify() {
if (auto e = getSparseTensorEncoding(op.tensor().getType())) { if (auto e = getSparseTensorEncoding(tensor().getType())) {
if (failed(isInBounds(op.dim(), op.tensor()))) if (failed(isInBounds(dim(), tensor())))
return op.emitError("requested indices dimension out of bounds"); return emitError("requested indices dimension out of bounds");
if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth()))) if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
return op.emitError("unexpected type for indices"); return emitError("unexpected type for indices");
return success(); return success();
} }
return op.emitError("expected a sparse tensor to get indices"); return emitError("expected a sparse tensor to get indices");
} }
static LogicalResult verify(ToValuesOp op) { LogicalResult ToValuesOp::verify() {
if (!getSparseTensorEncoding(op.tensor().getType())) if (!getSparseTensorEncoding(tensor().getType()))
return op.emitError("expected a sparse tensor to get values"); return emitError("expected a sparse tensor to get values");
RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>(); RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
MemRefType mtp = op.result().getType().cast<MemRefType>(); MemRefType mtp = result().getType().cast<MemRefType>();
if (ttp.getElementType() != mtp.getElementType()) if (ttp.getElementType() != mtp.getElementType())
return op.emitError("unexpected mismatch in element types"); return emitError("unexpected mismatch in element types");
return success(); return success();
} }
@ -300,39 +298,39 @@ static LogicalResult verify(ToValuesOp op) {
// TensorDialect Management Operations. // TensorDialect Management Operations.
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(LexInsertOp op) { LogicalResult LexInsertOp::verify() {
if (!getSparseTensorEncoding(op.tensor().getType())) if (!getSparseTensorEncoding(tensor().getType()))
return op.emitError("expected a sparse tensor for insertion"); return emitError("expected a sparse tensor for insertion");
return success(); return success();
} }
static LogicalResult verify(ExpandOp op) { LogicalResult ExpandOp::verify() {
if (!getSparseTensorEncoding(op.tensor().getType())) if (!getSparseTensorEncoding(tensor().getType()))
return op.emitError("expected a sparse tensor for expansion"); return emitError("expected a sparse tensor for expansion");
return success(); return success();
} }
static LogicalResult verify(CompressOp op) { LogicalResult CompressOp::verify() {
if (!getSparseTensorEncoding(op.tensor().getType())) if (!getSparseTensorEncoding(tensor().getType()))
return op.emitError("expected a sparse tensor for compression"); return emitError("expected a sparse tensor for compression");
return success(); return success();
} }
static LogicalResult verify(LoadOp op) { LogicalResult LoadOp::verify() {
if (!getSparseTensorEncoding(op.tensor().getType())) if (!getSparseTensorEncoding(tensor().getType()))
return op.emitError("expected a sparse tensor to materialize"); return emitError("expected a sparse tensor to materialize");
return success(); return success();
} }
static LogicalResult verify(ReleaseOp op) { LogicalResult ReleaseOp::verify() {
if (!getSparseTensorEncoding(op.tensor().getType())) if (!getSparseTensorEncoding(tensor().getType()))
return op.emitError("expected a sparse tensor to release"); return emitError("expected a sparse tensor to release");
return success(); return success();
} }
static LogicalResult verify(OutOp op) { LogicalResult OutOp::verify() {
if (!getSparseTensorEncoding(op.tensor().getType())) if (!getSparseTensorEncoding(tensor().getType()))
return op.emitError("expected a sparse tensor for output"); return emitError("expected a sparse tensor for output");
return success(); return success();
} }

View File

@ -228,17 +228,17 @@ Optional<int64_t> DimOp::getConstantIndex() {
return {}; return {};
} }
static LogicalResult verify(DimOp op) { LogicalResult DimOp::verify() {
// Assume unknown index to be in range. // Assume unknown index to be in range.
Optional<int64_t> index = op.getConstantIndex(); Optional<int64_t> index = getConstantIndex();
if (!index.hasValue()) if (!index.hasValue())
return success(); return success();
// Check that constant index is not knowingly out of range. // Check that constant index is not knowingly out of range.
auto type = op.source().getType(); auto type = source().getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) { if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
if (index.getValue() >= tensorType.getRank()) if (index.getValue() >= tensorType.getRank())
return op.emitOpError("index is out of range"); return emitOpError("index is out of range");
} else if (type.isa<UnrankedTensorType>()) { } else if (type.isa<UnrankedTensorType>()) {
// Assume index to be in range. // Assume index to be in range.
} else { } else {
@ -328,11 +328,11 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
// ExtractOp // ExtractOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(ExtractOp op) { LogicalResult ExtractOp::verify() {
// Verify the # indices match if we have a ranked type. // Verify the # indices match if we have a ranked type.
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>()) if (auto tensorType = tensor().getType().dyn_cast<RankedTensorType>())
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size())) if (tensorType.getRank() != static_cast<int64_t>(indices().size()))
return op.emitOpError("incorrect number of indices for extract_element"); return emitOpError("incorrect number of indices for extract_element");
return success(); return success();
} }
@ -480,11 +480,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
// InsertOp // InsertOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(InsertOp op) { LogicalResult InsertOp::verify() {
// Verify the # indices match if we have a ranked type. // Verify the # indices match if we have a ranked type.
if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>()) if (auto destType = dest().getType().dyn_cast<RankedTensorType>())
if (destType.getRank() != static_cast<int64_t>(op.indices().size())) if (destType.getRank() != static_cast<int64_t>(indices().size()))
return op.emitOpError("incorrect number of indices"); return emitOpError("incorrect number of indices");
return success(); return success();
} }
@ -502,27 +502,26 @@ OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
// GenerateOp // GenerateOp
//===----------------------------------------------------------------------===// //===----------------------------------------------------------------------===//
static LogicalResult verify(GenerateOp op) { LogicalResult GenerateOp::verify() {
// Ensure that the tensor type has as many dynamic dimensions as are specified // Ensure that the tensor type has as many dynamic dimensions as are specified
// by the operands. // by the operands.
RankedTensorType resultTy = op.getType().cast<RankedTensorType>(); RankedTensorType resultTy = getType().cast<RankedTensorType>();
if (op.getNumOperands() != resultTy.getNumDynamicDims()) if (getNumOperands() != resultTy.getNumDynamicDims())
return op.emitError("must have as many index operands as dynamic extents " return emitError("must have as many index operands as dynamic extents "
"in the result type"); "in the result type");
// Ensure that region arguments span the index space. // Ensure that region arguments span the index space.
if (!llvm::all_of(op.body().getArgumentTypes(), if (!llvm::all_of(body().getArgumentTypes(),
[](Type ty) { return ty.isIndex(); })) [](Type ty) { return ty.isIndex(); }))
return op.emitError("all body arguments must be index"); return emitError("all body arguments must be index");
if (op.body().getNumArguments() != resultTy.getRank()) if (body().getNumArguments() != resultTy.getRank())
return op.emitError("must have one body argument per input dimension"); return emitError("must have one body argument per input dimension");
// Ensure that the region yields an element of the right type. // Ensure that the region yields an element of the right type.
auto yieldOp = auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
if (yieldOp.value().getType() != resultTy.getElementType()) if (yieldOp.value().getType() != resultTy.getElementType())
return op.emitOpError( return emitOpError(
"body must be terminated with a `yield` operation of the tensor " "body must be terminated with a `yield` operation of the tensor "
"element type"); "element type");
@ -686,16 +685,15 @@ static int64_t getNumElements(ShapedType type) {
return numElements; return numElements;
} }
static LogicalResult verify(ReshapeOp op) { LogicalResult ReshapeOp::verify() {
TensorType operandType = op.source().getType().cast<TensorType>(); TensorType operandType = source().getType().cast<TensorType>();
TensorType resultType = op.result().getType().cast<TensorType>(); TensorType resultType = result().getType().cast<TensorType>();
if (operandType.getElementType() != resultType.getElementType()) if (operandType.getElementType() != resultType.getElementType())
return op.emitOpError("element types of source and destination tensor " return emitOpError("element types of source and destination tensor "
"types should be the same"); "types should be the same");
int64_t shapeSize = int64_t shapeSize = shape().getType().cast<RankedTensorType>().getDimSize(0);
op.shape().getType().cast<RankedTensorType>().getDimSize(0);
auto resultRankedType = resultType.dyn_cast<RankedTensorType>(); auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
auto operandRankedType = operandType.dyn_cast<RankedTensorType>(); auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
@ -703,14 +701,14 @@ static LogicalResult verify(ReshapeOp op) {
if (operandRankedType && resultRankedType.hasStaticShape() && if (operandRankedType && resultRankedType.hasStaticShape() &&
operandRankedType.hasStaticShape()) { operandRankedType.hasStaticShape()) {
if (getNumElements(operandRankedType) != getNumElements(resultRankedType)) if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
return op.emitOpError("source and destination tensor should have the " return emitOpError("source and destination tensor should have the "
"same number of elements"); "same number of elements");
} }
if (ShapedType::isDynamic(shapeSize)) if (ShapedType::isDynamic(shapeSize))
return op.emitOpError("cannot use shape operand with dynamic length to " return emitOpError("cannot use shape operand with dynamic length to "
"reshape to statically-ranked tensor type"); "reshape to statically-ranked tensor type");
if (shapeSize != resultRankedType.getRank()) if (shapeSize != resultRankedType.getRank())
return op.emitOpError( return emitOpError(
"length of shape operand differs from the result's tensor rank"); "length of shape operand differs from the result's tensor rank");
} }
return success(); return success();
@ -814,12 +812,12 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
return success(); return success();
} }
static LogicalResult verify(ExpandShapeOp op) { LogicalResult ExpandShapeOp::verify() {
return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType()); return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
} }
static LogicalResult verify(CollapseShapeOp op) { LogicalResult CollapseShapeOp::verify() {
return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType()); return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
} }
namespace { namespace {
@ -1052,14 +1050,12 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
} }
/// Verifier for ExtractSliceOp. /// Verifier for ExtractSliceOp.
static LogicalResult verify(ExtractSliceOp op) { LogicalResult ExtractSliceOp::verify() {
// Verify result type against inferred type. // Verify result type against inferred type.
auto expectedType = auto expectedType = ExtractSliceOp::inferResultType(
ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(), getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
op.getMixedSizes(), op.getMixedStrides()); auto result = isRankReducedType(expectedType.cast<ShapedType>(), getType());
auto result = return produceSliceErrorMsg(result, *this, expectedType);
isRankReducedType(expectedType.cast<ShapedType>(), op.getType());
return produceSliceErrorMsg(result, op, expectedType);
} }
/// Infer the canonical type of the result of an extract_slice op. Returns a /// Infer the canonical type of the result of an extract_slice op. Returns a
@ -1308,16 +1304,16 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
} }
/// Verifier for InsertSliceOp. /// Verifier for InsertSliceOp.
static LogicalResult verify(InsertSliceOp op) { LogicalResult InsertSliceOp::verify() {
// insert_slice is the inverse of extract_slice, use the same type inference. // insert_slice is the inverse of extract_slice, use the same type inference.
auto expectedType = ExtractSliceOp::inferRankReducedResultType( auto expectedType = ExtractSliceOp::inferRankReducedResultType(
op.getSourceType().getRank(), op.getType(), getSourceType().getRank(), getType(),
extractFromI64ArrayAttr(op.static_offsets()), extractFromI64ArrayAttr(static_offsets()),
extractFromI64ArrayAttr(op.static_sizes()), extractFromI64ArrayAttr(static_sizes()),
extractFromI64ArrayAttr(op.static_strides())); extractFromI64ArrayAttr(static_strides()));
auto result = auto result =
isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType()); isRankReducedType(expectedType.cast<ShapedType>(), getSourceType());
return produceSliceErrorMsg(result, op, expectedType); return produceSliceErrorMsg(result, *this, expectedType);
} }
/// If we have two consecutive InsertSliceOp writing to the same slice, we /// If we have two consecutive InsertSliceOp writing to the same slice, we
@ -1569,40 +1565,40 @@ ParseResult parseInferType(OpAsmParser &parser,
return success(); return success();
} }
static LogicalResult verify(PadOp op) { LogicalResult PadOp::verify() {
auto sourceType = op.source().getType().cast<RankedTensorType>(); auto sourceType = source().getType().cast<RankedTensorType>();
auto resultType = op.result().getType().cast<RankedTensorType>(); auto resultType = result().getType().cast<RankedTensorType>();
auto expectedType = PadOp::inferResultType( auto expectedType =
sourceType, extractFromI64ArrayAttr(op.static_low()), PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()),
extractFromI64ArrayAttr(op.static_high())); extractFromI64ArrayAttr(static_high()));
for (int i = 0, e = sourceType.getRank(); i < e; ++i) { for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
if (resultType.getDimSize(i) == expectedType.getDimSize(i)) if (resultType.getDimSize(i) == expectedType.getDimSize(i))
continue; continue;
if (expectedType.isDynamicDim(i)) if (expectedType.isDynamicDim(i))
continue; continue;
return op.emitError("specified type ") return emitError("specified type ")
<< resultType << " does not match the inferred type " << resultType << " does not match the inferred type "
<< expectedType; << expectedType;
} }
auto &region = op.region(); auto &region = getRegion();
unsigned rank = resultType.getRank(); unsigned rank = resultType.getRank();
Block &block = region.front(); Block &block = region.front();
if (block.getNumArguments() != rank) if (block.getNumArguments() != rank)
return op.emitError("expected the block to have ") << rank << " arguments"; return emitError("expected the block to have ") << rank << " arguments";
// Note: the number and type of yield values are checked in the YieldOp. // Note: the number and type of yield values are checked in the YieldOp.
for (const auto &en : llvm::enumerate(block.getArgumentTypes())) { for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
if (!en.value().isIndex()) if (!en.value().isIndex())
return op.emitOpError("expected block argument ") return emitOpError("expected block argument ")
<< (en.index() + 1) << " to be an index"; << (en.index() + 1) << " to be an index";
} }
// Ensure that the region yields an element of the right type. // Ensure that the region yields an element of the right type.
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator()); auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
if (yieldOp.value().getType() != if (yieldOp.value().getType() !=
op.getType().cast<ShapedType>().getElementType()) getType().cast<ShapedType>().getElementType())
return op.emitOpError("expected yield type to match shape element type"); return emitOpError("expected yield type to match shape element type");
return success(); return success();
} }