forked from OSchip/llvm-project
[mlir][NFC] Update MemRef/Tensor operations to use `hasVerifier` instead of `verifier`
The verifier field is deprecated, and slated for removal. Differential Revision: https://reviews.llvm.org/D118821
This commit is contained in:
parent
bdc7ce975a
commit
b98dc0351a
|
@ -28,7 +28,6 @@ def MemRefTypeAttr
|
|||
class MemRef_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<MemRef_Dialect, mnemonic, traits> {
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
|
@ -93,6 +92,7 @@ class AllocLikeOp<string mnemonic,
|
|||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -115,6 +115,7 @@ def AssumeAlignmentOp : MemRef_Op<"assume_alignment"> {
|
|||
let results = (outs);
|
||||
|
||||
let assemblyFormat = "$memref `,` $alignment attr-dict `:` type($memref)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -162,6 +163,7 @@ def MemRef_AllocOp : AllocLikeOp<"alloc", DefaultResource, []> {
|
|||
memref<8x64xf32, affine_map<(d0, d1)[s0] -> ((d0 + s0), d1)>, 1>
|
||||
```
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -205,6 +207,7 @@ def MemRef_AllocaOp : AllocLikeOp<"alloca", AutomaticAllocationScopeResource> {
|
|||
an alignment on any convenient boundary compatible with the type will be
|
||||
chosen.
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -253,6 +256,7 @@ def MemRef_AllocaScopeOp : MemRef_Op<"alloca_scope",
|
|||
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
let regions = (region SizedRegion<1>:$bodyRegion);
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -279,11 +283,7 @@ def MemRef_AllocaScopeReturnOp : MemRef_Op<"alloca_scope.return",
|
|||
let arguments = (ins Variadic<AnyType>:$results);
|
||||
let builders = [OpBuilder<(ins), [{ /*nothing to do */ }]>];
|
||||
|
||||
let assemblyFormat =
|
||||
[{ attr-dict ($results^ `:` type($results))? }];
|
||||
|
||||
// No custom verification needed.
|
||||
let verifier = ?;
|
||||
let assemblyFormat = "attr-dict ($results^ `:` type($results))?";
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -355,7 +355,6 @@ def MemRef_CastOp : MemRef_Op<"cast", [
|
|||
let arguments = (ins AnyRankedOrUnrankedMemRef:$source);
|
||||
let results = (outs AnyRankedOrUnrankedMemRef:$dest);
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
|
||||
let verifier = "return impl::verifyCastOp(*this, areCastCompatible);";
|
||||
let builders = [
|
||||
OpBuilder<(ins "Value":$source, "Type":$destType), [{
|
||||
impl::buildCastOp($_builder, $_state, source, destType);
|
||||
|
@ -370,6 +369,7 @@ def MemRef_CastOp : MemRef_Op<"cast", [
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -408,7 +408,6 @@ def CopyOp : MemRef_Op<"copy",
|
|||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -434,7 +433,6 @@ def MemRef_DeallocOp : MemRef_Op<"dealloc", [MemRefsNormalizable]> {
|
|||
let arguments = (ins Arg<AnyRankedOrUnrankedMemRef, "", [MemFree]>:$memref);
|
||||
|
||||
let hasFolder = 1;
|
||||
let verifier = ?;
|
||||
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
||||
}
|
||||
|
||||
|
@ -488,6 +486,7 @@ def MemRef_DimOp : MemRef_Op<"dim", [NoSideEffect, MemRefsNormalizable]> {
|
|||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -646,6 +645,7 @@ def MemRef_DmaStartOp : MemRef_Op<"dma_start"> {
|
|||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -697,6 +697,7 @@ def MemRef_DmaWaitOp : MemRef_Op<"dma_wait"> {
|
|||
Value getNumElements() { return numElements(); }
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -757,6 +758,7 @@ def GenericAtomicRMWOp : MemRef_Op<"generic_atomic_rmw", [
|
|||
return memref().getType().cast<MemRefType>();
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def AtomicYieldOp : MemRef_Op<"atomic_yield", [
|
||||
|
@ -772,6 +774,7 @@ def AtomicYieldOp : MemRef_Op<"atomic_yield", [
|
|||
|
||||
let arguments = (ins AnyType:$result);
|
||||
let assemblyFormat = "$result attr-dict `:` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -797,9 +800,6 @@ def MemRef_GetGlobalOp : MemRef_Op<"get_global",
|
|||
let arguments = (ins FlatSymbolRefAttr:$name);
|
||||
let results = (outs AnyStaticShapeMemRef:$result);
|
||||
let assemblyFormat = "$name `:` type($result) attr-dict";
|
||||
|
||||
// `GetGlobalOp` is fully verified by its traits.
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -866,6 +866,7 @@ def MemRef_GlobalOp : MemRef_Op<"global", [Symbol]> {
|
|||
return !isExternal() && initial_value().getValue().isa<UnitAttr>();
|
||||
}
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -939,6 +940,7 @@ def LoadOp : MemRef_Op<"load",
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = "$memref `[` $indices `]` attr-dict `:` type($memref)";
|
||||
}
|
||||
|
@ -982,6 +984,7 @@ def MemRef_PrefetchOp : MemRef_Op<"prefetch"> {
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1034,6 +1037,7 @@ def MemRef_ReinterpretCastOp:
|
|||
|
||||
let parser = ?;
|
||||
let printer = ?;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let builders = [
|
||||
// Build a ReinterpretCastOp with mixed static and dynamic entries.
|
||||
|
@ -1096,7 +1100,6 @@ def MemRef_RankOp : MemRef_Op<"rank", [NoSideEffect]> {
|
|||
let arguments = (ins AnyRankedOrUnrankedMemRef:$memref);
|
||||
let results = (outs Index);
|
||||
|
||||
let verifier = ?;
|
||||
let hasFolder = 1;
|
||||
let assemblyFormat = "$memref attr-dict `:` type($memref)";
|
||||
}
|
||||
|
@ -1161,6 +1164,7 @@ def MemRef_ReshapeOp: MemRef_Op<"reshape", [
|
|||
let assemblyFormat = [{
|
||||
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1226,6 +1230,7 @@ class MemRef_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
|||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
|
||||
}
|
||||
|
@ -1265,6 +1270,7 @@ def MemRef_ExpandShapeOp : MemRef_ReassociativeReshapeOp<"expand_shape"> {
|
|||
```
|
||||
}];
|
||||
let extraClassDeclaration = commonExtraClassDeclaration;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
|
||||
|
@ -1302,6 +1308,7 @@ def MemRef_CollapseShapeOp : MemRef_ReassociativeReshapeOp<"collapse_shape"> {
|
|||
```
|
||||
}];
|
||||
let extraClassDeclaration = commonExtraClassDeclaration;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1369,6 +1376,7 @@ def MemRef_StoreOp : MemRef_Op<"store",
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
|
||||
let assemblyFormat = [{
|
||||
$value `,` $memref `[` $indices `]` attr-dict `:` type($memref)
|
||||
|
@ -1617,6 +1625,7 @@ def SubViewOp : BaseOpWithOffsetSizesAndStrides<
|
|||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1645,8 +1654,6 @@ def TensorStoreOp : MemRef_Op<"tensor_store",
|
|||
|
||||
let arguments = (ins AnyTensor:$tensor, Arg<AnyRankedOrUnrankedMemRef,
|
||||
"the reference to store to", [MemWrite]>:$memref);
|
||||
// TensorStoreOp is fully verified by traits.
|
||||
let verifier = ?;
|
||||
|
||||
let assemblyFormat = "$tensor `,` $memref attr-dict `:` type($memref)";
|
||||
}
|
||||
|
@ -1681,6 +1688,7 @@ def MemRef_TransposeOp : MemRef_Op<"transpose", [NoSideEffect]>,
|
|||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1749,6 +1757,7 @@ def MemRef_ViewOp : MemRef_Op<"view", [
|
|||
}];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -1796,6 +1805,7 @@ def AtomicRMWOp : MemRef_Op<"atomic_rmw", [
|
|||
}
|
||||
}];
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // MEMREF_OPS
|
||||
|
|
|
@ -21,7 +21,6 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
class SparseTensor_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<SparseTensor_Dialect, mnemonic, traits> {
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
|
@ -50,6 +49,7 @@ def SparseTensor_NewOp : SparseTensor_Op<"new", [NoSideEffect]>,
|
|||
```
|
||||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
|
||||
|
@ -72,6 +72,7 @@ def SparseTensor_InitOp : SparseTensor_Op<"init", [NoSideEffect]>,
|
|||
```
|
||||
}];
|
||||
let assemblyFormat = "`[` $sizes `]` attr-dict `:` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
|
||||
|
@ -113,6 +114,7 @@ def SparseTensor_ConvertOp : SparseTensor_Op<"convert",
|
|||
}];
|
||||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
|
||||
|
@ -137,6 +139,7 @@ def SparseTensor_ToPointersOp : SparseTensor_Op<"pointers", [NoSideEffect]>,
|
|||
}];
|
||||
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
|
||||
" `to` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
|
||||
|
@ -161,6 +164,7 @@ def SparseTensor_ToIndicesOp : SparseTensor_Op<"indices", [NoSideEffect]>,
|
|||
}];
|
||||
let assemblyFormat = "$tensor `,` $dim attr-dict `:` type($tensor)"
|
||||
" `to` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
|
||||
|
@ -183,6 +187,7 @@ def SparseTensor_ToValuesOp : SparseTensor_Op<"values", [NoSideEffect]>,
|
|||
```
|
||||
}];
|
||||
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($result)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -217,6 +222,7 @@ def SparseTensor_LexInsertOp : SparseTensor_Op<"lex_insert", []>,
|
|||
}];
|
||||
let assemblyFormat = "$tensor `,` $indices `,` $value attr-dict `:`"
|
||||
" type($tensor) `,` type($indices) `,` type($value)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
|
||||
|
@ -258,6 +264,7 @@ def SparseTensor_ExpandOp : SparseTensor_Op<"expand", []>,
|
|||
}];
|
||||
let assemblyFormat = "$tensor attr-dict `:` type($tensor) `to` type($values)"
|
||||
" `,` type($filled) `,` type($added) `,` type($count)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
|
||||
|
@ -292,6 +299,7 @@ def SparseTensor_CompressOp : SparseTensor_Op<"compress", []>,
|
|||
" $added `,` $count attr-dict `:` type($tensor) `,`"
|
||||
" type($indices) `,` type($values) `,` type($filled) `,`"
|
||||
" type($added) `,` type($count)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
|
||||
|
@ -324,6 +332,7 @@ def SparseTensor_LoadOp : SparseTensor_Op<"load", [SameOperandsAndResultType]>,
|
|||
```
|
||||
}];
|
||||
let assemblyFormat = "$tensor (`hasInserts` $hasInserts^)? attr-dict `:` type($tensor)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
|
||||
|
@ -349,6 +358,7 @@ def SparseTensor_ReleaseOp : SparseTensor_Op<"release", []>,
|
|||
```
|
||||
}];
|
||||
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
|
||||
|
@ -369,6 +379,7 @@ def SparseTensor_OutOp : SparseTensor_Op<"out", []>,
|
|||
```
|
||||
}];
|
||||
let assemblyFormat = "$tensor `,` $dest attr-dict `:` type($tensor) `,` type($dest)";
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
#endif // SPARSETENSOR_OPS
|
||||
|
|
|
@ -20,7 +20,6 @@ include "mlir/Interfaces/ViewLikeInterface.td"
|
|||
class Tensor_Op<string mnemonic, list<Trait> traits = []>
|
||||
: Op<Tensor_Dialect, mnemonic, traits> {
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
|
@ -59,7 +58,6 @@ def Tensor_CastOp : Tensor_Op<"cast", [
|
|||
let assemblyFormat = "$source attr-dict `:` type($source) `to` type($dest)";
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -111,6 +109,7 @@ def Tensor_DimOp : Tensor_Op<"dim", [NoSideEffect]> {
|
|||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -151,6 +150,7 @@ def Tensor_ExtractOp : Tensor_Op<"extract",
|
|||
}]>];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
|
@ -303,6 +303,7 @@ def Tensor_ExtractSliceOp : BaseOpWithOffsetSizesAndStrides<
|
|||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -339,9 +340,6 @@ def Tensor_FromElementsOp : Tensor_Op<"from_elements", [
|
|||
|
||||
let assemblyFormat = "$elements attr-dict `:` type($result)";
|
||||
|
||||
// This op is fully verified by its traits.
|
||||
let verifier = ?;
|
||||
|
||||
let skipDefaultBuilders = 1;
|
||||
let builders = [
|
||||
OpBuilder<(ins "Type":$resultType, "ValueRange":$elements)>,
|
||||
|
@ -394,6 +392,7 @@ def Tensor_GenerateOp : Tensor_Op<"generate",
|
|||
];
|
||||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -445,6 +444,7 @@ def Tensor_InsertOp : Tensor_Op<"insert",
|
|||
}]>];
|
||||
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -564,6 +564,7 @@ def Tensor_InsertSliceOp : BaseOpWithOffsetSizesAndStrides<
|
|||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -586,7 +587,6 @@ def Tensor_RankOp : Tensor_Op<"rank", [NoSideEffect]> {
|
|||
let arguments = (ins AnyTensor:$tensor);
|
||||
let results = (outs Index);
|
||||
|
||||
let verifier = ?;
|
||||
let hasFolder = 1;
|
||||
let assemblyFormat = "$tensor attr-dict `:` type($tensor)";
|
||||
}
|
||||
|
@ -650,6 +650,7 @@ def Tensor_ReshapeOp: Tensor_Op<"reshape", [NoSideEffect]> {
|
|||
let assemblyFormat = [{
|
||||
$source `(` $shape `)` attr-dict `:` functional-type(operands, results)
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -718,6 +719,7 @@ class Tensor_ReassociativeReshapeOp<string mnemonic, list<Trait> traits = []> :
|
|||
|
||||
let hasFolder = 1;
|
||||
let hasCanonicalizer = 1;
|
||||
let hasVerifier = 1;
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parseReshapeLikeOp(parser, result); }];
|
||||
}
|
||||
|
@ -748,6 +750,7 @@ def Tensor_ExpandShapeOp : Tensor_ReassociativeReshapeOp<"expand_shape"> {
|
|||
```
|
||||
}];
|
||||
let extraClassDeclaration = commonExtraClassDeclaration;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
|
||||
|
@ -776,6 +779,7 @@ def Tensor_CollapseShapeOp : Tensor_ReassociativeReshapeOp<"collapse_shape"> {
|
|||
```
|
||||
}];
|
||||
let extraClassDeclaration = commonExtraClassDeclaration;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -961,6 +965,7 @@ def Tensor_PadOp : Tensor_Op<"pad", [AttrSizedOperandSegments, NoSideEffect,
|
|||
|
||||
let hasCanonicalizer = 1;
|
||||
let hasFolder = 1;
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
|
@ -984,7 +989,6 @@ def Tensor_YieldOp : Tensor_Op<"yield",
|
|||
// Dummy builder to appease code in templated ensureTerminator that
|
||||
// GenerateOp's auto-generated parser calls.
|
||||
let builders = [OpBuilder<(ins), [{ /* nothing to do */ }]>];
|
||||
let verifier = ?;
|
||||
}
|
||||
|
||||
#endif // TENSOR_OPS
|
||||
|
|
|
@ -67,6 +67,10 @@ Type mlir::memref::getTensorTypeFromMemRefType(Type type) {
|
|||
return NoneType::get(type.getContext());
|
||||
}
|
||||
|
||||
LogicalResult memref::CastOp::verify() {
|
||||
return impl::verifyCastOp(*this, areCastCompatible);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllocOp / AllocaOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -95,15 +99,15 @@ static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
|
||||
LogicalResult AllocOp::verify() { return verifyAllocLikeOp(*this); }
|
||||
|
||||
static LogicalResult verify(AllocaOp op) {
|
||||
LogicalResult AllocaOp::verify() {
|
||||
// An alloca op needs to have an ancestor with an allocation scope trait.
|
||||
if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
|
||||
return op.emitOpError(
|
||||
if (!(*this)->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
|
||||
return emitOpError(
|
||||
"requires an ancestor op with AutomaticAllocationScope trait");
|
||||
|
||||
return verifyAllocLikeOp(op);
|
||||
return verifyAllocLikeOp(*this);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -246,11 +250,8 @@ static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(AllocaScopeOp op) {
|
||||
if (failed(RegionBranchOpInterface::verifyTypes(op)))
|
||||
return failure();
|
||||
|
||||
return success();
|
||||
LogicalResult AllocaScopeOp::verify() {
|
||||
return RegionBranchOpInterface::verifyTypes(*this);
|
||||
}
|
||||
|
||||
void AllocaScopeOp::getSuccessorRegions(
|
||||
|
@ -268,10 +269,9 @@ void AllocaScopeOp::getSuccessorRegions(
|
|||
// AssumeAlignmentOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(AssumeAlignmentOp op) {
|
||||
unsigned alignment = op.alignment();
|
||||
if (!llvm::isPowerOf2_32(alignment))
|
||||
return op.emitOpError("alignment must be power of 2");
|
||||
LogicalResult AssumeAlignmentOp::verify() {
|
||||
if (!llvm::isPowerOf2_32(alignment()))
|
||||
return emitOpError("alignment must be power of 2");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -556,17 +556,17 @@ Optional<int64_t> DimOp::getConstantIndex() {
|
|||
return {};
|
||||
}
|
||||
|
||||
static LogicalResult verify(DimOp op) {
|
||||
LogicalResult DimOp::verify() {
|
||||
// Assume unknown index to be in range.
|
||||
Optional<int64_t> index = op.getConstantIndex();
|
||||
Optional<int64_t> index = getConstantIndex();
|
||||
if (!index.hasValue())
|
||||
return success();
|
||||
|
||||
// Check that constant index is not knowingly out of range.
|
||||
auto type = op.source().getType();
|
||||
auto type = source().getType();
|
||||
if (auto memrefType = type.dyn_cast<MemRefType>()) {
|
||||
if (index.getValue() >= memrefType.getRank())
|
||||
return op.emitOpError("index is out of range");
|
||||
return emitOpError("index is out of range");
|
||||
} else if (type.isa<UnrankedMemRefType>()) {
|
||||
// Assume index to be in range.
|
||||
} else {
|
||||
|
@ -866,67 +866,66 @@ static ParseResult parseDmaStartOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(DmaStartOp op) {
|
||||
unsigned numOperands = op.getNumOperands();
|
||||
LogicalResult DmaStartOp::verify() {
|
||||
unsigned numOperands = getNumOperands();
|
||||
|
||||
// Mandatory non-variadic operands are: src memref, dst memref, tag memref and
|
||||
// the number of elements.
|
||||
if (numOperands < 4)
|
||||
return op.emitOpError("expected at least 4 operands");
|
||||
return emitOpError("expected at least 4 operands");
|
||||
|
||||
// Check types of operands. The order of these calls is important: the later
|
||||
// calls rely on some type properties to compute the operand position.
|
||||
// 1. Source memref.
|
||||
if (!op.getSrcMemRef().getType().isa<MemRefType>())
|
||||
return op.emitOpError("expected source to be of memref type");
|
||||
if (numOperands < op.getSrcMemRefRank() + 4)
|
||||
return op.emitOpError()
|
||||
<< "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
|
||||
if (!op.getSrcIndices().empty() &&
|
||||
!llvm::all_of(op.getSrcIndices().getTypes(),
|
||||
if (!getSrcMemRef().getType().isa<MemRefType>())
|
||||
return emitOpError("expected source to be of memref type");
|
||||
if (numOperands < getSrcMemRefRank() + 4)
|
||||
return emitOpError() << "expected at least " << getSrcMemRefRank() + 4
|
||||
<< " operands";
|
||||
if (!getSrcIndices().empty() &&
|
||||
!llvm::all_of(getSrcIndices().getTypes(),
|
||||
[](Type t) { return t.isIndex(); }))
|
||||
return op.emitOpError("expected source indices to be of index type");
|
||||
return emitOpError("expected source indices to be of index type");
|
||||
|
||||
// 2. Destination memref.
|
||||
if (!op.getDstMemRef().getType().isa<MemRefType>())
|
||||
return op.emitOpError("expected destination to be of memref type");
|
||||
unsigned numExpectedOperands =
|
||||
op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
|
||||
if (!getDstMemRef().getType().isa<MemRefType>())
|
||||
return emitOpError("expected destination to be of memref type");
|
||||
unsigned numExpectedOperands = getSrcMemRefRank() + getDstMemRefRank() + 4;
|
||||
if (numOperands < numExpectedOperands)
|
||||
return op.emitOpError()
|
||||
<< "expected at least " << numExpectedOperands << " operands";
|
||||
if (!op.getDstIndices().empty() &&
|
||||
!llvm::all_of(op.getDstIndices().getTypes(),
|
||||
return emitOpError() << "expected at least " << numExpectedOperands
|
||||
<< " operands";
|
||||
if (!getDstIndices().empty() &&
|
||||
!llvm::all_of(getDstIndices().getTypes(),
|
||||
[](Type t) { return t.isIndex(); }))
|
||||
return op.emitOpError("expected destination indices to be of index type");
|
||||
return emitOpError("expected destination indices to be of index type");
|
||||
|
||||
// 3. Number of elements.
|
||||
if (!op.getNumElements().getType().isIndex())
|
||||
return op.emitOpError("expected num elements to be of index type");
|
||||
if (!getNumElements().getType().isIndex())
|
||||
return emitOpError("expected num elements to be of index type");
|
||||
|
||||
// 4. Tag memref.
|
||||
if (!op.getTagMemRef().getType().isa<MemRefType>())
|
||||
return op.emitOpError("expected tag to be of memref type");
|
||||
numExpectedOperands += op.getTagMemRefRank();
|
||||
if (!getTagMemRef().getType().isa<MemRefType>())
|
||||
return emitOpError("expected tag to be of memref type");
|
||||
numExpectedOperands += getTagMemRefRank();
|
||||
if (numOperands < numExpectedOperands)
|
||||
return op.emitOpError()
|
||||
<< "expected at least " << numExpectedOperands << " operands";
|
||||
if (!op.getTagIndices().empty() &&
|
||||
!llvm::all_of(op.getTagIndices().getTypes(),
|
||||
return emitOpError() << "expected at least " << numExpectedOperands
|
||||
<< " operands";
|
||||
if (!getTagIndices().empty() &&
|
||||
!llvm::all_of(getTagIndices().getTypes(),
|
||||
[](Type t) { return t.isIndex(); }))
|
||||
return op.emitOpError("expected tag indices to be of index type");
|
||||
return emitOpError("expected tag indices to be of index type");
|
||||
|
||||
// Optional stride-related operands must be either both present or both
|
||||
// absent.
|
||||
if (numOperands != numExpectedOperands &&
|
||||
numOperands != numExpectedOperands + 2)
|
||||
return op.emitOpError("incorrect number of operands");
|
||||
return emitOpError("incorrect number of operands");
|
||||
|
||||
// 5. Strides.
|
||||
if (op.isStrided()) {
|
||||
if (!op.getStride().getType().isIndex() ||
|
||||
!op.getNumElementsPerStride().getType().isIndex())
|
||||
return op.emitOpError(
|
||||
if (isStrided()) {
|
||||
if (!getStride().getType().isIndex() ||
|
||||
!getNumElementsPerStride().getType().isIndex())
|
||||
return emitOpError(
|
||||
"expected stride and num elements per stride to be of type index");
|
||||
}
|
||||
|
||||
|
@ -949,14 +948,14 @@ LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
|
|||
return foldMemRefCast(*this);
|
||||
}
|
||||
|
||||
static LogicalResult verify(DmaWaitOp op) {
|
||||
LogicalResult DmaWaitOp::verify() {
|
||||
// Check that the number of tag indices matches the tagMemRef rank.
|
||||
unsigned numTagIndices = op.tagIndices().size();
|
||||
unsigned tagMemRefRank = op.getTagMemRefRank();
|
||||
unsigned numTagIndices = tagIndices().size();
|
||||
unsigned tagMemRefRank = getTagMemRefRank();
|
||||
if (numTagIndices != tagMemRefRank)
|
||||
return op.emitOpError() << "expected tagIndices to have the same number of "
|
||||
"elements as the tagMemRef rank, expected "
|
||||
<< tagMemRefRank << ", but got " << numTagIndices;
|
||||
return emitOpError() << "expected tagIndices to have the same number of "
|
||||
"elements as the tagMemRef rank, expected "
|
||||
<< tagMemRefRank << ", but got " << numTagIndices;
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -979,14 +978,13 @@ void GenericAtomicRMWOp::build(OpBuilder &builder, OperationState &result,
|
|||
}
|
||||
}
|
||||
|
||||
static LogicalResult verify(GenericAtomicRMWOp op) {
|
||||
auto &body = op.getRegion();
|
||||
LogicalResult GenericAtomicRMWOp::verify() {
|
||||
auto &body = getRegion();
|
||||
if (body.getNumArguments() != 1)
|
||||
return op.emitOpError("expected single number of entry block arguments");
|
||||
return emitOpError("expected single number of entry block arguments");
|
||||
|
||||
if (op.getResult().getType() != body.getArgument(0).getType())
|
||||
return op.emitOpError(
|
||||
"expected block argument of the same type result type");
|
||||
if (getResult().getType() != body.getArgument(0).getType())
|
||||
return emitOpError("expected block argument of the same type result type");
|
||||
|
||||
bool hasSideEffects =
|
||||
body.walk([&](Operation *nestedOp) {
|
||||
|
@ -1034,12 +1032,12 @@ static void print(OpAsmPrinter &p, GenericAtomicRMWOp op) {
|
|||
// AtomicYieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(AtomicYieldOp op) {
|
||||
Type parentType = op->getParentOp()->getResultTypes().front();
|
||||
Type resultType = op.result().getType();
|
||||
LogicalResult AtomicYieldOp::verify() {
|
||||
Type parentType = (*this)->getParentOp()->getResultTypes().front();
|
||||
Type resultType = result().getType();
|
||||
if (parentType != resultType)
|
||||
return op.emitOpError() << "types mismatch between yield op: " << resultType
|
||||
<< " and its parent: " << parentType;
|
||||
return emitOpError() << "types mismatch between yield op: " << resultType
|
||||
<< " and its parent: " << parentType;
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1090,19 +1088,19 @@ parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(GlobalOp op) {
|
||||
auto memrefType = op.type().dyn_cast<MemRefType>();
|
||||
LogicalResult GlobalOp::verify() {
|
||||
auto memrefType = type().dyn_cast<MemRefType>();
|
||||
if (!memrefType || !memrefType.hasStaticShape())
|
||||
return op.emitOpError("type should be static shaped memref, but got ")
|
||||
<< op.type();
|
||||
return emitOpError("type should be static shaped memref, but got ")
|
||||
<< type();
|
||||
|
||||
// Verify that the initial value, if present, is either a unit attribute or
|
||||
// an elements attribute.
|
||||
if (op.initial_value().hasValue()) {
|
||||
Attribute initValue = op.initial_value().getValue();
|
||||
if (initial_value().hasValue()) {
|
||||
Attribute initValue = initial_value().getValue();
|
||||
if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
|
||||
return op.emitOpError("initial value should be a unit or elements "
|
||||
"attribute, but got ")
|
||||
return emitOpError("initial value should be a unit or elements "
|
||||
"attribute, but got ")
|
||||
<< initValue;
|
||||
|
||||
// Check that the type of the initial value is compatible with the type of
|
||||
|
@ -1111,17 +1109,17 @@ static LogicalResult verify(GlobalOp op) {
|
|||
Type initType = initValue.getType();
|
||||
Type tensorType = getTensorTypeFromMemRefType(memrefType);
|
||||
if (initType != tensorType)
|
||||
return op.emitOpError("initial value expected to be of type ")
|
||||
return emitOpError("initial value expected to be of type ")
|
||||
<< tensorType << ", but was of type " << initType;
|
||||
}
|
||||
}
|
||||
|
||||
if (Optional<uint64_t> alignAttr = op.alignment()) {
|
||||
if (Optional<uint64_t> alignAttr = alignment()) {
|
||||
uint64_t alignment = alignAttr.getValue();
|
||||
|
||||
if (!llvm::isPowerOf2_64(alignment))
|
||||
return op->emitError() << "alignment attribute value " << alignment
|
||||
<< " is not a power of 2";
|
||||
return emitError() << "alignment attribute value " << alignment
|
||||
<< " is not a power of 2";
|
||||
}
|
||||
|
||||
// TODO: verify visibility for declarations.
|
||||
|
@ -1154,9 +1152,9 @@ GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
// LoadOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(LoadOp op) {
|
||||
if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
|
||||
return op.emitOpError("incorrect number of indices for load");
|
||||
LogicalResult LoadOp::verify() {
|
||||
if (getNumOperands() != 1 + getMemRefType().getRank())
|
||||
return emitOpError("incorrect number of indices for load");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1224,9 +1222,9 @@ static ParseResult parsePrefetchOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(PrefetchOp op) {
|
||||
if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
|
||||
return op.emitOpError("too few indices");
|
||||
LogicalResult PrefetchOp::verify() {
|
||||
if (getNumOperands() != 1 + getMemRefType().getRank())
|
||||
return emitOpError("too few indices");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1306,26 +1304,25 @@ void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
|
|||
|
||||
// TODO: ponder whether we want to allow missing trailing sizes/strides that are
|
||||
// completed automatically, like we have for subview and extract_slice.
|
||||
static LogicalResult verify(ReinterpretCastOp op) {
|
||||
LogicalResult ReinterpretCastOp::verify() {
|
||||
// The source and result memrefs should be in the same memory space.
|
||||
auto srcType = op.source().getType().cast<BaseMemRefType>();
|
||||
auto resultType = op.getType().cast<MemRefType>();
|
||||
auto srcType = source().getType().cast<BaseMemRefType>();
|
||||
auto resultType = getType().cast<MemRefType>();
|
||||
if (srcType.getMemorySpace() != resultType.getMemorySpace())
|
||||
return op.emitError("different memory spaces specified for source type ")
|
||||
return emitError("different memory spaces specified for source type ")
|
||||
<< srcType << " and result memref type " << resultType;
|
||||
if (srcType.getElementType() != resultType.getElementType())
|
||||
return op.emitError("different element types specified for source type ")
|
||||
return emitError("different element types specified for source type ")
|
||||
<< srcType << " and result memref type " << resultType;
|
||||
|
||||
// Match sizes in result memref type and in static_sizes attribute.
|
||||
for (auto &en :
|
||||
llvm::enumerate(llvm::zip(resultType.getShape(),
|
||||
extractFromI64ArrayAttr(op.static_sizes())))) {
|
||||
for (auto &en : llvm::enumerate(llvm::zip(
|
||||
resultType.getShape(), extractFromI64ArrayAttr(static_sizes())))) {
|
||||
int64_t resultSize = std::get<0>(en.value());
|
||||
int64_t expectedSize = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamic(resultSize) &&
|
||||
!ShapedType::isDynamic(expectedSize) && resultSize != expectedSize)
|
||||
return op.emitError("expected result type with size = ")
|
||||
return emitError("expected result type with size = ")
|
||||
<< expectedSize << " instead of " << resultSize
|
||||
<< " in dim = " << en.index();
|
||||
}
|
||||
|
@ -1336,27 +1333,26 @@ static LogicalResult verify(ReinterpretCastOp op) {
|
|||
int64_t resultOffset;
|
||||
SmallVector<int64_t, 4> resultStrides;
|
||||
if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
|
||||
return op.emitError(
|
||||
"expected result type to have strided layout but found ")
|
||||
return emitError("expected result type to have strided layout but found ")
|
||||
<< resultType;
|
||||
|
||||
// Match offset in result memref type and in static_offsets attribute.
|
||||
int64_t expectedOffset = extractFromI64ArrayAttr(op.static_offsets()).front();
|
||||
int64_t expectedOffset = extractFromI64ArrayAttr(static_offsets()).front();
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultOffset) &&
|
||||
!ShapedType::isDynamicStrideOrOffset(expectedOffset) &&
|
||||
resultOffset != expectedOffset)
|
||||
return op.emitError("expected result type with offset = ")
|
||||
return emitError("expected result type with offset = ")
|
||||
<< resultOffset << " instead of " << expectedOffset;
|
||||
|
||||
// Match strides in result memref type and in static_strides attribute.
|
||||
for (auto &en : llvm::enumerate(llvm::zip(
|
||||
resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
|
||||
resultStrides, extractFromI64ArrayAttr(static_strides())))) {
|
||||
int64_t resultStride = std::get<0>(en.value());
|
||||
int64_t expectedStride = std::get<1>(en.value());
|
||||
if (!ShapedType::isDynamicStrideOrOffset(resultStride) &&
|
||||
!ShapedType::isDynamicStrideOrOffset(expectedStride) &&
|
||||
resultStride != expectedStride)
|
||||
return op.emitError("expected result type with stride = ")
|
||||
return emitError("expected result type with stride = ")
|
||||
<< expectedStride << " instead of " << resultStride
|
||||
<< " in dim = " << en.index();
|
||||
}
|
||||
|
@ -1532,8 +1528,8 @@ static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExpandShapeOp op) {
|
||||
return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
|
||||
LogicalResult ExpandShapeOp::verify() {
|
||||
return verifyReshapeOp(*this, getResultType(), getSrcType());
|
||||
}
|
||||
|
||||
void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
||||
|
@ -1542,8 +1538,8 @@ void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
|
||||
}
|
||||
|
||||
static LogicalResult verify(CollapseShapeOp op) {
|
||||
return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
|
||||
LogicalResult CollapseShapeOp::verify() {
|
||||
return verifyReshapeOp(*this, getSrcType(), getResultType());
|
||||
}
|
||||
|
||||
struct CollapseShapeOpMemRefCastFolder
|
||||
|
@ -1593,32 +1589,30 @@ OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
|
|||
// ReshapeOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(ReshapeOp op) {
|
||||
Type operandType = op.source().getType();
|
||||
Type resultType = op.result().getType();
|
||||
LogicalResult ReshapeOp::verify() {
|
||||
Type operandType = source().getType();
|
||||
Type resultType = result().getType();
|
||||
|
||||
Type operandElementType = operandType.cast<ShapedType>().getElementType();
|
||||
Type resultElementType = resultType.cast<ShapedType>().getElementType();
|
||||
if (operandElementType != resultElementType)
|
||||
return op.emitOpError("element types of source and destination memref "
|
||||
"types should be the same");
|
||||
return emitOpError("element types of source and destination memref "
|
||||
"types should be the same");
|
||||
|
||||
if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
|
||||
if (!operandMemRefType.getLayout().isIdentity())
|
||||
return op.emitOpError(
|
||||
"source memref type should have identity affine map");
|
||||
return emitOpError("source memref type should have identity affine map");
|
||||
|
||||
int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
|
||||
int64_t shapeSize = shape().getType().cast<MemRefType>().getDimSize(0);
|
||||
auto resultMemRefType = resultType.dyn_cast<MemRefType>();
|
||||
if (resultMemRefType) {
|
||||
if (!resultMemRefType.getLayout().isIdentity())
|
||||
return op.emitOpError(
|
||||
"result memref type should have identity affine map");
|
||||
return emitOpError("result memref type should have identity affine map");
|
||||
if (shapeSize == ShapedType::kDynamicSize)
|
||||
return op.emitOpError("cannot use shape operand with dynamic length to "
|
||||
"reshape to statically-ranked memref type");
|
||||
return emitOpError("cannot use shape operand with dynamic length to "
|
||||
"reshape to statically-ranked memref type");
|
||||
if (shapeSize != resultMemRefType.getRank())
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"length of shape operand differs from the result's memref rank");
|
||||
}
|
||||
return success();
|
||||
|
@ -1628,9 +1622,9 @@ static LogicalResult verify(ReshapeOp op) {
|
|||
// StoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(StoreOp op) {
|
||||
if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
|
||||
return op.emitOpError("store index operand count not equal to memref rank");
|
||||
LogicalResult StoreOp::verify() {
|
||||
if (getNumOperands() != 2 + getMemRefType().getRank())
|
||||
return emitOpError("store index operand count not equal to memref rank");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -1951,29 +1945,29 @@ static LogicalResult produceSubViewErrorMsg(SliceVerificationResult result,
|
|||
}
|
||||
|
||||
/// Verifier for SubViewOp.
|
||||
static LogicalResult verify(SubViewOp op) {
|
||||
MemRefType baseType = op.getSourceType();
|
||||
MemRefType subViewType = op.getType();
|
||||
LogicalResult SubViewOp::verify() {
|
||||
MemRefType baseType = getSourceType();
|
||||
MemRefType subViewType = getType();
|
||||
|
||||
// The base memref and the view memref should be in the same memory space.
|
||||
if (baseType.getMemorySpace() != subViewType.getMemorySpace())
|
||||
return op.emitError("different memory spaces specified for base memref "
|
||||
"type ")
|
||||
return emitError("different memory spaces specified for base memref "
|
||||
"type ")
|
||||
<< baseType << " and subview memref type " << subViewType;
|
||||
|
||||
// Verify that the base memref type has a strided layout map.
|
||||
if (!isStrided(baseType))
|
||||
return op.emitError("base type ") << baseType << " is not strided";
|
||||
return emitError("base type ") << baseType << " is not strided";
|
||||
|
||||
// Verify result type against inferred type.
|
||||
auto expectedType = SubViewOp::inferResultType(
|
||||
baseType, extractFromI64ArrayAttr(op.static_offsets()),
|
||||
extractFromI64ArrayAttr(op.static_sizes()),
|
||||
extractFromI64ArrayAttr(op.static_strides()));
|
||||
baseType, extractFromI64ArrayAttr(static_offsets()),
|
||||
extractFromI64ArrayAttr(static_sizes()),
|
||||
extractFromI64ArrayAttr(static_strides()));
|
||||
|
||||
auto result = isRankReducedMemRefType(expectedType.cast<MemRefType>(),
|
||||
subViewType, op.getMixedSizes());
|
||||
return produceSubViewErrorMsg(result, op, expectedType);
|
||||
subViewType, getMixedSizes());
|
||||
return produceSubViewErrorMsg(result, *this, expectedType);
|
||||
}
|
||||
|
||||
raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
|
||||
|
@ -2278,18 +2272,17 @@ static ParseResult parseTransposeOp(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(TransposeOp op) {
|
||||
if (!op.permutation().isPermutation())
|
||||
return op.emitOpError("expected a permutation map");
|
||||
if (op.permutation().getNumDims() != op.getShapedType().getRank())
|
||||
return op.emitOpError(
|
||||
"expected a permutation map of same rank as the input");
|
||||
LogicalResult TransposeOp::verify() {
|
||||
if (!permutation().isPermutation())
|
||||
return emitOpError("expected a permutation map");
|
||||
if (permutation().getNumDims() != getShapedType().getRank())
|
||||
return emitOpError("expected a permutation map of same rank as the input");
|
||||
|
||||
auto srcType = op.in().getType().cast<MemRefType>();
|
||||
auto dstType = op.getType().cast<MemRefType>();
|
||||
auto transposedType = inferTransposeResultType(srcType, op.permutation());
|
||||
auto srcType = in().getType().cast<MemRefType>();
|
||||
auto dstType = getType().cast<MemRefType>();
|
||||
auto transposedType = inferTransposeResultType(srcType, permutation());
|
||||
if (dstType != transposedType)
|
||||
return op.emitOpError("output type ")
|
||||
return emitOpError("output type ")
|
||||
<< dstType << " does not match transposed input type " << srcType
|
||||
<< ", " << transposedType;
|
||||
return success();
|
||||
|
@ -2338,29 +2331,28 @@ static void print(OpAsmPrinter &p, ViewOp op) {
|
|||
p << " : " << op.getOperand(0).getType() << " to " << op.getType();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ViewOp op) {
|
||||
auto baseType = op.getOperand(0).getType().cast<MemRefType>();
|
||||
auto viewType = op.getType();
|
||||
LogicalResult ViewOp::verify() {
|
||||
auto baseType = getOperand(0).getType().cast<MemRefType>();
|
||||
auto viewType = getType();
|
||||
|
||||
// The base memref should have identity layout map (or none).
|
||||
if (!baseType.getLayout().isIdentity())
|
||||
return op.emitError("unsupported map for base memref type ") << baseType;
|
||||
return emitError("unsupported map for base memref type ") << baseType;
|
||||
|
||||
// The result memref should have identity layout map (or none).
|
||||
if (!viewType.getLayout().isIdentity())
|
||||
return op.emitError("unsupported map for result memref type ") << viewType;
|
||||
return emitError("unsupported map for result memref type ") << viewType;
|
||||
|
||||
// The base memref and the view memref should be in the same memory space.
|
||||
if (baseType.getMemorySpace() != viewType.getMemorySpace())
|
||||
return op.emitError("different memory spaces specified for base memref "
|
||||
"type ")
|
||||
return emitError("different memory spaces specified for base memref "
|
||||
"type ")
|
||||
<< baseType << " and view memref type " << viewType;
|
||||
|
||||
// Verify that we have the correct number of sizes for the result type.
|
||||
unsigned numDynamicDims = viewType.getNumDynamicDims();
|
||||
if (op.sizes().size() != numDynamicDims)
|
||||
return op.emitError("incorrect number of size operands for type ")
|
||||
<< viewType;
|
||||
if (sizes().size() != numDynamicDims)
|
||||
return emitError("incorrect number of size operands for type ") << viewType;
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -2467,19 +2459,19 @@ void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// AtomicRMWOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(AtomicRMWOp op) {
|
||||
if (op.getMemRefType().getRank() != op.getNumOperands() - 2)
|
||||
return op.emitOpError(
|
||||
LogicalResult AtomicRMWOp::verify() {
|
||||
if (getMemRefType().getRank() != getNumOperands() - 2)
|
||||
return emitOpError(
|
||||
"expects the number of subscripts to be equal to memref rank");
|
||||
switch (op.kind()) {
|
||||
switch (kind()) {
|
||||
case arith::AtomicRMWKind::addf:
|
||||
case arith::AtomicRMWKind::maxf:
|
||||
case arith::AtomicRMWKind::minf:
|
||||
case arith::AtomicRMWKind::mulf:
|
||||
if (!op.value().getType().isa<FloatType>())
|
||||
return op.emitOpError()
|
||||
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
|
||||
<< "' expects a floating-point type";
|
||||
if (!value().getType().isa<FloatType>())
|
||||
return emitOpError() << "with kind '"
|
||||
<< arith::stringifyAtomicRMWKind(kind())
|
||||
<< "' expects a floating-point type";
|
||||
break;
|
||||
case arith::AtomicRMWKind::addi:
|
||||
case arith::AtomicRMWKind::maxs:
|
||||
|
@ -2489,10 +2481,10 @@ static LogicalResult verify(AtomicRMWOp op) {
|
|||
case arith::AtomicRMWKind::muli:
|
||||
case arith::AtomicRMWKind::ori:
|
||||
case arith::AtomicRMWKind::andi:
|
||||
if (!op.value().getType().isa<IntegerType>())
|
||||
return op.emitOpError()
|
||||
<< "with kind '" << arith::stringifyAtomicRMWKind(op.kind())
|
||||
<< "' expects an integer type";
|
||||
if (!value().getType().isa<IntegerType>())
|
||||
return emitOpError() << "with kind '"
|
||||
<< arith::stringifyAtomicRMWKind(kind())
|
||||
<< "' expects an integer type";
|
||||
break;
|
||||
default:
|
||||
break;
|
||||
|
|
|
@ -209,53 +209,51 @@ static LogicalResult isMatchingWidth(Value result, unsigned width) {
|
|||
return failure();
|
||||
}
|
||||
|
||||
static LogicalResult verify(NewOp op) {
|
||||
if (!getSparseTensorEncoding(op.result().getType()))
|
||||
return op.emitError("expected a sparse tensor result");
|
||||
LogicalResult NewOp::verify() {
|
||||
if (!getSparseTensorEncoding(result().getType()))
|
||||
return emitError("expected a sparse tensor result");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(InitOp op) {
|
||||
if (!getSparseTensorEncoding(op.result().getType()))
|
||||
return op.emitError("expected a sparse tensor result");
|
||||
RankedTensorType ttp = op.getType().cast<RankedTensorType>();
|
||||
LogicalResult InitOp::verify() {
|
||||
if (!getSparseTensorEncoding(result().getType()))
|
||||
return emitError("expected a sparse tensor result");
|
||||
RankedTensorType ttp = getType().cast<RankedTensorType>();
|
||||
unsigned rank = ttp.getRank();
|
||||
if (rank != op.sizes().size())
|
||||
return op.emitError("unexpected mismatch between tensor rank and sizes: ")
|
||||
<< rank << " vs. " << op.sizes().size();
|
||||
if (rank != sizes().size())
|
||||
return emitError("unexpected mismatch between tensor rank and sizes: ")
|
||||
<< rank << " vs. " << sizes().size();
|
||||
auto shape = ttp.getShape();
|
||||
for (unsigned i = 0; i < rank; i++) {
|
||||
if (shape[i] == ShapedType::kDynamicSize)
|
||||
continue;
|
||||
IntegerAttr constantAttr;
|
||||
if (!matchPattern(op.sizes()[i], m_Constant(&constantAttr)) ||
|
||||
if (!matchPattern(sizes()[i], m_Constant(&constantAttr)) ||
|
||||
constantAttr.getInt() != shape[i]) {
|
||||
return op.emitError("unexpected mismatch with static dimension size ")
|
||||
return emitError("unexpected mismatch with static dimension size ")
|
||||
<< shape[i];
|
||||
}
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ConvertOp op) {
|
||||
if (auto tp1 = op.source().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto tp2 = op.dest().getType().dyn_cast<RankedTensorType>()) {
|
||||
LogicalResult ConvertOp::verify() {
|
||||
if (auto tp1 = source().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (auto tp2 = dest().getType().dyn_cast<RankedTensorType>()) {
|
||||
if (tp1.getRank() != tp2.getRank())
|
||||
return op.emitError("unexpected conversion mismatch in rank");
|
||||
return emitError("unexpected conversion mismatch in rank");
|
||||
auto shape1 = tp1.getShape();
|
||||
auto shape2 = tp2.getShape();
|
||||
// Accept size matches between the source and the destination type
|
||||
// (e.g. 10 vs. 10, 10 vs. ?, or ? vs. ?), but reject direct mismatches or
|
||||
// matches that would need a runtime assert (e.g. 10 vs. 20 or ? vs. 10).
|
||||
for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++) {
|
||||
for (unsigned d = 0, rank = tp1.getRank(); d < rank; d++)
|
||||
if (shape1[d] != shape2[d] && shape2[d] != ShapedType::kDynamicSize)
|
||||
return op.emitError("unexpected conversion mismatch in dimension ")
|
||||
<< d;
|
||||
}
|
||||
return emitError("unexpected conversion mismatch in dimension ") << d;
|
||||
return success();
|
||||
}
|
||||
}
|
||||
return op.emitError("unexpected type in convert");
|
||||
return emitError("unexpected type in convert");
|
||||
}
|
||||
|
||||
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
||||
|
@ -264,35 +262,35 @@ OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
|
|||
return {};
|
||||
}
|
||||
|
||||
static LogicalResult verify(ToPointersOp op) {
|
||||
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
|
||||
if (failed(isInBounds(op.dim(), op.tensor())))
|
||||
return op.emitError("requested pointers dimension out of bounds");
|
||||
if (failed(isMatchingWidth(op.result(), e.getPointerBitWidth())))
|
||||
return op.emitError("unexpected type for pointers");
|
||||
LogicalResult ToPointersOp::verify() {
|
||||
if (auto e = getSparseTensorEncoding(tensor().getType())) {
|
||||
if (failed(isInBounds(dim(), tensor())))
|
||||
return emitError("requested pointers dimension out of bounds");
|
||||
if (failed(isMatchingWidth(result(), e.getPointerBitWidth())))
|
||||
return emitError("unexpected type for pointers");
|
||||
return success();
|
||||
}
|
||||
return op.emitError("expected a sparse tensor to get pointers");
|
||||
return emitError("expected a sparse tensor to get pointers");
|
||||
}
|
||||
|
||||
static LogicalResult verify(ToIndicesOp op) {
|
||||
if (auto e = getSparseTensorEncoding(op.tensor().getType())) {
|
||||
if (failed(isInBounds(op.dim(), op.tensor())))
|
||||
return op.emitError("requested indices dimension out of bounds");
|
||||
if (failed(isMatchingWidth(op.result(), e.getIndexBitWidth())))
|
||||
return op.emitError("unexpected type for indices");
|
||||
LogicalResult ToIndicesOp::verify() {
|
||||
if (auto e = getSparseTensorEncoding(tensor().getType())) {
|
||||
if (failed(isInBounds(dim(), tensor())))
|
||||
return emitError("requested indices dimension out of bounds");
|
||||
if (failed(isMatchingWidth(result(), e.getIndexBitWidth())))
|
||||
return emitError("unexpected type for indices");
|
||||
return success();
|
||||
}
|
||||
return op.emitError("expected a sparse tensor to get indices");
|
||||
return emitError("expected a sparse tensor to get indices");
|
||||
}
|
||||
|
||||
static LogicalResult verify(ToValuesOp op) {
|
||||
if (!getSparseTensorEncoding(op.tensor().getType()))
|
||||
return op.emitError("expected a sparse tensor to get values");
|
||||
RankedTensorType ttp = op.tensor().getType().cast<RankedTensorType>();
|
||||
MemRefType mtp = op.result().getType().cast<MemRefType>();
|
||||
LogicalResult ToValuesOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor to get values");
|
||||
RankedTensorType ttp = tensor().getType().cast<RankedTensorType>();
|
||||
MemRefType mtp = result().getType().cast<MemRefType>();
|
||||
if (ttp.getElementType() != mtp.getElementType())
|
||||
return op.emitError("unexpected mismatch in element types");
|
||||
return emitError("unexpected mismatch in element types");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -300,39 +298,39 @@ static LogicalResult verify(ToValuesOp op) {
|
|||
// TensorDialect Management Operations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(LexInsertOp op) {
|
||||
if (!getSparseTensorEncoding(op.tensor().getType()))
|
||||
return op.emitError("expected a sparse tensor for insertion");
|
||||
LogicalResult LexInsertOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor for insertion");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExpandOp op) {
|
||||
if (!getSparseTensorEncoding(op.tensor().getType()))
|
||||
return op.emitError("expected a sparse tensor for expansion");
|
||||
LogicalResult ExpandOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor for expansion");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(CompressOp op) {
|
||||
if (!getSparseTensorEncoding(op.tensor().getType()))
|
||||
return op.emitError("expected a sparse tensor for compression");
|
||||
LogicalResult CompressOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor for compression");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(LoadOp op) {
|
||||
if (!getSparseTensorEncoding(op.tensor().getType()))
|
||||
return op.emitError("expected a sparse tensor to materialize");
|
||||
LogicalResult LoadOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor to materialize");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ReleaseOp op) {
|
||||
if (!getSparseTensorEncoding(op.tensor().getType()))
|
||||
return op.emitError("expected a sparse tensor to release");
|
||||
LogicalResult ReleaseOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor to release");
|
||||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(OutOp op) {
|
||||
if (!getSparseTensorEncoding(op.tensor().getType()))
|
||||
return op.emitError("expected a sparse tensor for output");
|
||||
LogicalResult OutOp::verify() {
|
||||
if (!getSparseTensorEncoding(tensor().getType()))
|
||||
return emitError("expected a sparse tensor for output");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
|
|
@ -228,17 +228,17 @@ Optional<int64_t> DimOp::getConstantIndex() {
|
|||
return {};
|
||||
}
|
||||
|
||||
static LogicalResult verify(DimOp op) {
|
||||
LogicalResult DimOp::verify() {
|
||||
// Assume unknown index to be in range.
|
||||
Optional<int64_t> index = op.getConstantIndex();
|
||||
Optional<int64_t> index = getConstantIndex();
|
||||
if (!index.hasValue())
|
||||
return success();
|
||||
|
||||
// Check that constant index is not knowingly out of range.
|
||||
auto type = op.source().getType();
|
||||
auto type = source().getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
if (index.getValue() >= tensorType.getRank())
|
||||
return op.emitOpError("index is out of range");
|
||||
return emitOpError("index is out of range");
|
||||
} else if (type.isa<UnrankedTensorType>()) {
|
||||
// Assume index to be in range.
|
||||
} else {
|
||||
|
@ -328,11 +328,11 @@ void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// ExtractOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(ExtractOp op) {
|
||||
LogicalResult ExtractOp::verify() {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (auto tensorType = op.tensor().getType().dyn_cast<RankedTensorType>())
|
||||
if (tensorType.getRank() != static_cast<int64_t>(op.indices().size()))
|
||||
return op.emitOpError("incorrect number of indices for extract_element");
|
||||
if (auto tensorType = tensor().getType().dyn_cast<RankedTensorType>())
|
||||
if (tensorType.getRank() != static_cast<int64_t>(indices().size()))
|
||||
return emitOpError("incorrect number of indices for extract_element");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
@ -480,11 +480,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results,
|
|||
// InsertOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(InsertOp op) {
|
||||
LogicalResult InsertOp::verify() {
|
||||
// Verify the # indices match if we have a ranked type.
|
||||
if (auto destType = op.dest().getType().dyn_cast<RankedTensorType>())
|
||||
if (destType.getRank() != static_cast<int64_t>(op.indices().size()))
|
||||
return op.emitOpError("incorrect number of indices");
|
||||
if (auto destType = dest().getType().dyn_cast<RankedTensorType>())
|
||||
if (destType.getRank() != static_cast<int64_t>(indices().size()))
|
||||
return emitOpError("incorrect number of indices");
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -502,27 +502,26 @@ OpFoldResult InsertOp::fold(ArrayRef<Attribute> operands) {
|
|||
// GenerateOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(GenerateOp op) {
|
||||
LogicalResult GenerateOp::verify() {
|
||||
// Ensure that the tensor type has as many dynamic dimensions as are specified
|
||||
// by the operands.
|
||||
RankedTensorType resultTy = op.getType().cast<RankedTensorType>();
|
||||
if (op.getNumOperands() != resultTy.getNumDynamicDims())
|
||||
return op.emitError("must have as many index operands as dynamic extents "
|
||||
"in the result type");
|
||||
RankedTensorType resultTy = getType().cast<RankedTensorType>();
|
||||
if (getNumOperands() != resultTy.getNumDynamicDims())
|
||||
return emitError("must have as many index operands as dynamic extents "
|
||||
"in the result type");
|
||||
|
||||
// Ensure that region arguments span the index space.
|
||||
if (!llvm::all_of(op.body().getArgumentTypes(),
|
||||
if (!llvm::all_of(body().getArgumentTypes(),
|
||||
[](Type ty) { return ty.isIndex(); }))
|
||||
return op.emitError("all body arguments must be index");
|
||||
if (op.body().getNumArguments() != resultTy.getRank())
|
||||
return op.emitError("must have one body argument per input dimension");
|
||||
return emitError("all body arguments must be index");
|
||||
if (body().getNumArguments() != resultTy.getRank())
|
||||
return emitError("must have one body argument per input dimension");
|
||||
|
||||
// Ensure that the region yields an element of the right type.
|
||||
auto yieldOp =
|
||||
llvm::cast<YieldOp>(op.body().getBlocks().front().getTerminator());
|
||||
auto yieldOp = cast<YieldOp>(body().getBlocks().front().getTerminator());
|
||||
|
||||
if (yieldOp.value().getType() != resultTy.getElementType())
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"body must be terminated with a `yield` operation of the tensor "
|
||||
"element type");
|
||||
|
||||
|
@ -686,16 +685,15 @@ static int64_t getNumElements(ShapedType type) {
|
|||
return numElements;
|
||||
}
|
||||
|
||||
static LogicalResult verify(ReshapeOp op) {
|
||||
TensorType operandType = op.source().getType().cast<TensorType>();
|
||||
TensorType resultType = op.result().getType().cast<TensorType>();
|
||||
LogicalResult ReshapeOp::verify() {
|
||||
TensorType operandType = source().getType().cast<TensorType>();
|
||||
TensorType resultType = result().getType().cast<TensorType>();
|
||||
|
||||
if (operandType.getElementType() != resultType.getElementType())
|
||||
return op.emitOpError("element types of source and destination tensor "
|
||||
"types should be the same");
|
||||
return emitOpError("element types of source and destination tensor "
|
||||
"types should be the same");
|
||||
|
||||
int64_t shapeSize =
|
||||
op.shape().getType().cast<RankedTensorType>().getDimSize(0);
|
||||
int64_t shapeSize = shape().getType().cast<RankedTensorType>().getDimSize(0);
|
||||
auto resultRankedType = resultType.dyn_cast<RankedTensorType>();
|
||||
auto operandRankedType = operandType.dyn_cast<RankedTensorType>();
|
||||
|
||||
|
@ -703,14 +701,14 @@ static LogicalResult verify(ReshapeOp op) {
|
|||
if (operandRankedType && resultRankedType.hasStaticShape() &&
|
||||
operandRankedType.hasStaticShape()) {
|
||||
if (getNumElements(operandRankedType) != getNumElements(resultRankedType))
|
||||
return op.emitOpError("source and destination tensor should have the "
|
||||
"same number of elements");
|
||||
return emitOpError("source and destination tensor should have the "
|
||||
"same number of elements");
|
||||
}
|
||||
if (ShapedType::isDynamic(shapeSize))
|
||||
return op.emitOpError("cannot use shape operand with dynamic length to "
|
||||
"reshape to statically-ranked tensor type");
|
||||
return emitOpError("cannot use shape operand with dynamic length to "
|
||||
"reshape to statically-ranked tensor type");
|
||||
if (shapeSize != resultRankedType.getRank())
|
||||
return op.emitOpError(
|
||||
return emitOpError(
|
||||
"length of shape operand differs from the result's tensor rank");
|
||||
}
|
||||
return success();
|
||||
|
@ -814,12 +812,12 @@ static LogicalResult verifyTensorReshapeOp(TensorReshapeOp op,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(ExpandShapeOp op) {
|
||||
return verifyTensorReshapeOp(op, op.getResultType(), op.getSrcType());
|
||||
LogicalResult ExpandShapeOp::verify() {
|
||||
return verifyTensorReshapeOp(*this, getResultType(), getSrcType());
|
||||
}
|
||||
|
||||
static LogicalResult verify(CollapseShapeOp op) {
|
||||
return verifyTensorReshapeOp(op, op.getSrcType(), op.getResultType());
|
||||
LogicalResult CollapseShapeOp::verify() {
|
||||
return verifyTensorReshapeOp(*this, getSrcType(), getResultType());
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
@ -1052,14 +1050,12 @@ static LogicalResult produceSliceErrorMsg(SliceVerificationResult result,
|
|||
}
|
||||
|
||||
/// Verifier for ExtractSliceOp.
|
||||
static LogicalResult verify(ExtractSliceOp op) {
|
||||
LogicalResult ExtractSliceOp::verify() {
|
||||
// Verify result type against inferred type.
|
||||
auto expectedType =
|
||||
ExtractSliceOp::inferResultType(op.getSourceType(), op.getMixedOffsets(),
|
||||
op.getMixedSizes(), op.getMixedStrides());
|
||||
auto result =
|
||||
isRankReducedType(expectedType.cast<ShapedType>(), op.getType());
|
||||
return produceSliceErrorMsg(result, op, expectedType);
|
||||
auto expectedType = ExtractSliceOp::inferResultType(
|
||||
getSourceType(), getMixedOffsets(), getMixedSizes(), getMixedStrides());
|
||||
auto result = isRankReducedType(expectedType.cast<ShapedType>(), getType());
|
||||
return produceSliceErrorMsg(result, *this, expectedType);
|
||||
}
|
||||
|
||||
/// Infer the canonical type of the result of an extract_slice op. Returns a
|
||||
|
@ -1308,16 +1304,16 @@ void InsertSliceOp::build(OpBuilder &b, OperationState &result, Value source,
|
|||
}
|
||||
|
||||
/// Verifier for InsertSliceOp.
|
||||
static LogicalResult verify(InsertSliceOp op) {
|
||||
LogicalResult InsertSliceOp::verify() {
|
||||
// insert_slice is the inverse of extract_slice, use the same type inference.
|
||||
auto expectedType = ExtractSliceOp::inferRankReducedResultType(
|
||||
op.getSourceType().getRank(), op.getType(),
|
||||
extractFromI64ArrayAttr(op.static_offsets()),
|
||||
extractFromI64ArrayAttr(op.static_sizes()),
|
||||
extractFromI64ArrayAttr(op.static_strides()));
|
||||
getSourceType().getRank(), getType(),
|
||||
extractFromI64ArrayAttr(static_offsets()),
|
||||
extractFromI64ArrayAttr(static_sizes()),
|
||||
extractFromI64ArrayAttr(static_strides()));
|
||||
auto result =
|
||||
isRankReducedType(expectedType.cast<ShapedType>(), op.getSourceType());
|
||||
return produceSliceErrorMsg(result, op, expectedType);
|
||||
isRankReducedType(expectedType.cast<ShapedType>(), getSourceType());
|
||||
return produceSliceErrorMsg(result, *this, expectedType);
|
||||
}
|
||||
|
||||
/// If we have two consecutive InsertSliceOp writing to the same slice, we
|
||||
|
@ -1569,40 +1565,40 @@ ParseResult parseInferType(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
static LogicalResult verify(PadOp op) {
|
||||
auto sourceType = op.source().getType().cast<RankedTensorType>();
|
||||
auto resultType = op.result().getType().cast<RankedTensorType>();
|
||||
auto expectedType = PadOp::inferResultType(
|
||||
sourceType, extractFromI64ArrayAttr(op.static_low()),
|
||||
extractFromI64ArrayAttr(op.static_high()));
|
||||
LogicalResult PadOp::verify() {
|
||||
auto sourceType = source().getType().cast<RankedTensorType>();
|
||||
auto resultType = result().getType().cast<RankedTensorType>();
|
||||
auto expectedType =
|
||||
PadOp::inferResultType(sourceType, extractFromI64ArrayAttr(static_low()),
|
||||
extractFromI64ArrayAttr(static_high()));
|
||||
for (int i = 0, e = sourceType.getRank(); i < e; ++i) {
|
||||
if (resultType.getDimSize(i) == expectedType.getDimSize(i))
|
||||
continue;
|
||||
if (expectedType.isDynamicDim(i))
|
||||
continue;
|
||||
return op.emitError("specified type ")
|
||||
return emitError("specified type ")
|
||||
<< resultType << " does not match the inferred type "
|
||||
<< expectedType;
|
||||
}
|
||||
|
||||
auto ®ion = op.region();
|
||||
auto ®ion = getRegion();
|
||||
unsigned rank = resultType.getRank();
|
||||
Block &block = region.front();
|
||||
if (block.getNumArguments() != rank)
|
||||
return op.emitError("expected the block to have ") << rank << " arguments";
|
||||
return emitError("expected the block to have ") << rank << " arguments";
|
||||
|
||||
// Note: the number and type of yield values are checked in the YieldOp.
|
||||
for (const auto &en : llvm::enumerate(block.getArgumentTypes())) {
|
||||
if (!en.value().isIndex())
|
||||
return op.emitOpError("expected block argument ")
|
||||
return emitOpError("expected block argument ")
|
||||
<< (en.index() + 1) << " to be an index";
|
||||
}
|
||||
|
||||
// Ensure that the region yields an element of the right type.
|
||||
auto yieldOp = llvm::cast<YieldOp>(block.getTerminator());
|
||||
if (yieldOp.value().getType() !=
|
||||
op.getType().cast<ShapedType>().getElementType())
|
||||
return op.emitOpError("expected yield type to match shape element type");
|
||||
getType().cast<ShapedType>().getElementType())
|
||||
return emitOpError("expected yield type to match shape element type");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue