[mlir] Support verification order (2/3)

This change gives explicit order of verifier execution and adds
    `hasRegionVerifier` and `verifyWithRegions` to increase the granularity
    of verifier classification. The orders are as below,

    1. InternalOpTrait will be verified first, they can be run independently.
    2. `verifyInvariants` which is constructed by ODS, it verifies the type,
       attributes, .etc.
    3. Other Traits/Interfaces that have marked their verifier as
       `verifyTrait` or `verifyWithRegions=0`.
    4. Custom verifier which is defined in the op and has marked
       `hasVerifier=1`

    If an operation has regions, then it may have the second phase,

    5. Traits/Interfaces that have marked their verifier as
       `verifyRegionTrait` or
       `verifyWithRegions=1`. This implies the verifier needs to access the
       operations in its regions.
    6. Custom verifier which is defined in the op and has marked
       `hasRegionVerifier=1`

    Note that the second phase will be run after the operations in the
    region are verified. Based on the verification order, you will be able to
    avoid verifying duplicate things.

Reviewed By: Mogball

Differential Revision: https://reviews.llvm.org/D116789
This commit is contained in:
Chia-hung Duan 2022-02-25 18:17:30 +00:00
parent d04d9220e1
commit 9445b39673
28 changed files with 333 additions and 115 deletions

View File

@ -567,10 +567,39 @@ _additional_ verification, you can use
let hasVerifier = 1;
```
This will generate a `LogicalResult verify()` method declaration on the op class
that can be defined with any additional verification constraints. This method
will be invoked after the auto-generated verification code. The order of trait
verification excluding those of `hasVerifier` should not be relied upon.
or
```tablegen
let hasRegionVerifier = 1;
```
This will generate either `LogicalResult verify()` or
`LogicalResult verifyRegions()` method declaration on the op class
that can be defined with any additional verification constraints. These method
will be invoked on its verification order.
#### Verification Ordering
The verification of an operation involves several steps,
1. StructuralOpTrait will be verified first, they can be run independently.
1. `verifyInvariants` which is constructed by ODS, it verifies the type,
attributes, .etc.
1. Other Traits/Interfaces that have marked their verifier as `verifyTrait` or
`verifyWithRegions=0`.
1. Custom verifier which is defined in the op and has marked `hasVerifier=1`
If an operation has regions, then it may have the second phase,
1. Traits/Interfaces that have marked their verifier as `verifyRegionTrait` or
`verifyWithRegions=1`. This implies the verifier needs to access the
operations in its regions.
1. Custom verifier which is defined in the op and has marked
`hasRegionVerifier=1`
Note that the second phase will be run after the operations in the region are
verified. Verifiers further down the order can rely on certain invariants being
verified by a previous verifier and do not need to re-verify them.
### Declarative Assembly Format

View File

@ -36,9 +36,12 @@ class MyTrait : public TraitBase<ConcreteType, MyTrait> {
};
```
Operation traits may also provide a `verifyTrait` hook, that is called when
verifying the concrete operation. The trait verifiers will currently always be
invoked before the main `Op::verify`.
Operation traits may also provide a `verifyTrait` or `verifyRegionTrait` hook
that is called when verifying the concrete operation. The difference between
these two is that whether the verifier needs to access the regions, if so, the
operations in the regions will be verified before the verification of this
trait. The [verification order](OpDefinitions.md/#verification-ordering)
determines when a verifier will be invoked.
```c++
template <typename ConcreteType>
@ -53,8 +56,9 @@ public:
```
Note: It is generally good practice to define the implementation of the
`verifyTrait` hook out-of-line as a free function when possible to avoid
instantiating the implementation for every concrete operation type.
`verifyTrait` or `verifyRegionTrait` hook out-of-line as a free function when
possible to avoid instantiating the implementation for every concrete operation
type.
Operation traits may also provide a `foldTrait` hook that is called when folding
the concrete operation. The trait folders will only be invoked if the concrete

View File

@ -76,7 +76,7 @@ bool isTopLevelValue(Value value);
class AffineDmaStartOp
: public Op<AffineDmaStartOp, OpTrait::MemRefsNormalizable,
OpTrait::VariadicOperands, OpTrait::ZeroResult,
AffineMapAccessInterface::Trait> {
OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
public:
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
@ -227,7 +227,8 @@ public:
static StringRef getOperationName() { return "affine.dma_start"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verifyInvariants();
LogicalResult verifyInvariantsImpl();
LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
@ -268,7 +269,7 @@ public:
class AffineDmaWaitOp
: public Op<AffineDmaWaitOp, OpTrait::MemRefsNormalizable,
OpTrait::VariadicOperands, OpTrait::ZeroResult,
AffineMapAccessInterface::Trait> {
OpTrait::OpInvariants, AffineMapAccessInterface::Trait> {
public:
using Op::Op;
static ArrayRef<StringRef> getAttributeNames() { return {}; }
@ -315,7 +316,8 @@ public:
static StringRef getTagMapAttrName() { return "tag_map"; }
static ParseResult parse(OpAsmParser &parser, OperationState &result);
void print(OpAsmPrinter &p);
LogicalResult verifyInvariants();
LogicalResult verifyInvariantsImpl();
LogicalResult verifyInvariants() { return verifyInvariantsImpl(); }
LogicalResult fold(ArrayRef<Attribute> cstOperands,
SmallVectorImpl<OpFoldResult> &results);
};

View File

@ -2023,6 +2023,10 @@ class PredAttrTrait<string descr, Pred pred> : PredTrait<descr, pred>;
// OpTrait definitions
//===----------------------------------------------------------------------===//
// A trait that describes the structure of operation will be marked with
// `StructuralOpTrait` and they will be verified first.
class StructuralOpTrait;
// These classes are used to define operation specific traits.
class NativeOpTrait<string name, list<Trait> traits = []>
: NativeTrait<name, "Op"> {
@ -2053,7 +2057,8 @@ class PredOpTrait<string descr, Pred pred, list<Trait> traits = []>
// Op defines an affine scope.
def AffineScope : NativeOpTrait<"AffineScope">;
// Op defines an automatic allocation scope.
def AutomaticAllocationScope : NativeOpTrait<"AutomaticAllocationScope">;
def AutomaticAllocationScope :
NativeOpTrait<"AutomaticAllocationScope">;
// Op supports operand broadcast behavior.
def ResultsBroadcastableShape :
NativeOpTrait<"ResultsBroadcastableShape">;
@ -2074,9 +2079,11 @@ def SameTypeOperands : NativeOpTrait<"SameTypeOperands">;
// Op has same shape for all operands.
def SameOperandsShape : NativeOpTrait<"SameOperandsShape">;
// Op has same operand and result shape.
def SameOperandsAndResultShape : NativeOpTrait<"SameOperandsAndResultShape">;
def SameOperandsAndResultShape :
NativeOpTrait<"SameOperandsAndResultShape">;
// Op has the same element type (or type itself, if scalar) for all operands.
def SameOperandsElementType : NativeOpTrait<"SameOperandsElementType">;
def SameOperandsElementType :
NativeOpTrait<"SameOperandsElementType">;
// Op has the same operand and result element type (or type itself, if scalar).
def SameOperandsAndResultElementType :
NativeOpTrait<"SameOperandsAndResultElementType">;
@ -2104,21 +2111,23 @@ def ElementwiseMappable : TraitList<[
]>;
// Op's regions have a single block.
def SingleBlock : NativeOpTrait<"SingleBlock">;
def SingleBlock : NativeOpTrait<"SingleBlock">, StructuralOpTrait;
// Op's regions have a single block with the specified terminator.
class SingleBlockImplicitTerminator<string op>
: ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>;
: ParamNativeOpTrait<"SingleBlockImplicitTerminator", op>,
StructuralOpTrait;
// Op's regions don't have terminator.
def NoTerminator : NativeOpTrait<"NoTerminator">;
def NoTerminator : NativeOpTrait<"NoTerminator">, StructuralOpTrait;
// Op's parent operation is the provided one.
class HasParent<string op>
: ParamNativeOpTrait<"HasParent", op>;
: ParamNativeOpTrait<"HasParent", op>, StructuralOpTrait;
class ParentOneOf<list<string> ops>
: ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>;
: ParamNativeOpTrait<"HasParent", !interleave(ops, ", ")>,
StructuralOpTrait;
// Op result type is derived from the first attribute. If the attribute is an
// subclass of `TypeAttrBase`, its value is used, otherwise, the type of the
@ -2147,13 +2156,15 @@ def SameVariadicResultSize : GenInternalOpTrait<"SameVariadicResultSize">;
// vector that has the same number of elements as the number of ODS declared
// operands. That means even if some operands are non-variadic, the attribute
// still need to have an element for its size, which is always 1.
def AttrSizedOperandSegments : NativeOpTrait<"AttrSizedOperandSegments">;
def AttrSizedOperandSegments :
NativeOpTrait<"AttrSizedOperandSegments">, StructuralOpTrait;
// Similar to AttrSizedOperandSegments, but used for results. The attribute
// should be named as `result_segment_sizes`.
def AttrSizedResultSegments : NativeOpTrait<"AttrSizedResultSegments">;
def AttrSizedResultSegments :
NativeOpTrait<"AttrSizedResultSegments">, StructuralOpTrait;
// Op attached regions have no arguments
def NoRegionArguments : NativeOpTrait<"NoRegionArguments">;
def NoRegionArguments : NativeOpTrait<"NoRegionArguments">, StructuralOpTrait;
//===----------------------------------------------------------------------===//
// OpInterface definitions
@ -2191,6 +2202,11 @@ class OpInterfaceTrait<string name, code verifyBody = [{}],
// the operation being verified.
code verify = verifyBody;
// A bit indicating if the verifier needs to access the ops in the regions. If
// it set to `1`, the region ops will be verified before invoking this
// verifier.
bit verifyWithRegions = 0;
// Specify the list of traits that need to be verified before the verification
// of this OpInterfaceTrait.
list<Trait> dependentTraits = traits;
@ -2467,6 +2483,16 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
// operation class. The operation should implement this method and verify the
// additional necessary invariants.
bit hasVerifier = 0;
// A bit indicating if the operation has additional invariants that need to
// verified and which associate with regions (aside from those verified by the
// traits). If set to `1`, an additional `LogicalResult verifyRegions()`
// declaration will be generated on the operation class. The operation should
// implement this method and verify the additional necessary invariants
// associated with regions. Note that this method is invoked after all the
// region ops are verified.
bit hasRegionVerifier = 0;
// A custom code block corresponding to the extra verification code of the
// operation.
// NOTE: This field is deprecated in favor of `hasVerifier` and is slated for

View File

@ -200,7 +200,8 @@ public:
protected:
/// If the concrete type didn't implement a custom verifier hook, just fall
/// back to this one which accepts everything.
LogicalResult verifyInvariants() { return success(); }
LogicalResult verify() { return success(); }
LogicalResult verifyRegions() { return success(); }
/// Parse the custom form of an operation. Unless overridden, this method will
/// first try to get an operation parser from the op's dialect. Otherwise the
@ -376,6 +377,18 @@ struct MultiOperandTraitBase : public TraitBase<ConcreteType, TraitType> {
};
} // namespace detail
/// `verifyInvariantsImpl` verifies the invariants like the types, attrs, .etc.
/// It should be run after core traits and before any other user defined traits.
/// In order to run it in the correct order, wrap it with OpInvariants trait so
/// that tblgen will be able to put it in the right order.
template <typename ConcreteType>
class OpInvariants : public TraitBase<ConcreteType, OpInvariants> {
public:
static LogicalResult verifyTrait(Operation *op) {
return cast<ConcreteType>(op).verifyInvariantsImpl();
}
};
/// This class provides the API for ops that are known to have no
/// SSA operand.
template <typename ConcreteType>
@ -1572,6 +1585,14 @@ using has_verify_trait = decltype(T::verifyTrait(std::declval<Operation *>()));
template <typename T>
using detect_has_verify_trait = llvm::is_detected<has_verify_trait, T>;
/// Trait to check if T provides a `verifyTrait` method.
template <typename T, typename... Args>
using has_verify_region_trait =
decltype(T::verifyRegionTrait(std::declval<Operation *>()));
template <typename T>
using detect_has_verify_region_trait =
llvm::is_detected<has_verify_region_trait, T>;
/// The internal implementation of `verifyTraits` below that returns the result
/// of verifying the current operation with all of the provided trait types
/// `Ts`.
@ -1589,6 +1610,26 @@ template <typename TraitTupleT>
static LogicalResult verifyTraits(Operation *op) {
return verifyTraitsImpl(op, (TraitTupleT *)nullptr);
}
/// The internal implementation of `verifyRegionTraits` below that returns the
/// result of verifying the current operation with all of the provided trait
/// types `Ts`.
template <typename... Ts>
static LogicalResult verifyRegionTraitsImpl(Operation *op,
std::tuple<Ts...> *) {
LogicalResult result = success();
(void)std::initializer_list<int>{
(result = succeeded(result) ? Ts::verifyRegionTrait(op) : failure(),
0)...};
return result;
}
/// Given a tuple type containing a set of traits that contain a
/// `verifyTrait` method, return the result of verifying the given operation.
template <typename TraitTupleT>
static LogicalResult verifyRegionTraits(Operation *op) {
return verifyRegionTraitsImpl(op, (TraitTupleT *)nullptr);
}
} // namespace op_definition_impl
//===----------------------------------------------------------------------===//
@ -1603,7 +1644,8 @@ class Op : public OpState, public Traits<ConcreteType>... {
public:
/// Inherit getOperation from `OpState`.
using OpState::getOperation;
using OpState::verifyInvariants;
using OpState::verify;
using OpState::verifyRegions;
/// Return if this operation contains the provided trait.
template <template <typename T> class Trait>
@ -1704,6 +1746,10 @@ private:
using VerifiableTraitsTupleT =
typename detail::FilterTypes<op_definition_impl::detect_has_verify_trait,
Traits<ConcreteType>...>::type;
/// A tuple type containing the region traits that have a verify function.
using VerifiableRegionTraitsTupleT = typename detail::FilterTypes<
op_definition_impl::detect_has_verify_region_trait,
Traits<ConcreteType>...>::type;
/// Returns an interface map containing the interfaces registered to this
/// operation.
@ -1839,11 +1885,22 @@ private:
"Op class shouldn't define new data members");
return failure(
failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
failed(cast<ConcreteType>(op).verifyInvariants()));
failed(cast<ConcreteType>(op).verify()));
}
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
}
/// Implementation of `VerifyRegionInvariantsFn` OperationName hook.
static LogicalResult verifyRegionInvariants(Operation *op) {
static_assert(hasNoDataMembers(),
"Op class shouldn't define new data members");
return failure(failed(op_definition_impl::verifyRegionTraits<
VerifiableRegionTraitsTupleT>(op)) ||
failed(cast<ConcreteType>(op).verifyRegions()));
}
static OperationName::VerifyRegionInvariantsFn getVerifyRegionInvariantsFn() {
return static_cast<LogicalResult (*)(Operation *)>(&verifyRegionInvariants);
}
static constexpr bool hasNoDataMembers() {
// Checking that the derived class does not define any member by comparing

View File

@ -73,6 +73,8 @@ public:
llvm::unique_function<void(Operation *, OpAsmPrinter &, StringRef) const>;
using VerifyInvariantsFn =
llvm::unique_function<LogicalResult(Operation *) const>;
using VerifyRegionInvariantsFn =
llvm::unique_function<LogicalResult(Operation *) const>;
protected:
/// This class represents a type erased version of an operation. It contains
@ -112,6 +114,7 @@ protected:
ParseAssemblyFn parseAssemblyFn;
PrintAssemblyFn printAssemblyFn;
VerifyInvariantsFn verifyInvariantsFn;
VerifyRegionInvariantsFn verifyRegionInvariantsFn;
/// A list of attribute names registered to this operation in StringAttr
/// form. This allows for operation classes to use StringAttr for attribute
@ -238,16 +241,18 @@ public:
static void insert(Dialect &dialect) {
insert(T::getOperationName(), dialect, TypeID::get<T>(),
T::getParseAssemblyFn(), T::getPrintAssemblyFn(),
T::getVerifyInvariantsFn(), T::getFoldHookFn(),
T::getGetCanonicalizationPatternsFn(), T::getInterfaceMap(),
T::getHasTraitFn(), T::getAttributeNames());
T::getVerifyInvariantsFn(), T::getVerifyRegionInvariantsFn(),
T::getFoldHookFn(), T::getGetCanonicalizationPatternsFn(),
T::getInterfaceMap(), T::getHasTraitFn(), T::getAttributeNames());
}
/// The use of this method is in general discouraged in favor of
/// 'insert<CustomOp>(dialect)'.
static void
insert(StringRef name, Dialect &dialect, TypeID typeID,
ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
VerifyInvariantsFn &&verifyInvariants,
VerifyRegionInvariantsFn &&verifyRegionInvariants,
FoldHookFn &&foldHook,
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
ArrayRef<StringRef> attrNames);
@ -272,12 +277,15 @@ public:
return impl->printAssemblyFn(op, p, defaultDialect);
}
/// This hook implements the verifier for this operation. It should emits an
/// error message and returns failure if a problem is detected, or returns
/// These hooks implement the verifiers for this operation. It should emits
/// an error message and returns failure if a problem is detected, or returns
/// success if everything is ok.
LogicalResult verifyInvariants(Operation *op) const {
return impl->verifyInvariantsFn(op);
}
LogicalResult verifyRegionInvariants(Operation *op) const {
return impl->verifyRegionInvariantsFn(op);
}
/// This hook implements a generalized folder for this operation. Operations
/// can implement this to provide simplifications rules that are applied by

View File

@ -98,6 +98,10 @@ public:
// Return the verify method body if it has one.
llvm::Optional<StringRef> getVerify() const;
// If there's a verify method, return if it needs to access the ops in the
// regions.
bool verifyWithRegions() const;
// Returns the Tablegen definition this interface was constructed from.
const llvm::Record &getDef() const { return *def; }

View File

@ -65,6 +65,9 @@ public:
// Returns the trait corresponding to a C++ trait class.
std::string getFullyQualifiedTraitName() const;
// Returns if this is a structural op trait.
bool isStructuralOpTrait() const;
static bool classof(const Trait *t) { return t->getKind() == Kind::Native; }
};

View File

@ -1117,7 +1117,7 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
return success();
}
LogicalResult AffineDmaStartOp::verifyInvariants() {
LogicalResult AffineDmaStartOp::verifyInvariantsImpl() {
if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
return emitOpError("expected DMA source to be of memref type");
if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
@ -1219,7 +1219,7 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
return success();
}
LogicalResult AffineDmaWaitOp::verifyInvariants() {
LogicalResult AffineDmaWaitOp::verifyInvariantsImpl() {
if (!getOperand(0).getType().isa<MemRefType>())
return emitOpError("expected DMA tag to be of memref type");
Region *scope = getAffineScope(*this);

View File

@ -693,7 +693,8 @@ RegisteredOperationName::parseAssembly(OpAsmParser &parser,
void RegisteredOperationName::insert(
StringRef name, Dialect &dialect, TypeID typeID,
ParseAssemblyFn &&parseAssembly, PrintAssemblyFn &&printAssembly,
VerifyInvariantsFn &&verifyInvariants, FoldHookFn &&foldHook,
VerifyInvariantsFn &&verifyInvariants,
VerifyRegionInvariantsFn &&verifyRegionInvariants, FoldHookFn &&foldHook,
GetCanonicalizationPatternsFn &&getCanonicalizationPatterns,
detail::InterfaceMap &&interfaceMap, HasTraitFn &&hasTrait,
ArrayRef<StringRef> attrNames) {
@ -749,6 +750,7 @@ void RegisteredOperationName::insert(
impl.parseAssemblyFn = std::move(parseAssembly);
impl.printAssemblyFn = std::move(printAssembly);
impl.verifyInvariantsFn = std::move(verifyInvariants);
impl.verifyRegionInvariantsFn = std::move(verifyRegionInvariants);
impl.attributeNames = cachedAttrNames;
}

View File

@ -217,6 +217,11 @@ LogicalResult OperationVerifier::verifyOperation(
}
}
// After the region ops are verified, run the verifiers that have additional
// region invariants need to veirfy.
if (registeredInfo && failed(registeredInfo->verifyRegionInvariants(&op)))
return failure();
// If this is a registered operation, there is nothing left to do.
if (registeredInfo)
return success();

View File

@ -125,6 +125,10 @@ llvm::Optional<StringRef> Interface::getVerify() const {
return value.empty() ? llvm::Optional<StringRef>() : value;
}
bool Interface::verifyWithRegions() const {
return def->getValueAsBit("verifyWithRegions");
}
//===----------------------------------------------------------------------===//
// AttrInterface
//===----------------------------------------------------------------------===//

View File

@ -50,6 +50,10 @@ std::string NativeTrait::getFullyQualifiedTraitName() const {
: (cppNamespace + "::" + trait).str();
}
bool NativeTrait::isStructuralOpTrait() const {
return def->isSubClassOf("StructuralOpTrait");
}
//===----------------------------------------------------------------------===//
// InternalTrait
//===----------------------------------------------------------------------===//

View File

@ -168,7 +168,7 @@ func @func_with_ops(i32, i32) {
func @func_with_ops() {
^bb0:
%c = arith.constant dense<0> : vector<42 x i32>
// expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
// expected-error@+1 {{op failed to verify that result type has i1 element type and same shape as operands}}
%r = "arith.cmpi"(%c, %c) {predicate = 0} : (vector<42 x i32>, vector<42 x i32>) -> vector<41 x i1>
}
@ -249,7 +249,7 @@ func @cmpf_canonical_wrong_result_type(%a : f32, %b : f32) -> f32 {
// -----
func @cmpf_result_shape_mismatch(%a : vector<42xf32>) {
// expected-error@+1 {{all non-scalar operands/results must have the same shape and base type}}
// expected-error@+1 {{op failed to verify that result type has i1 element type and same shape as operands}}
%r = "arith.cmpf"(%a, %a) {predicate = 0} : (vector<42 x f32>, vector<42 x f32>) -> vector<41 x i1>
}
@ -285,7 +285,7 @@ func @index_cast_index_to_index(%arg0: index) {
// -----
func @index_cast_float(%arg0: index, %arg1: f32) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op result #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
%0 = arith.index_cast %arg0 : index to f32
return
}
@ -293,7 +293,7 @@ func @index_cast_float(%arg0: index, %arg1: f32) {
// -----
func @index_cast_float_to_index(%arg0: f32) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op operand #0 must be signless-integer-like or memref of signless-integer, but got 'f32'}}
%0 = arith.index_cast %arg0 : f32 to index
return
}
@ -301,7 +301,7 @@ func @index_cast_float_to_index(%arg0: f32) {
// -----
func @sitofp_i32_to_i64(%arg0 : i32) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op result #0 must be floating-point-like, but got 'i64'}}
%0 = arith.sitofp %arg0 : i32 to i64
return
}
@ -309,7 +309,7 @@ func @sitofp_i32_to_i64(%arg0 : i32) {
// -----
func @sitofp_f32_to_i32(%arg0 : f32) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'f32'}}
%0 = arith.sitofp %arg0 : f32 to i32
return
}
@ -333,7 +333,7 @@ func @fpext_f16_to_f16(%arg0 : f16) {
// -----
func @fpext_i32_to_f32(%arg0 : i32) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op operand #0 must be floating-point-like, but got 'i32'}}
%0 = arith.extf %arg0 : i32 to f32
return
}
@ -341,7 +341,7 @@ func @fpext_i32_to_f32(%arg0 : i32) {
// -----
func @fpext_f32_to_i32(%arg0 : f32) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op result #0 must be floating-point-like, but got 'i32'}}
%0 = arith.extf %arg0 : f32 to i32
return
}
@ -373,7 +373,7 @@ func @fpext_vec_f16_to_f16(%arg0 : vector<2xf16>) {
// -----
func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op operand #0 must be floating-point-like, but got 'vector<2xi32>'}}
%0 = arith.extf %arg0 : vector<2xi32> to vector<2xf32>
return
}
@ -381,7 +381,7 @@ func @fpext_vec_i32_to_f32(%arg0 : vector<2xi32>) {
// -----
func @fpext_vec_f32_to_i32(%arg0 : vector<2xf32>) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op result #0 must be floating-point-like, but got 'vector<2xi32>'}}
%0 = arith.extf %arg0 : vector<2xf32> to vector<2xi32>
return
}
@ -405,7 +405,7 @@ func @fptrunc_f32_to_f32(%arg0 : f32) {
// -----
func @fptrunc_i32_to_f32(%arg0 : i32) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op operand #0 must be floating-point-like, but got 'i32'}}
%0 = arith.truncf %arg0 : i32 to f32
return
}
@ -413,7 +413,7 @@ func @fptrunc_i32_to_f32(%arg0 : i32) {
// -----
func @fptrunc_f32_to_i32(%arg0 : f32) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op result #0 must be floating-point-like, but got 'i32'}}
%0 = arith.truncf %arg0 : f32 to i32
return
}
@ -445,7 +445,7 @@ func @fptrunc_vec_f32_to_f32(%arg0 : vector<2xf32>) {
// -----
func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op operand #0 must be floating-point-like, but got 'vector<2xi32>'}}
%0 = arith.truncf %arg0 : vector<2xi32> to vector<2xf32>
return
}
@ -453,7 +453,7 @@ func @fptrunc_vec_i32_to_f32(%arg0 : vector<2xi32>) {
// -----
func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op result #0 must be floating-point-like, but got 'vector<2xi32>'}}
%0 = arith.truncf %arg0 : vector<2xf32> to vector<2xi32>
return
}
@ -461,7 +461,7 @@ func @fptrunc_vec_f32_to_i32(%arg0 : vector<2xf32>) {
// -----
func @sexti_index_as_operand(%arg0 : index) {
// expected-error@+1 {{are cast incompatible}}
// expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
%0 = arith.extsi %arg0 : index to i128
return
}
@ -469,7 +469,7 @@ func @sexti_index_as_operand(%arg0 : index) {
// -----
func @zexti_index_as_operand(%arg0 : index) {
// expected-error@+1 {{operand type 'index' and result type}}
// expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
%0 = arith.extui %arg0 : index to i128
return
}
@ -477,7 +477,7 @@ func @zexti_index_as_operand(%arg0 : index) {
// -----
func @trunci_index_as_operand(%arg0 : index) {
// expected-error@+1 {{operand type 'index' and result type}}
// expected-error@+1 {{op operand #0 must be signless-fixed-width-integer-like, but got 'index'}}
%2 = arith.trunci %arg0 : index to i128
return
}
@ -485,7 +485,7 @@ func @trunci_index_as_operand(%arg0 : index) {
// -----
func @sexti_index_as_result(%arg0 : i1) {
// expected-error@+1 {{result type 'index' are cast incompatible}}
// expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
%0 = arith.extsi %arg0 : i1 to index
return
}
@ -493,7 +493,7 @@ func @sexti_index_as_result(%arg0 : i1) {
// -----
func @zexti_index_as_operand(%arg0 : i1) {
// expected-error@+1 {{result type 'index' are cast incompatible}}
// expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
%0 = arith.extui %arg0 : i1 to index
return
}
@ -501,7 +501,7 @@ func @zexti_index_as_operand(%arg0 : i1) {
// -----
func @trunci_index_as_result(%arg0 : i128) {
// expected-error@+1 {{result type 'index' are cast incompatible}}
// expected-error@+1 {{op result #0 must be signless-fixed-width-integer-like, but got 'index'}}
%2 = arith.trunci %arg0 : i128 to index
return
}

View File

@ -301,7 +301,7 @@ func @reduce_incorrect_yield(%arg0 : f32) {
// -----
func @shuffle_mismatching_type(%arg0 : f32, %arg1 : i32, %arg2 : i32) {
// expected-error@+1 {{inferred type(s) 'f32', 'i1' are incompatible with return type(s) of operation 'i32', 'i1'}}
// expected-error@+1 {{op failed to verify that all of {value, result} have same type}}
%shfl, %pred = "gpu.shuffle"(%arg0, %arg1, %arg2) { mode = #gpu<"shuffle_mode xor"> } : (f32, i32, i32) -> (i32, i1)
return
}

View File

@ -80,8 +80,8 @@ llvm.mlir.global internal constant @sectionvar("teststring") {section = ".mysec
// -----
// expected-error @+1 {{requires string attribute 'sym_name'}}
"llvm.mlir.global"() ({}) {type = i64, constant, value = 42 : i64} : () -> ()
// expected-error @+1 {{op requires attribute 'sym_name'}}
"llvm.mlir.global"() ({}) {type = i64, constant, global_type = i64, value = 42 : i64} : () -> ()
// -----

View File

@ -214,15 +214,15 @@ func @generic_shaped_operand_block_arg_type(%arg0: memref<f32>) {
// -----
func @generic_scalar_operand_block_arg_type(%arg0: f32) {
func @generic_scalar_operand_block_arg_type(%arg0: tensor<f32>) {
// expected-error @+1 {{expected type of bb argument #0 ('i1') to match element or self type of the corresponding operand ('f32')}}
linalg.generic {
indexing_maps = [ affine_map<() -> ()> ],
iterator_types = []}
outs(%arg0 : f32) {
outs(%arg0 : tensor<f32>) {
^bb(%i: i1):
linalg.yield %i : i1
}
} -> tensor<f32>
}
// -----
@ -243,7 +243,7 @@ func @generic_result_0_element_type(%arg0: memref<?xf32, affine_map<(i)[off]->(o
func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off + i)>>,
%arg1: tensor<?xf32>) {
// expected-error @+1 {{expected type of operand #1 ('tensor<?xf32>') to match type of corresponding result ('f32')}}
// expected-error @+1 {{expected type of operand #1 ('tensor<?xf32>') to match type of corresponding result ('tensor<f32>')}}
%0 = linalg.generic {
indexing_maps = [ affine_map<(i) -> (i)> , affine_map<(i) -> (i)> ],
iterator_types = ["parallel"]}
@ -251,7 +251,7 @@ func @generic_result_tensor_type(%arg0: memref<?xf32, affine_map<(i)[off]->(off
outs(%arg1 : tensor<?xf32>) {
^bb(%i: f32, %j: f32):
linalg.yield %i: f32
} -> f32
} -> tensor<f32>
}
// -----
@ -362,11 +362,11 @@ func @illegal_fill_tensor_no_return(%arg0 : index, %arg1 : index, %arg2 : f32)
// -----
func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
func @illegal_fill_memref_with_return(%arg0 : memref<?x?xf32>, %arg1 : f32) -> tensor<?x?xf32>
{
// expected-error @+1 {{expected the number of results (1) to be equal to the number of output tensors (0)}}
%0 = linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32> -> memref<?x?xf32>
return %0 : memref<?x?xf32>
// expected-error @+1 {{op expected the number of results (1) to be equal to the number of output tensors (0)}}
%0 = linalg.fill(%arg1, %arg0) : f32, memref<?x?xf32> -> tensor<?x?xf32>
return %0 : tensor<?x?xf32>
}
// -----
@ -384,7 +384,7 @@ func @illegal_fill_memref_with_tensor_return
func @illegal_fill_tensor_with_memref_return
(%arg0 : tensor<?x?xf32>, %arg1 : f32) -> memref<?x?xf32>
{
// expected-error @+1 {{expected type of operand #1 ('tensor<?x?xf32>') to match type of corresponding result ('memref<?x?xf32>')}}
// expected-error @+1 {{op result #0 must be ranked tensor of any type values, but got 'memref<?x?xf32>'}}
%0 = linalg.fill(%arg1, %arg0) : f32, tensor<?x?xf32> -> memref<?x?xf32>
return %0 : memref<?x?xf32>
}
@ -477,7 +477,7 @@ func @tiled_loop_incorrent_iterator_types_count(%A: memref<192x192xf32>,
%c0 = arith.constant 0 : index
%c192 = arith.constant 192 : index
// expected-error @+1 {{expected iterator types array attribute size = 1 to match the number of loops = 2}}
%0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ({
%0 = "linalg.tiled_loop"(%c0, %c0, %c192, %c192, %c24, %c24, %A, %B, %C_tensor, %C) ( {
^bb0(%arg4: index, %arg5: index, %A_: memref<192x192xf32>,
%B_: memref<192x192xf32>, %CT_: tensor<192x192xf32>,
%C_: memref<192x192xf32>):
@ -502,7 +502,7 @@ func @tiled_loop_incorrent_block_arg_type(%A: memref<192xf32>) {
%c192 = arith.constant 192 : index
%c24 = arith.constant 24 : index
// expected-error @+1 {{expected output arg 0 with type = 'memref<192xf32>' to match region arg 1 type = 'memref<100xf32>'}}
"linalg.tiled_loop"(%c0, %c192, %c24, %A) ({
"linalg.tiled_loop"(%c0, %c192, %c24, %A) ( {
^bb0(%arg4: index, %A_: memref<100xf32>):
call @foo(%A_) : (memref<100xf32>)-> ()
linalg.yield

View File

@ -111,7 +111,7 @@ func @depthwise_conv_2d_input_nhwc_filter_default_attributes(%input: memref<1x11
// -----
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect element type for index attribute 'strides'}}
// expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}}
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2.0> : vector<2xf32>}
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)
@ -121,7 +121,7 @@ func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_element_type(%input: memr
// -----
func @depthwise_conv_2d_input_nhwc_filter_wrong_stride_size(%input: memref<1x113x113x96xf32>, %filter: memref<3x3x96xf32>, %output: memref<1x56x56x96xf32>) {
// expected-error @+1 {{incorrect shape for index attribute 'strides'}}
// expected-error @+1 {{op attribute 'strides' failed to satisfy constraint: 64-bit signless int elements attribute of shape [2]}}
linalg.depthwise_conv_2d_nhwc_hwc {dilations = dense<1> : vector<2xi64>, strides = dense<2> : vector<3xi64> }
ins(%input, %filter: memref<1x113x113x96xf32>, memref<3x3x96xf32>)
outs(%output: memref<1x56x56x96xf32>)

View File

@ -59,7 +59,7 @@ func @bit_field_u_extract_vec(%base: vector<3xi32>, %offset: i8, %count: i8) ->
// -----
func @bit_field_u_extract_invalid_result_type(%base: vector<3xi32>, %offset: i32, %count: i16) -> vector<4xi32> {
// expected-error @+1 {{inferred type(s) 'vector<3xi32>' are incompatible with return type(s) of operation 'vector<4xi32>'}}
// expected-error @+1 {{failed to verify that all of {base, result} have same type}}
%0 = "spv.BitFieldUExtract" (%base, %offset, %count) : (vector<3xi32>, i32, i16) -> vector<4xi32>
spv.ReturnValue %0 : vector<4xi32>
}
@ -181,7 +181,7 @@ func @shift_left_logical(%arg0: i32, %arg1 : i16) -> i32 {
// -----
func @shift_left_logical_invalid_result_type(%arg0: i32, %arg1 : i16) -> i16 {
// expected-error @+1 {{op inferred type(s) 'i32' are incompatible with return type(s) of operation 'i16'}}
// expected-error @+1 {{op failed to verify that all of {operand1, result} have same type}}
%0 = "spv.ShiftLeftLogical" (%arg0, %arg1) : (i32, i16) -> (i16)
spv.ReturnValue %0 : i16
}

View File

@ -98,8 +98,8 @@ func @shape_of(%value_arg : !shape.value_shape,
// -----
func @shape_of_incompatible_return_types(%value_arg : tensor<1x2xindex>) {
// expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xf32>'}}
%0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xf32>
// expected-error@+1 {{'shape.shape_of' op inferred type(s) 'tensor<2xindex>' are incompatible with return type(s) of operation 'tensor<3xindex>'}}
%0 = shape.shape_of %value_arg : tensor<1x2xindex> -> tensor<3xindex>
return
}

View File

@ -58,7 +58,7 @@ func @broadcast_tensor_tensor_tensor(tensor<8x1x?x1xi32>, tensor<7x1x5xi32>) ->
// Check incompatible vector and tensor result type
func @broadcast_scalar_vector_vector(tensor<4xf32>, tensor<4xf32>) -> vector<4xf32> {
^bb0(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>):
// expected-error @+1 {{cannot broadcast vector with tensor}}
// expected-error @+1 {{op result #0 must be tensor of any type values, but got 'vector<4xf32>'}}
%0 = "test.broadcastable"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> vector<4xf32>
return %0 : vector<4xf32>
}

View File

@ -3,7 +3,7 @@
// -----
func @module_op() {
// expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}}
// expected-error@+1 {{'builtin.module' op expects region #0 to have 0 or 1 blocks}}
builtin.module {
^bb1:
"test.dummy"() : () -> ()

View File

@ -332,12 +332,12 @@ func @failedSingleBlockImplicitTerminator_missing_terminator() {
// Test the invariants of operations with the Symbol Trait.
// expected-error@+1 {{requires string attribute 'sym_name'}}
// expected-error@+1 {{op requires attribute 'sym_name'}}
"test.symbol"() {} : () -> ()
// -----
// expected-error@+1 {{requires visibility attribute 'sym_visibility' to be a string attribute}}
// expected-error@+1 {{op attribute 'sym_visibility' failed to satisfy constraint: string attribute}}
"test.symbol"() {sym_name = "foo_2", sym_visibility} : () -> ()
// -----
@ -364,7 +364,7 @@ func private @foo()
// -----
// Test that operation with the SymbolTable Trait fails with too many blocks.
// expected-error@+1 {{Operations with a 'SymbolTable' must have exactly one block}}
// expected-error@+1 {{op expects region #0 to have 0 or 1 blocks}}
"test.symbol_scope"() ({
^entry:
"test.finish" () : () -> ()
@ -668,4 +668,4 @@ func @failed_attr_traits() {
// expected-error@+1 {{'attr' attribute should have trait 'TestAttrTrait'}}
"test.attr_with_trait"() {attr = 42 : i32} : () -> ()
return
}
}

View File

@ -68,7 +68,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
// CHECK: ::mlir::ValueRange odsOperands;
// CHECK: };
// CHECK: class AOp : public ::mlir::Op<AOp, ::mlir::OpTrait::AtLeastNRegions<1>::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::IsIsolatedFromAbove
// CHECK: class AOp : public ::mlir::Op<AOp, ::mlir::OpTrait::AtLeastNRegions<1>::Impl, ::mlir::OpTrait::AtLeastNResults<1>::Impl, ::mlir::OpTrait::ZeroSuccessor, ::mlir::OpTrait::AtLeastNOperands<1>::Impl, ::mlir::OpTrait::OpInvariants, ::mlir::OpTrait::IsIsolatedFromAbove
// CHECK-NOT: ::mlir::OpTrait::IsIsolatedFromAbove
// CHECK: public:
// CHECK: using Op::Op;

View File

@ -42,6 +42,19 @@ def TestOpInterface : OpInterface<"TestOpInterface"> {
];
}
def TestOpInterfaceVerify : OpInterface<"TestOpInterfaceVerify"> {
let verify = [{
return foo();
}];
}
def TestOpInterfaceVerifyRegion : OpInterface<"TestOpInterfaceVerifyRegion"> {
let verify = [{
return foo();
}];
let verifyWithRegions = 1;
}
// Define Ops with TestOpInterface and
// DeclareOpInterfaceMethods<TestOpInterface> traits to check that there
// are not duplicated C++ classes generated.
@ -65,6 +78,12 @@ def DeclareMethodsWithDefaultOp : Op<TestDialect, "declare_methods_op",
// DECL: template<typename ConcreteOp>
// DECL: int detail::TestOpInterfaceInterfaceTraits::Model<ConcreteOp>::foo
// DECL-LABEL: struct TestOpInterfaceVerifyTrait
// DECL: verifyTrait
// DECL-LABEL: struct TestOpInterfaceVerifyRegionTrait
// DECL: verifyRegionTrait
// OP_DECL-LABEL: class DeclareMethodsOp : public
// OP_DECL: int foo(int input);
// OP_DECL-NOT: int default_foo(int input);

View File

@ -58,7 +58,7 @@ func @complex_f64_tensor_success() {
// -----
func @complex_f64_failure() {
// expected-error@+1 {{op inferred type(s) 'complex<f64>' are incompatible with return type(s) of operation 'f64'}}
// expected-error@+1 {{op result #0 must be complex type with 64-bit float elements, but got 'f64'}}
"test.complex_f64"() : () -> (f64)
return
}
@ -438,7 +438,7 @@ func @operand_rank_equals_result_size_failure(%arg : tensor<1x2x3x4xi32>) {
// -----
func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>) {
// expected-error@+1 {{op inferred type(s) 'tensor<*xi32>' are incompatible with return type(s) of operation 'tensor<*xf32>'}}
// expected-error@+1 {{op failed to verify that all of {x, res} have same type}}
"test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<* x i32>, tensor<* x f32>) -> tensor<* x f32>
return
}
@ -446,7 +446,7 @@ func @same_types_element_mismatch(%arg0: tensor<* x i32>, %arg1: tensor<* x f32>
// -----
func @same_types_shape_mismatch(%arg0: tensor<1x2xi32>, %arg1: tensor<2x1xi32>) {
// expected-error@+1 {{op inferred type(s) 'tensor<1x2xi32>' are incompatible with return type(s) of operation 'tensor<2x1xi32>'}}
// expected-error@+1 {{op failed to verify that all of {x, res} have same type}}
"test.operand0_and_result_have_same_type"(%arg0, %arg1) : (tensor<1x2xi32>, tensor<2x1xi32>) -> tensor<2x1xi32>
return
}

View File

@ -394,6 +394,9 @@ private:
// Generates verify method for the operation.
void genVerifier();
// Generates custom verify methods for the operation.
void genCustomVerifier();
// Generates verify statements for operands and results in the operation.
// The generated code will be attached to `body`.
void genOperandResultVerifier(MethodBody &body,
@ -593,6 +596,7 @@ OpEmitter::OpEmitter(const Operator &op,
genParser();
genPrinter();
genVerifier();
genCustomVerifier();
genCanonicalizerDecls();
genFolderDecls();
genTypeInterfaceMethods();
@ -2236,47 +2240,76 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
}
void OpEmitter::genVerifier() {
auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
ERROR_IF_PRUNED(method, "verifyInvariants", op);
auto &body = method->body();
auto *implMethod =
opClass.addMethod("::mlir::LogicalResult", "verifyInvariantsImpl");
ERROR_IF_PRUNED(implMethod, "verifyInvariantsImpl", op);
auto &implBody = implMethod->body();
OpOrAdaptorHelper emitHelper(op, /*isOp=*/true);
genNativeTraitAttrVerifier(body, emitHelper);
genNativeTraitAttrVerifier(implBody, emitHelper);
auto *valueInit = def.getValueInit("verifier");
StringInit *stringInit = dyn_cast<StringInit>(valueInit);
bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
populateSubstitutions(emitHelper, verifyCtx);
genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
genOperandResultVerifier(body, op.getOperands(), "operand");
genOperandResultVerifier(body, op.getResults(), "result");
genAttributeVerifier(emitHelper, verifyCtx, implBody, staticVerifierEmitter);
genOperandResultVerifier(implBody, op.getOperands(), "operand");
genOperandResultVerifier(implBody, op.getResults(), "result");
for (auto &trait : op.getTraits()) {
if (auto *t = dyn_cast<tblgen::PredTrait>(&trait)) {
body << tgfmt(" if (!($0))\n "
"return emitOpError(\"failed to verify that $1\");\n",
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
t->getSummary());
implBody << tgfmt(" if (!($0))\n "
"return emitOpError(\"failed to verify that $1\");\n",
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
t->getSummary());
}
}
genRegionVerifier(body);
genSuccessorVerifier(body);
genRegionVerifier(implBody);
genSuccessorVerifier(implBody);
implBody << " return ::mlir::success();\n";
// TODO: Some places use the `verifyInvariants` to do operation verification.
// This may not act as their expectation because this doesn't call any
// verifiers of native/interface traits. Needs to review those use cases and
// see if we should use the mlir::verify() instead.
auto *valueInit = def.getValueInit("verifier");
StringInit *stringInit = dyn_cast<StringInit>(valueInit);
bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
ERROR_IF_PRUNED(method, "verifyInvariants", op);
auto &body = method->body();
if (hasCustomVerifyCodeBlock || def.getValueAsBit("hasVerifier")) {
body << " if(::mlir::succeeded(verifyInvariantsImpl()) && "
"::mlir::succeeded(verify()))\n";
body << " return ::mlir::success();\n";
body << " return ::mlir::failure();";
} else {
body << " return verifyInvariantsImpl();";
}
}
void OpEmitter::genCustomVerifier() {
auto *valueInit = def.getValueInit("verifier");
StringInit *stringInit = dyn_cast<StringInit>(valueInit);
bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
if (def.getValueAsBit("hasVerifier")) {
auto *method = opClass.declareMethod<Method::Private>(
"::mlir::LogicalResult", "verify");
auto *method = opClass.declareMethod("::mlir::LogicalResult", "verify");
ERROR_IF_PRUNED(method, "verify", op);
body << " return verify();\n";
} else if (def.getValueAsBit("hasRegionVerifier")) {
auto *method =
opClass.declareMethod("::mlir::LogicalResult", "verifyRegions");
ERROR_IF_PRUNED(method, "verifyRegions", op);
} else if (hasCustomVerifyCodeBlock) {
auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
ERROR_IF_PRUNED(method, "verify", op);
auto &body = method->body();
FmtContext fctx;
fctx.addSubst("cppClass", opClass.getClassName());
auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
body << " " << tgfmt(printer, &fctx);
} else {
body << " return ::mlir::success();\n";
}
}
@ -2508,12 +2541,27 @@ void OpEmitter::genTraits() {
}
}
// The op traits defined internal are ensured that they can be verified
// earlier.
for (const auto &trait : op.getTraits()) {
if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
if (opTrait->isStructuralOpTrait())
opClass.addTrait(opTrait->getFullyQualifiedTraitName());
}
}
// OpInvariants wrapps the verifyInvariants which needs to be run before
// native/interface traits and after all the traits with `StructuralOpTrait`.
opClass.addTrait("::mlir::OpTrait::OpInvariants");
// Add the native and interface traits.
for (const auto &trait : op.getTraits()) {
if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait))
opClass.addTrait(opTrait->getFullyQualifiedTraitName());
else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait))
if (auto *opTrait = dyn_cast<tblgen::NativeTrait>(&trait)) {
if (!opTrait->isStructuralOpTrait())
opClass.addTrait(opTrait->getFullyQualifiedTraitName());
} else if (auto *opTrait = dyn_cast<tblgen::InterfaceTrait>(&trait)) {
opClass.addTrait(opTrait->getFullyQualifiedTraitName());
}
}
}

View File

@ -413,9 +413,12 @@ void InterfaceGenerator::emitTraitDecl(const Interface &interface,
tblgen::FmtContext verifyCtx;
verifyCtx.withOp("op");
os << " static ::mlir::LogicalResult verifyTrait(::mlir::Operation *op) "
"{\n "
<< tblgen::tgfmt(verify->trim(), &verifyCtx) << "\n }\n";
os << llvm::formatv(
" static ::mlir::LogicalResult {0}(::mlir::Operation *op) ",
(interface.verifyWithRegions() ? "verifyRegionTrait"
: "verifyTrait"))
<< "{\n " << tblgen::tgfmt(verify->trim(), &verifyCtx)
<< "\n }\n";
}
if (auto extraTraitDecls = interface.getExtraTraitClassDeclaration())
os << tblgen::tgfmt(*extraTraitDecls, &traitMethodFmt) << "\n";