[mlir-tblgen] Minor Refactor for StaticVerifierFunctionEmitter.

Move StaticVerifierFunctionEmitter to CodeGenHelper.h so that it can be
used for both ODS and DRR.

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D106636
This commit is contained in:
Chia-hung Duan 2021-08-12 17:35:00 +00:00
parent 15497e62f6
commit 62df4df41c
4 changed files with 238 additions and 172 deletions

View File

@ -13,13 +13,21 @@
#ifndef MLIR_TABLEGEN_CODEGENHELPERS_H
#define MLIR_TABLEGEN_CODEGENHELPERS_H
#include "mlir/Support/IndentedOstream.h"
#include "mlir/TableGen/Dialect.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/ADT/StringRef.h"
namespace llvm {
class RecordKeeper;
} // namespace llvm
namespace mlir {
namespace tblgen {
class Constraint;
// Simple RAII helper for defining ifdef-undef-endif scopes.
class IfDefScope {
public:
@ -62,6 +70,82 @@ private:
SmallVector<StringRef, 2> namespaces;
};
/// This class deduplicates shared operation verification code by emitting
/// static functions alongside the op definitions. These methods are local to
/// the definition file, and are invoked within the operation verify methods.
/// An example is shown below:
///
/// static LogicalResult localVerify(...)
///
/// LogicalResult OpA::verify(...) {
/// if (failed(localVerify(...)))
/// return failure();
/// ...
/// }
///
/// LogicalResult OpB::verify(...) {
/// if (failed(localVerify(...)))
/// return failure();
/// ...
/// }
///
class StaticVerifierFunctionEmitter {
public:
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
raw_ostream &os);
/// Emit the static verifier functions for `llvm::Record`s. The
/// `signatureFormat` describes the required arguments and it must have a
/// placeholder for function name.
/// Example,
/// const char *typeVerifierSignature =
/// "static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type"
/// " type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
///
/// `errorHandlerFormat` describes the error message to return. It may have a
/// placeholder for the summary of Constraint and bring more information for
/// the error message.
/// Example,
/// const char *typeVerifierErrorHandler =
/// " op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << "
/// "\" must be {0}, but got \" << type";
///
/// `typeArgName` is used to identify the argument that needs to check its
/// type. The constraint template will replace `$_self` with it.
void emitFunctionsFor(StringRef signatureFormat, StringRef errorHandlerFormat,
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs,
bool emitDecl);
/// Get the name of the local function used for the given type constraint.
/// These functions are used for operand and result constraints and have the
/// form:
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
/// unsigned valueGroupStartIndex);
StringRef getTypeConstraintFn(const Constraint &constraint) const;
private:
/// Returns a unique name to use when generating local methods.
static std::string getUniqueName(const llvm::RecordKeeper &records);
/// Emit local methods for the type constraints used within the provided op
/// definitions.
void emitTypeConstraintMethods(StringRef signatureFormat,
StringRef errorHandlerFormat,
StringRef typeArgName,
ArrayRef<llvm::Record *> opDefs,
bool emitDecl);
raw_indented_ostream os;
/// A unique label for the file currently being generated. This is used to
/// ensure that the local functions have a unique name.
std::string uniqueOutputLabel;
/// A set of functions implementing type constraints, used for operand and
/// result verification.
llvm::DenseMap<const void *, std::string> localTypeConstraints;
};
} // namespace tblgen
} // namespace mlir

View File

@ -6,6 +6,7 @@ set(LLVM_LINK_COMPONENTS
add_tablegen(mlir-tblgen MLIR
AttrOrTypeDefGen.cpp
CodeGenHelpers.cpp
DialectGen.cpp
DirectiveCommonGen.cpp
EnumsGen.cpp

View File

@ -0,0 +1,139 @@
//===- CodeGenHelpers.cpp - MLIR op definitions generator ---------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// OpDefinitionsGen uses the description of operations to generate C++
// definitions for ops.
//
//===----------------------------------------------------------------------===//
#include "mlir/TableGen/CodeGenHelpers.h"
#include "mlir/TableGen/Format.h"
#include "mlir/TableGen/Operator.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/Support/FormatVariadic.h"
#include "llvm/Support/Path.h"
#include "llvm/TableGen/Record.h"
using namespace llvm;
using namespace mlir;
using namespace mlir::tblgen;
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
const llvm::RecordKeeper &records, raw_ostream &os)
: os(os), uniqueOutputLabel(getUniqueName(records)) {}
void StaticVerifierFunctionEmitter::emitFunctionsFor(
StringRef signatureFormat, StringRef errorHandlerFormat,
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
llvm::Optional<NamespaceEmitter> namespaceEmitter;
if (!emitDecl)
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
emitTypeConstraintMethods(signatureFormat, errorHandlerFormat, typeArgName,
opDefs, emitDecl);
}
StringRef StaticVerifierFunctionEmitter::getTypeConstraintFn(
const Constraint &constraint) const {
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
return it->second;
}
std::string StaticVerifierFunctionEmitter::getUniqueName(
const llvm::RecordKeeper &records) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();
// Drop all but the base filename.
StringRef nameRef = llvm::sys::path::filename(inputFilename);
nameRef.consume_back(".td");
// Sanitize any invalid characters.
std::string uniqueName;
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
uniqueName.push_back(c);
else
uniqueName.append(llvm::utohexstr((unsigned char)c));
}
return uniqueName;
}
void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
StringRef signatureFormat, StringRef errorHandlerFormat,
StringRef typeArgName, ArrayRef<llvm::Record *> opDefs, bool emitDecl) {
// Collect a set of all of the used type constraints within the operation
// definitions.
llvm::SetVector<const void *> typeConstraints;
for (Record *def : opDefs) {
Operator op(*def);
for (NamedTypeConstraint &operand : op.getOperands())
if (operand.hasPredicate())
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
for (NamedTypeConstraint &result : op.getResults())
if (result.hasPredicate())
typeConstraints.insert(result.constraint.getAsOpaquePointer());
}
// Record the mapping from predicate to constraint. If two constraints has the
// same predicate and constraint summary, they can share the same verification
// function.
llvm::DenseMap<Pred, const void *> predToConstraint;
FmtContext fctx;
for (auto it : llvm::enumerate(typeConstraints)) {
std::string name;
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
Pred pred = constraint.getPredicate();
auto iter = predToConstraint.find(pred);
if (iter != predToConstraint.end()) {
do {
Constraint built = Constraint::getFromOpaquePointer(iter->second);
// We may have the different constraints but have the same predicate,
// for example, ConstraintA and Variadic<ConstraintA>, note that
// Variadic<> doesn't introduce new predicate. In this case, we can
// share the same predicate function if they also have consistent
// summary, otherwise we may report the wrong message while verification
// fails.
if (constraint.getSummary() == built.getSummary()) {
name = getTypeConstraintFn(built).str();
break;
}
++iter;
} while (iter != predToConstraint.end() && iter->first == pred);
}
if (!name.empty()) {
localTypeConstraints.try_emplace(it.value(), name);
continue;
}
// Generate an obscure and unique name for this type constraint.
name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
Twine(it.index()))
.str();
predToConstraint.insert(
std::make_pair(constraint.getPredicate(), it.value()));
localTypeConstraints.try_emplace(it.value(), name);
// Only generate the methods if we are generating definitions.
if (emitDecl)
continue;
os << formatv(signatureFormat.data(), name) << " {\n";
os.indent() << "if (!("
<< tgfmt(constraint.getConditionTemplate(),
&fctx.withSelf(typeArgName))
<< ")) {\n";
os.indent() << "return "
<< formatv(errorHandlerFormat.data(), constraint.getSummary())
<< ";\n";
os.unindent() << "}\nreturn ::mlir::success();\n";
os.unindent() << "}\n\n";
}
}

View File

@ -24,7 +24,6 @@
#include "llvm/ADT/MapVector.h"
#include "llvm/ADT/Sequence.h"
#include "llvm/ADT/StringExtras.h"
#include "llvm/Support/Path.h"
#include "llvm/Support/Signals.h"
#include "llvm/TableGen/Error.h"
#include "llvm/TableGen/Record.h"
@ -101,6 +100,14 @@ const char *valueRangeReturnCode = R"(
std::next({0}, valueRange.first + valueRange.second)};
)";
const char *typeVerifierSignature =
"static ::mlir::LogicalResult {0}(::mlir::Operation *op, ::mlir::Type "
"type, ::llvm::StringRef valueKind, unsigned valueGroupStartIndex)";
const char *typeVerifierErrorHandler =
" op->emitOpError(valueKind) << \" #\" << valueGroupStartIndex << \" must "
"be {0}, but got \" << type";
static const char *const opCommentHeader = R"(
//===----------------------------------------------------------------------===//
// {0} {1}
@ -108,175 +115,6 @@ static const char *const opCommentHeader = R"(
)";
//===----------------------------------------------------------------------===//
// StaticVerifierFunctionEmitter
//===----------------------------------------------------------------------===//
namespace {
/// This class deduplicates shared operation verification code by emitting
/// static functions alongside the op definitions. These methods are local to
/// the definition file, and are invoked within the operation verify methods.
/// An example is shown below:
///
/// static LogicalResult localVerify(...)
///
/// LogicalResult OpA::verify(...) {
/// if (failed(localVerify(...)))
/// return failure();
/// ...
/// }
///
/// LogicalResult OpB::verify(...) {
/// if (failed(localVerify(...)))
/// return failure();
/// ...
/// }
///
class StaticVerifierFunctionEmitter {
public:
StaticVerifierFunctionEmitter(const llvm::RecordKeeper &records,
ArrayRef<llvm::Record *> opDefs,
raw_ostream &os, bool emitDecl);
/// Get the name of the local function used for the given type constraint.
/// These functions are used for operand and result constraints and have the
/// form:
/// LogicalResult(Operation *op, Type type, StringRef valueKind,
/// unsigned valueGroupStartIndex);
StringRef getTypeConstraintFn(const Constraint &constraint) const {
auto it = localTypeConstraints.find(constraint.getAsOpaquePointer());
assert(it != localTypeConstraints.end() && "expected valid constraint fn");
return it->second;
}
private:
/// Returns a unique name to use when generating local methods.
static std::string getUniqueName(const llvm::RecordKeeper &records);
/// Emit local methods for the type constraints used within the provided op
/// definitions.
void emitTypeConstraintMethods(ArrayRef<llvm::Record *> opDefs,
raw_ostream &os, bool emitDecl);
/// A unique label for the file currently being generated. This is used to
/// ensure that the local functions have a unique name.
std::string uniqueOutputLabel;
/// A set of functions implementing type constraints, used for operand and
/// result verification.
llvm::DenseMap<const void *, std::string> localTypeConstraints;
};
} // namespace
StaticVerifierFunctionEmitter::StaticVerifierFunctionEmitter(
const llvm::RecordKeeper &records, ArrayRef<llvm::Record *> opDefs,
raw_ostream &os, bool emitDecl)
: uniqueOutputLabel(getUniqueName(records)) {
llvm::Optional<NamespaceEmitter> namespaceEmitter;
if (!emitDecl) {
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
namespaceEmitter.emplace(os, Operator(*opDefs[0]).getCppNamespace());
}
emitTypeConstraintMethods(opDefs, os, emitDecl);
}
std::string StaticVerifierFunctionEmitter::getUniqueName(
const llvm::RecordKeeper &records) {
// Use the input file name when generating a unique name.
std::string inputFilename = records.getInputFilename();
// Drop all but the base filename.
StringRef nameRef = llvm::sys::path::filename(inputFilename);
nameRef.consume_back(".td");
// Sanitize any invalid characters.
std::string uniqueName;
for (char c : nameRef) {
if (llvm::isAlnum(c) || c == '_')
uniqueName.push_back(c);
else
uniqueName.append(llvm::utohexstr((unsigned char)c));
}
return uniqueName;
}
void StaticVerifierFunctionEmitter::emitTypeConstraintMethods(
ArrayRef<llvm::Record *> opDefs, raw_ostream &os, bool emitDecl) {
// Collect a set of all of the used type constraints within the operation
// definitions.
llvm::SetVector<const void *> typeConstraints;
for (Record *def : opDefs) {
Operator op(*def);
for (NamedTypeConstraint &operand : op.getOperands())
if (operand.hasPredicate())
typeConstraints.insert(operand.constraint.getAsOpaquePointer());
for (NamedTypeConstraint &result : op.getResults())
if (result.hasPredicate())
typeConstraints.insert(result.constraint.getAsOpaquePointer());
}
// Record the mapping from predicate to constraint. If two constraints has the
// same predicate and constraint summary, they can share the same verification
// function.
llvm::DenseMap<Pred, const void *> predToConstraint;
FmtContext fctx;
for (auto it : llvm::enumerate(typeConstraints)) {
std::string name;
Constraint constraint = Constraint::getFromOpaquePointer(it.value());
Pred pred = constraint.getPredicate();
auto iter = predToConstraint.find(pred);
if (iter != predToConstraint.end()) {
do {
Constraint built = Constraint::getFromOpaquePointer(iter->second);
// We may have the different constraints but have the same predicate,
// for example, ConstraintA and Variadic<ConstraintA>, note that
// Variadic<> doesn't introduce new predicate. In this case, we can
// share the same predicate function if they also have consistent
// summary, otherwise we may report the wrong message while verification
// fails.
if (constraint.getSummary() == built.getSummary()) {
name = getTypeConstraintFn(built).str();
break;
}
++iter;
} while (iter != predToConstraint.end() && iter->first == pred);
}
if (!name.empty()) {
localTypeConstraints.try_emplace(it.value(), name);
continue;
}
// Generate an obscure and unique name for this type constraint.
name = (Twine("__mlir_ods_local_type_constraint_") + uniqueOutputLabel +
Twine(it.index()))
.str();
predToConstraint.insert(
std::make_pair(constraint.getPredicate(), it.value()));
localTypeConstraints.try_emplace(it.value(), name);
// Only generate the methods if we are generating definitions.
if (emitDecl)
continue;
os << "static ::mlir::LogicalResult " << name
<< "(::mlir::Operation *op, ::mlir::Type type, ::llvm::StringRef "
"valueKind, unsigned valueGroupStartIndex) {\n";
os << " if (!("
<< tgfmt(constraint.getConditionTemplate(), &fctx.withSelf("type"))
<< ")) {\n"
<< formatv(
" return op->emitOpError(valueKind) << \" #\" << "
"valueGroupStartIndex << \" must be {0}, but got \" << type;\n",
constraint.getSummary())
<< " }\n"
<< " return ::mlir::success();\n"
<< "}\n\n";
}
}
//===----------------------------------------------------------------------===//
// Utility structs and functions
//===----------------------------------------------------------------------===//
@ -2560,8 +2398,12 @@ static void emitOpClasses(const RecordKeeper &recordKeeper,
return;
// Generate all of the locally instantiated methods first.
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, defs, os,
emitDecl);
StaticVerifierFunctionEmitter staticVerifierEmitter(recordKeeper, os);
os << formatv(opCommentHeader, "Local Utility Method", "Definitions");
staticVerifierEmitter.emitFunctionsFor(
typeVerifierSignature, typeVerifierErrorHandler, /*typeArgName=*/"type",
defs, emitDecl);
for (auto *def : defs) {
Operator op(*def);
if (emitDecl) {