forked from OSchip/llvm-project
[mlir] Refactor how additional verification is specified in ODS
Currently if an operation requires additional verification, it specifies an inline code block (`let verifier = "blah"`). This is quite problematic for various reasons, e.g. it requires defining C++ inside of Tablegen which is discouraged when possible, but mainly because nearly all usages simply forward to a static function `static LogicalResult verify(SomeOp op)`. This commit adds support for a `hasVerifier` bit field that specifies if an additional verifier is needed, and when set to `1` declares a `LogicalResult verify()` method for operations to override. For migration purposes, the existing behavior is untouched. Upstream usages will be replaced in a followup to keep this patch focused on the hasVerifier implementation. One main user facing change is that what was one `MyOp::verify` is now `MyOp::verifyInvariants`. This better matches the name this method is called everywhere else, and also frees up `verify` for the user defined additional verification. The `verify` function when generated now (for additional verification) is private to the operation class, which should also help avoid accidental usages after this switch. Differential Revision: https://reviews.llvm.org/D118742
This commit is contained in:
parent
7e9d19016e
commit
42e5f1d97b
|
@ -86,7 +86,7 @@ compileFIR(const mlir::PassPipelineCLParser &passPipeline) {
|
|||
errs() << "Error can't load file " << inputFilename << '\n';
|
||||
return mlir::failure();
|
||||
}
|
||||
if (mlir::failed(owningRef->verify())) {
|
||||
if (mlir::failed(owningRef->verifyInvariants())) {
|
||||
errs() << "Error verifying FIR module\n";
|
||||
return mlir::failure();
|
||||
}
|
||||
|
|
|
@ -564,14 +564,13 @@ Verification code will be automatically generated for
|
|||
_additional_ verification, you can use
|
||||
|
||||
```tablegen
|
||||
let verifier = [{
|
||||
...
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
```
|
||||
|
||||
Code placed in `verifier` will be called after the auto-generated verification
|
||||
code. The order of trait verification excluding those of `verifier` should not
|
||||
be relied upon.
|
||||
This will generate a `LogicalResult verify()` method declaration on the op class
|
||||
that can be defined with any additional verification constraints. This method
|
||||
will be invoked after the auto-generated verification code. The order of trait
|
||||
verification excluding those of `hasVerifier` should not be relied upon.
|
||||
|
||||
### Declarative Assembly Format
|
||||
|
||||
|
|
|
@ -225,7 +225,7 @@ public:
|
|||
static StringRef getOperationName() { return "affine.dma_start"; }
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
LogicalResult verifyInvariants();
|
||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
|
||||
|
@ -313,7 +313,7 @@ public:
|
|||
static StringRef getTagMapAttrName() { return "tag_map"; }
|
||||
static ParseResult parse(OpAsmParser &parser, OperationState &result);
|
||||
void print(OpAsmPrinter &p);
|
||||
LogicalResult verify();
|
||||
LogicalResult verifyInvariants();
|
||||
LogicalResult fold(ArrayRef<Attribute> cstOperands,
|
||||
SmallVectorImpl<OpFoldResult> &results);
|
||||
};
|
||||
|
|
|
@ -2451,7 +2451,16 @@ class Op<Dialect dialect, string mnemonic, list<Trait> props = []> {
|
|||
// Custom assembly format.
|
||||
string assemblyFormat = ?;
|
||||
|
||||
// Custom verifier.
|
||||
// A bit indicating if the operation has additional invariants that need to
|
||||
// verified (aside from those verified by other ODS constructs). If set to `1`,
|
||||
// an additional `LogicalResult verify()` declaration will be generated on the
|
||||
// operation class. The operation should implement this method and verify the
|
||||
// additional necessary invariants.
|
||||
bit hasVerifier = 0;
|
||||
// A custom code block corresponding to the extra verification code of the
|
||||
// operation.
|
||||
// NOTE: This field is deprecated in favor of `hasVerifier` and is slated for
|
||||
// deletion.
|
||||
code verifier = ?;
|
||||
|
||||
// Whether this op has associated canonicalization patterns.
|
||||
|
|
|
@ -201,7 +201,7 @@ public:
|
|||
protected:
|
||||
/// If the concrete type didn't implement a custom verifier hook, just fall
|
||||
/// back to this one which accepts everything.
|
||||
LogicalResult verify() { return success(); }
|
||||
LogicalResult verifyInvariants() { return success(); }
|
||||
|
||||
/// Parse the custom form of an operation. Unless overridden, this method will
|
||||
/// first try to get an operation parser from the op's dialect. Otherwise the
|
||||
|
@ -1604,6 +1604,7 @@ class Op : public OpState, public Traits<ConcreteType>... {
|
|||
public:
|
||||
/// Inherit getOperation from `OpState`.
|
||||
using OpState::getOperation;
|
||||
using OpState::verifyInvariants;
|
||||
|
||||
/// Return if this operation contains the provided trait.
|
||||
template <template <typename T> class Trait>
|
||||
|
@ -1834,8 +1835,15 @@ private:
|
|||
return cast<ConcreteType>(op).print(p);
|
||||
}
|
||||
/// Implementation of `VerifyInvariantsFn` OperationName hook.
|
||||
static LogicalResult verifyInvariants(Operation *op) {
|
||||
static_assert(hasNoDataMembers(),
|
||||
"Op class shouldn't define new data members");
|
||||
return failure(
|
||||
failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
|
||||
failed(cast<ConcreteType>(op).verifyInvariants()));
|
||||
}
|
||||
static OperationName::VerifyInvariantsFn getVerifyInvariantsFn() {
|
||||
return &verifyInvariants;
|
||||
return static_cast<LogicalResult (*)(Operation *)>(&verifyInvariants);
|
||||
}
|
||||
|
||||
static constexpr bool hasNoDataMembers() {
|
||||
|
@ -1845,14 +1853,6 @@ private:
|
|||
return sizeof(ConcreteType) == sizeof(EmptyOp);
|
||||
}
|
||||
|
||||
static LogicalResult verifyInvariants(Operation *op) {
|
||||
static_assert(hasNoDataMembers(),
|
||||
"Op class shouldn't define new data members");
|
||||
return failure(
|
||||
failed(op_definition_impl::verifyTraits<VerifiableTraitsTupleT>(op)) ||
|
||||
failed(cast<ConcreteType>(op).verify()));
|
||||
}
|
||||
|
||||
/// Allow access to internal implementation methods.
|
||||
friend RegisteredOperationName;
|
||||
};
|
||||
|
|
|
@ -67,7 +67,7 @@ inline OwningOpRef<ContainerOpT> constructContainerOpForParserIfNecessary(
|
|||
|
||||
// After splicing, verify just this operation to ensure it can properly
|
||||
// contain the operations inside of it.
|
||||
if (failed(op.verify()))
|
||||
if (failed(op.verifyInvariants()))
|
||||
return OwningOpRef<ContainerOpT>();
|
||||
return opRef;
|
||||
}
|
||||
|
|
|
@ -1119,7 +1119,7 @@ ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AffineDmaStartOp::verify() {
|
||||
LogicalResult AffineDmaStartOp::verifyInvariants() {
|
||||
if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA source to be of memref type");
|
||||
if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
|
||||
|
@ -1221,7 +1221,7 @@ ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
|
|||
return success();
|
||||
}
|
||||
|
||||
LogicalResult AffineDmaWaitOp::verify() {
|
||||
LogicalResult AffineDmaWaitOp::verifyInvariants() {
|
||||
if (!getOperand(0).getType().isa<MemRefType>())
|
||||
return emitOpError("expected DMA tag to be of memref type");
|
||||
Region *scope = getAffineScope(*this);
|
||||
|
|
|
@ -86,7 +86,7 @@ Serializer::Serializer(spirv::ModuleOp module,
|
|||
LogicalResult Serializer::serialize() {
|
||||
LLVM_DEBUG(llvm::dbgs() << "+++ starting serialization +++\n");
|
||||
|
||||
if (failed(module.verify()))
|
||||
if (failed(module.verifyInvariants()))
|
||||
return failure();
|
||||
|
||||
// TODO: handle the other sections
|
||||
|
|
|
@ -1118,6 +1118,26 @@ void StringAttrPrettyNameOp::getAsmResultNames(
|
|||
setNameFn(getResult(i), str.getValue());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ResultTypeWithTraitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ResultTypeWithTraitOp::verify() {
|
||||
if ((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
|
||||
return success();
|
||||
return emitError("result type should have trait 'TestTypeTrait'");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AttrWithTraitOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult AttrWithTraitOp::verify() {
|
||||
if (getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
|
||||
return success();
|
||||
return emitError("'attr' attribute should have trait 'TestAttrTrait'");
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// RegionIfOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -666,27 +666,16 @@ def DefaultDialectOp : TEST_Op<"default_dialect", [OpAsmOpInterface]> {
|
|||
// This operation requires its return type to have the trait 'TestTypeTrait'.
|
||||
def ResultTypeWithTraitOp : TEST_Op<"result_type_with_trait", []> {
|
||||
let results = (outs AnyType);
|
||||
|
||||
let verifier = [{
|
||||
if((*this)->getResultTypes()[0].hasTrait<TypeTrait::TestTypeTrait>())
|
||||
return success();
|
||||
return this->emitError("result type should have trait 'TestTypeTrait'");
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
// This operation requires its "attr" attribute to have the
|
||||
// trait 'TestAttrTrait'.
|
||||
def AttrWithTraitOp : TEST_Op<"attr_with_trait", []> {
|
||||
let arguments = (ins AnyAttr:$attr);
|
||||
|
||||
let verifier = [{
|
||||
if (this->getAttr().hasTrait<AttributeTrait::TestAttrTrait>())
|
||||
return success();
|
||||
return this->emitError("'attr' attribute should have trait 'TestAttrTrait'");
|
||||
}];
|
||||
let hasVerifier = 1;
|
||||
}
|
||||
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test Locations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -98,7 +98,7 @@ def NS_AOp : NS_Op<"a_op", [IsolatedFromAbove, IsolatedFromAbove]> {
|
|||
// CHECK: static void build(::mlir::OpBuilder &, ::mlir::OperationState &odsState, ::mlir::TypeRange resultTypes, ::mlir::ValueRange operands, ::llvm::ArrayRef<::mlir::NamedAttribute> attributes, unsigned numRegions)
|
||||
// CHECK: static ::mlir::ParseResult parse(::mlir::OpAsmParser &parser, ::mlir::OperationState &result);
|
||||
// CHECK: void print(::mlir::OpAsmPrinter &p);
|
||||
// CHECK: ::mlir::LogicalResult verify();
|
||||
// CHECK: ::mlir::LogicalResult verifyInvariants();
|
||||
// CHECK: static void getCanonicalizationPatterns(::mlir::RewritePatternSet &results, ::mlir::MLIRContext *context);
|
||||
// CHECK: ::mlir::LogicalResult fold(::llvm::ArrayRef<::mlir::Attribute> operands, ::llvm::SmallVectorImpl<::mlir::OpFoldResult> &results);
|
||||
// CHECK: // Display a graph for debugging purposes.
|
||||
|
|
|
@ -2208,8 +2208,8 @@ static void genNativeTraitAttrVerifier(MethodBody &body,
|
|||
}
|
||||
|
||||
void OpEmitter::genVerifier() {
|
||||
auto *method = opClass.addMethod("::mlir::LogicalResult", "verify");
|
||||
ERROR_IF_PRUNED(method, "verify", op);
|
||||
auto *method = opClass.addMethod("::mlir::LogicalResult", "verifyInvariants");
|
||||
ERROR_IF_PRUNED(method, "verifyInvariants", op);
|
||||
auto &body = method->body();
|
||||
|
||||
OpOrAdaptorHelper emitHelper(op, /*isOp=*/true);
|
||||
|
@ -2217,7 +2217,7 @@ void OpEmitter::genVerifier() {
|
|||
|
||||
auto *valueInit = def.getValueInit("verifier");
|
||||
StringInit *stringInit = dyn_cast<StringInit>(valueInit);
|
||||
bool hasCustomVerify = stringInit && !stringInit->getValue().empty();
|
||||
bool hasCustomVerifyCodeBlock = stringInit && !stringInit->getValue().empty();
|
||||
populateSubstitutions(emitHelper, verifyCtx);
|
||||
|
||||
genAttributeVerifier(emitHelper, verifyCtx, body, staticVerifierEmitter);
|
||||
|
@ -2236,7 +2236,13 @@ void OpEmitter::genVerifier() {
|
|||
genRegionVerifier(body);
|
||||
genSuccessorVerifier(body);
|
||||
|
||||
if (hasCustomVerify) {
|
||||
if (def.getValueAsBit("hasVerifier")) {
|
||||
auto *method = opClass.declareMethod<Method::Private>(
|
||||
"::mlir::LogicalResult", "verify");
|
||||
ERROR_IF_PRUNED(method, "verify", op);
|
||||
body << " return verify();\n";
|
||||
|
||||
} else if (hasCustomVerifyCodeBlock) {
|
||||
FmtContext fctx;
|
||||
fctx.addSubst("cppClass", opClass.getClassName());
|
||||
auto printer = stringInit->getValue().ltrim().rtrim(" \t\v\f\r");
|
||||
|
|
Loading…
Reference in New Issue