forked from OSchip/llvm-project
[mlir-tblgen] Minor Refactor for StaticVerifierFunctionEmitter.
Move StaticVerifierFunctionEmitter to CodeGenHelper.h so that it can be used for both ODS and DRR. Reviewed By: jpienaar Differential Revision: https://reviews.llvm.org/D106636
This commit is contained in:
parent
15497e62f6
commit
62df4df41c
|
@ -13,13 +13,21 @@
|
|||
#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H
|
||||
#define MLIR_TABLEGEN_CODEGENHELPERS_H
|
||||
|
||||
#include "mlir/Support/IndentedOstream.h"
|
||||
#include "mlir/TableGen/Dialect.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace llvm {
|
||||
class RecordKeeper;
|
||||
} // namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
class Constraint;
|
||||
|
||||
// Simple RAII helper for defining ifdef-undef-endif scopes.
|
||||
class IfDefScope {
|
||||
public:
|
||||
|
@ -62,6 +70,82 @@ private:
|
|||
SmallVector<StringRef, 2> namespaces;
|
||||
};
|
||||
|
||||
/// This class deduplicates shared operation verification code by emitting
|
||||
/// static functions alongside the op definitions. These methods are local to
|
||||
/// the definition file, and are invoked within the operation verify methods.
|
||||
/// An example is shown below:
|
||||
///
|
||||
/// static LogicalResult localVerify(...)
|
||||
///
|
||||
/// LogicalResult OpA::verify(...) {
|
||||
/// if (failed(localVerify(...)))
|
||||
/// return failure();
|
||||
/// ...
|
||||
/// }
|
||||
///
|
||||
/// LogicalResult OpB::verify(...) {
|
||||
/// if (failed(localVerify(...)))
|
||||
/// return failure();
|
||||
/// ...
|
||||
/// }
|
||||
///
|
||||
class StaticVerifierFunctionEmitter {
|
||||
public:
|
||||
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
|
||||
raw_ostream &os);
|
||||
|
||||
/// Emit the static verifier functions for `llvm::Record`s. The
|
||||
/// `signatureFormat` describes the required arguments and it must have a
|
||||
/// placeholder for function name.
|
||||
/// Example,
|
||||
/// const char *typeVerifierSignature =
|
||||
/// "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type"
|
||||
/// " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
|
||||
///
|
||||
/// `errorHandlerFormat` describes the error message to return. It may have a
|
||||
/// placeholder for the summary of Constraint and bring more information for
|
||||
/// the error message.
|
||||
/// Example,
|
||||
/// const char *typeVerifierErrorHandler =
|
||||
/// " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << "
|
||||
/// "\" must be {0}, but got \" << type";
|
||||
///
|
||||
/// `typeArgName` is used to identify the argument that needs to check its
|
||||
/// type. The constraint template will replace `$_self` with it.
|
||||
void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat,
|
||||
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs,
|
||||
bool emitDecl);
|
||||
|
||||
/// Get the name of the local function used for the given type constraint.
|
||||
/// These functions are used for operand and result constraints and have the
|
||||
/// form:
|
||||
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
|
||||
/// unsigned valueGroupStartIndex);
|
||||
StringRef getTypeConstraintFn(const Constraint &constraint) const;
|
||||
|
||||
private:
|
||||
/// Returns a unique name to use when generating local methods.
|
||||
static std::string getUniqueName(const llvm::RecordKeeper &records);
|
||||
|
||||
/// Emit local methods for the type constraints used within the provided op
|
||||
/// definitions.
|
||||
void emitTypeConstraintMethods(StringRef signatureFormat,
|
||||
StringRef errorHandlerFormat,
|
||||
StringRef typeArgName,
|
||||
ArrayRef<llvm::Record *> opDefs,
|
||||
bool emitDecl);
|
||||
|
||||
raw_indented_ostream os;
|
||||
|
||||
/// A unique label for the file currently being generated. This is used to
|
||||
/// ensure that the local functions have a unique name.
|
||||
std::string uniqueOutputLabel;
|
||||
|
||||
/// A set of functions implementing type constraints, used for operand and
|
||||
/// result verification.
|
||||
llvm::DenseMap<const void *, std::string> localTypeConstraints;
|
||||
};
|
||||
|
||||
} // namespace tblgen
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS
|
|||
|
||||
add_tablegen(mlir-tblgen MLIR
|
||||
AttrOrTypeDefGen.cpp
|
||||
CodeGenHelpers.cpp
|
||||
DialectGen.cpp
|
||||
DirectiveCommonGen.cpp
|
||||
EnumsGen.cpp
|
||||
|
|
|
@ -0,0 +1,139 @@
|
|||
//===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// OpDefinitionsGen uses the description of operations to generate C++
|
||||
// definitions for ops.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/CodeGenHelpers.h"
|
||||
#include "mlir/TableGen/Format.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace llvm;
|
||||
using namespace mlir;
|
||||
using namespace mlir::tblgen;
|
||||
|
||||
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
|
||||
const llvm::RecordKeeper &records, raw_ostream &os)
|
||||
: os(os), uniqueOutputLabel(getUniqueName(records)) {}
|
||||
|
||||
void StaticVerifierFunctionEmitter::emitFunctionsFor(
|
||||
StringRef signatureFormat, StringRef errorHandlerFormat,
|
||||
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
|
||||
llvm::Optional<NamespaceEmitter> namespaceEmitter;
|
||||
if (!emitDecl)
|
||||
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
|
||||
|
||||
emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
|
||||
opDefs, emitDecl);
|
||||
}
|
||||
|
||||
StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
|
||||
const Constraint &constraint) const {
|
||||
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
|
||||
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
std::string StaticVerifierFunctionEmitter::getUniqueName(
|
||||
const llvm::RecordKeeper &records) {
|
||||
// Use the input file name when generating a unique name.
|
||||
std::string inputFilename = records.getInputFilename();
|
||||
|
||||
// Drop all but the base filename.
|
||||
StringRef nameRef = llvm::sys::path::filename(inputFilename);
|
||||
nameRef.consume_back(".td");
|
||||
|
||||
// Sanitize any invalid characters.
|
||||
std::string uniqueName;
|
||||
for (char c : nameRef) {
|
||||
if (llvm::isAlnum(c) || c == '_')
|
||||
uniqueName.push_back(c);
|
||||
else
|
||||
uniqueName.append(llvm::utohexstr((unsigned char)c));
|
||||
}
|
||||
return uniqueName;
|
||||
}
|
||||
|
||||
void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
|
||||
StringRef signatureFormat, StringRef errorHandlerFormat,
|
||||
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
|
||||
// Collect a set of all of the used type constraints within the operation
|
||||
// definitions.
|
||||
llvm::SetVector<const void *> typeConstraints;
|
||||
for (Record *def : opDefs) {
|
||||
Operator op(*def);
|
||||
for (NamedTypeConstraint &operand : op.getOperands())
|
||||
if (operand.hasPredicate())
|
||||
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
|
||||
for (NamedTypeConstraint &result : op.getResults())
|
||||
if (result.hasPredicate())
|
||||
typeConstraints.insert(result.constraint.getAsOpaquePointer());
|
||||
}
|
||||
|
||||
// Record the mapping from predicate to constraint. If two constraints has the
|
||||
// same predicate and constraint summary, they can share the same verification
|
||||
// function.
|
||||
llvm::DenseMap<Pred, const void *> predToConstraint;
|
||||
FmtContext fctx;
|
||||
for (auto it : llvm::enumerate(typeConstraints)) {
|
||||
std::string name;
|
||||
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
|
||||
Pred pred = constraint.getPredicate();
|
||||
auto iter = predToConstraint.find(pred);
|
||||
if (iter != predToConstraint.end()) {
|
||||
do {
|
||||
Constraint built = Constraint::getFromOpaquePointer(iter->second);
|
||||
// We may have the different constraints but have the same predicate,
|
||||
// for example, ConstraintA and Variadic<ConstraintA>, note that
|
||||
// Variadic<> doesn't introduce new predicate. In this case, we can
|
||||
// share the same predicate function if they also have consistent
|
||||
// summary, otherwise we may report the wrong message while verification
|
||||
// fails.
|
||||
if (constraint.getSummary() == built.getSummary()) {
|
||||
name = getTypeConstraintFn(built).str();
|
||||
break;
|
||||
}
|
||||
++iter;
|
||||
} while (iter != predToConstraint.end() && iter->first == pred);
|
||||
}
|
||||
|
||||
if (!name.empty()) {
|
||||
localTypeConstraints.try_emplace(it.value(), name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Generate an obscure and unique name for this type constraint.
|
||||
name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
|
||||
Twine(it.index()))
|
||||
.str();
|
||||
predToConstraint.insert(
|
||||
std::make_pair(constraint.getPredicate(), it.value()));
|
||||
localTypeConstraints.try_emplace(it.value(), name);
|
||||
|
||||
// Only generate the methods if we are generating definitions.
|
||||
if (emitDecl)
|
||||
continue;
|
||||
|
||||
os << formatv(signatureFormat.data(), name) << " {\n";
|
||||
os.indent() << "if (!("
|
||||
<< tgfmt(constraint.getConditionTemplate(),
|
||||
&fctx.withSelf(typeArgName))
|
||||
<< ")) {\n";
|
||||
os.indent() << "return "
|
||||
<< formatv(errorHandlerFormat.data(), constraint.getSummary())
|
||||
<< ";\n";
|
||||
os.unindent() << "}\nreturn ::mlir::success();\n";
|
||||
os.unindent() << "}\n\n";
|
||||
}
|
||||
}
|
|
@ -24,7 +24,6 @@
|
|||
#include "llvm/ADT/MapVector.h"
|
||||
#include "llvm/ADT/Sequence.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/Path.h"
|
||||
#include "llvm/Support/Signals.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
@ -101,6 +100,14 @@ const char *valueRangeReturnCode = R"(
|
|||
std::next({0}, valueRange.first + valueRange.second)};
|
||||
)";
|
||||
|
||||
const char *typeVerifierSignature =
|
||||
"static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type "
|
||||
"type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
|
||||
|
||||
const char *typeVerifierErrorHandler =
|
||||
" op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must "
|
||||
"be {0}, but got \" << type";
|
||||
|
||||
static const char *const opCommentHeader = R"(
|
||||
//===----------------------------------------------------------------------===//
|
||||
// {0} {1}
|
||||
|
@ -108,175 +115,6 @@ static const char *const opCommentHeader = R"(
|
|||
|
||||
)";
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// StaticVerifierFunctionEmitter
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace {
|
||||
/// This class deduplicates shared operation verification code by emitting
|
||||
/// static functions alongside the op definitions. These methods are local to
|
||||
/// the definition file, and are invoked within the operation verify methods.
|
||||
/// An example is shown below:
|
||||
///
|
||||
/// static LogicalResult localVerify(...)
|
||||
///
|
||||
/// LogicalResult OpA::verify(...) {
|
||||
/// if (failed(localVerify(...)))
|
||||
/// return failure();
|
||||
/// ...
|
||||
/// }
|
||||
///
|
||||
/// LogicalResult OpB::verify(...) {
|
||||
/// if (failed(localVerify(...)))
|
||||
/// return failure();
|
||||
/// ...
|
||||
/// }
|
||||
///
|
||||
class StaticVerifierFunctionEmitter {
|
||||
public:
|
||||
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
|
||||
ArrayRef<llvm::Record *> opDefs,
|
||||
raw_ostream &os, bool emitDecl);
|
||||
|
||||
/// Get the name of the local function used for the given type constraint.
|
||||
/// These functions are used for operand and result constraints and have the
|
||||
/// form:
|
||||
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
|
||||
/// unsigned valueGroupStartIndex);
|
||||
StringRef getTypeConstraintFn(const Constraint &constraint) const {
|
||||
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
|
||||
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
|
||||
return it->second;
|
||||
}
|
||||
|
||||
private:
|
||||
/// Returns a unique name to use when generating local methods.
|
||||
static std::string getUniqueName(const llvm::RecordKeeper &records);
|
||||
|
||||
/// Emit local methods for the type constraints used within the provided op
|
||||
/// definitions.
|
||||
void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
|
||||
raw_ostream &os, bool emitDecl);
|
||||
|
||||
/// A unique label for the file currently being generated. This is used to
|
||||
/// ensure that the local functions have a unique name.
|
||||
std::string uniqueOutputLabel;
|
||||
|
||||
/// A set of functions implementing type constraints, used for operand and
|
||||
/// result verification.
|
||||
llvm::DenseMap<const void *, std::string> localTypeConstraints;
|
||||
};
|
||||
} // namespace
|
||||
|
||||
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
|
||||
const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
|
||||
raw_ostream &os, bool emitDecl)
|
||||
: uniqueOutputLabel(getUniqueName(records)) {
|
||||
llvm::Optional<NamespaceEmitter> namespaceEmitter;
|
||||
if (!emitDecl) {
|
||||
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
|
||||
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
|
||||
}
|
||||
|
||||
emitTypeConstraintMethods(opDefs, os, emitDecl);
|
||||
}
|
||||
|
||||
std::string StaticVerifierFunctionEmitter::getUniqueName(
|
||||
const llvm::RecordKeeper &records) {
|
||||
// Use the input file name when generating a unique name.
|
||||
std::string inputFilename = records.getInputFilename();
|
||||
|
||||
// Drop all but the base filename.
|
||||
StringRef nameRef = llvm::sys::path::filename(inputFilename);
|
||||
nameRef.consume_back(".td");
|
||||
|
||||
// Sanitize any invalid characters.
|
||||
std::string uniqueName;
|
||||
for (char c : nameRef) {
|
||||
if (llvm::isAlnum(c) || c == '_')
|
||||
uniqueName.push_back(c);
|
||||
else
|
||||
uniqueName.append(llvm::utohexstr((unsigned char)c));
|
||||
}
|
||||
return uniqueName;
|
||||
}
|
||||
|
||||
void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
|
||||
ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
|
||||
// Collect a set of all of the used type constraints within the operation
|
||||
// definitions.
|
||||
llvm::SetVector<const void *> typeConstraints;
|
||||
for (Record *def : opDefs) {
|
||||
Operator op(*def);
|
||||
for (NamedTypeConstraint &operand : op.getOperands())
|
||||
if (operand.hasPredicate())
|
||||
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
|
||||
for (NamedTypeConstraint &result : op.getResults())
|
||||
if (result.hasPredicate())
|
||||
typeConstraints.insert(result.constraint.getAsOpaquePointer());
|
||||
}
|
||||
|
||||
// Record the mapping from predicate to constraint. If two constraints has the
|
||||
// same predicate and constraint summary, they can share the same verification
|
||||
// function.
|
||||
llvm::DenseMap<Pred, const void *> predToConstraint;
|
||||
FmtContext fctx;
|
||||
for (auto it : llvm::enumerate(typeConstraints)) {
|
||||
std::string name;
|
||||
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
|
||||
Pred pred = constraint.getPredicate();
|
||||
auto iter = predToConstraint.find(pred);
|
||||
if (iter != predToConstraint.end()) {
|
||||
do {
|
||||
Constraint built = Constraint::getFromOpaquePointer(iter->second);
|
||||
// We may have the different constraints but have the same predicate,
|
||||
// for example, ConstraintA and Variadic<ConstraintA>, note that
|
||||
// Variadic<> doesn't introduce new predicate. In this case, we can
|
||||
// share the same predicate function if they also have consistent
|
||||
// summary, otherwise we may report the wrong message while verification
|
||||
// fails.
|
||||
if (constraint.getSummary() == built.getSummary()) {
|
||||
name = getTypeConstraintFn(built).str();
|
||||
break;
|
||||
}
|
||||
++iter;
|
||||
} while (iter != predToConstraint.end() && iter->first == pred);
|
||||
}
|
||||
|
||||
if (!name.empty()) {
|
||||
localTypeConstraints.try_emplace(it.value(), name);
|
||||
continue;
|
||||
}
|
||||
|
||||
// Generate an obscure and unique name for this type constraint.
|
||||
name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
|
||||
Twine(it.index()))
|
||||
.str();
|
||||
predToConstraint.insert(
|
||||
std::make_pair(constraint.getPredicate(), it.value()));
|
||||
localTypeConstraints.try_emplace(it.value(), name);
|
||||
|
||||
// Only generate the methods if we are generating definitions.
|
||||
if (emitDecl)
|
||||
continue;
|
||||
|
||||
os << "static ::mlir::LogicalResult " << name
|
||||
<< "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
|
||||
"valueKind, unsigned valueGroupStartIndex) {\n";
|
||||
|
||||
os << " if (!("
|
||||
<< tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
|
||||
<< ")) {\n"
|
||||
<< formatv(
|
||||
" return op->emitOpError(valueKind) << \" #\" << "
|
||||
"valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
|
||||
constraint.getSummary())
|
||||
<< " }\n"
|
||||
<< " return ::mlir::success();\n"
|
||||
<< "}\n\n";
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Utility structs and functions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -2560,8 +2398,12 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
|
|||
return;
|
||||
|
||||
// Generate all of the locally instantiated methods first.
|
||||
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
|
||||
emitDecl);
|
||||
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os);
|
||||
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
|
||||
staticVerifierEmitter.emitFunctionsFor(
|
||||
typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type",
|
||||
defs, emitDecl);
|
||||
|
||||
for (auto *def : defs) {
|
||||
Operator op(*def);
|
||||
if (emitDecl) {
|
||||
|
|
Loading…
Reference in New Issue