forked from OSchip/llvm-project
Verify type of operands match those specifed in op registry.
Expand type to include matcher predicates. Use CNF form to allow specifying combinations of constraints for type. The matching call for the type is used to verify the construction of the operation as well as in rewrite pattern generation. The matching initially includes redundant checks (e.g., even if the operand of the op is guaranteed to satisfy some requirement, it is still checked during matcher generation for now). As well as some of the traits specified now check what the generated code already checks. Some of the traits can be removed in future as the verify method will include the relevant checks based on the op definition already. More work is needed for variadic operands. CNF form is used so that in the follow up redundant checks in the rewrite patterns could be omitted (e.g., when matching a F32Tensor, one does not need to verify that op X's operand 0 is a Tensor if that is guaranteed by op X's definition). The alternative was to have single matcher function specified, but this would not allow for reasoning about what attributes already hold (at the level of PredAtoms). Use this new operand type restrictions to rewrite BiasAdd with floating point operands as declarative pattern. PiperOrigin-RevId: 227991412
This commit is contained in:
parent
62dabbfd09
commit
8f24943826
|
@ -27,16 +27,38 @@
|
|||
// Types.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// Singular predicate condition.
|
||||
class PredAtom<code PredCall, bit Negated = 0> {
|
||||
// The function to invoke to compute the predicate.
|
||||
code predCall = PredCall;
|
||||
|
||||
// Whether the predicate result should be negated.
|
||||
bit negated = 0;
|
||||
}
|
||||
|
||||
// Predicate atoms in conjunctive normal form. The inner list consists
|
||||
// of PredAtoms, one of which in the list must hold, while all the outer
|
||||
// most conditions must hold. Conceptually
|
||||
// all_of(outer_conditions, any_of(inner_conditions)).
|
||||
class PredCNF<list<list<PredAtom>> Conditions> {
|
||||
list<list<PredAtom>> conditions = Conditions;
|
||||
}
|
||||
|
||||
// Base class for all types.
|
||||
class Type {
|
||||
// The builder call to invoke (if specified) to construct the Type.
|
||||
// Format: this will be affixed to the builder.
|
||||
code builderCall = ?;
|
||||
// The predicates that this type satisfies.
|
||||
// Format: {0} will be expanded to the type.
|
||||
PredCNF predicate = ?;
|
||||
}
|
||||
|
||||
// Integer types.
|
||||
class I<int width> : Type {
|
||||
int bitwidth = width;
|
||||
let builderCall = "getIntegerType(" # bitwidth # ")";
|
||||
let predicate = PredCNF<[[PredAtom<"{0}.isInteger(" # bitWidth # ")">]]>;
|
||||
}
|
||||
def I1 : I<1>;
|
||||
def I32 : I<32>;
|
||||
|
@ -45,24 +67,65 @@ def I32 : I<32>;
|
|||
class F<int width> : Type {
|
||||
int bitwidth = width;
|
||||
}
|
||||
def IsF32TypePred : PredAtom<"{0}.isF32()">;
|
||||
def F32 : F<32> {
|
||||
let builderCall = "getF32Type()";
|
||||
let predicate = PredCNF<[[IsF32TypePred]]>;
|
||||
}
|
||||
|
||||
// A container type is a type that has another embedded within it.
|
||||
class ContainerType<Type ElementType,
|
||||
list<list<PredAtom>> ContainerPred> : Type {
|
||||
// The type of elements in the container.
|
||||
Type elementType = ElementType;
|
||||
|
||||
// Call to retrieve.
|
||||
code getElementTypeCall = ?;
|
||||
|
||||
let predicate = PredCNF<
|
||||
!foldl(
|
||||
// Initialize with the predicate of the container.
|
||||
ContainerPred,
|
||||
elementType.predicate.conditions, a, b,
|
||||
// Add constraints of the element type.
|
||||
!listconcat(a, [!foldl([]<PredAtom>, b, c, d,
|
||||
!listconcat(c, [PredAtom<
|
||||
!subst("{0}", !cast<string>(getElementTypeCall),
|
||||
!cast<string>(d.predCall))>]
|
||||
))]
|
||||
)
|
||||
)
|
||||
>;
|
||||
}
|
||||
|
||||
def IsVectorTypePred : PredAtom<"{0}.isa<VectorType>()">;
|
||||
|
||||
// Vector types.
|
||||
class Vector<Type t, list<int> dims> : Type {
|
||||
Type elementType = t;
|
||||
class Vector<Type t, list<int> dims> : ContainerType<t, [[IsVectorTypePred]]> {
|
||||
list<int> dimensions = dims;
|
||||
let getElementTypeCall = "{0}.cast<VectorType>().getElementType()";
|
||||
// TODO: match dims in predicate.
|
||||
}
|
||||
|
||||
def IsTensorTypePred : PredAtom<"{0}.isa<TensorType>()">;
|
||||
|
||||
// Tensor type.
|
||||
class TypedTensor<Type t> : ContainerType<t, [[IsTensorTypePred]]> {
|
||||
let getElementTypeCall = "{0}.cast<TensorType>().getElementType()";
|
||||
}
|
||||
|
||||
// This represents a generic tensor without constraints on elemental type,
|
||||
// rank, size.
|
||||
def Tensor : Type;
|
||||
def Tensor : Type {
|
||||
let predicate = PredCNF<[[IsTensorTypePred]]>;
|
||||
}
|
||||
|
||||
def F32Tensor : TypedTensor<F32>;
|
||||
|
||||
// String type.
|
||||
def String : Type;
|
||||
|
||||
// Type corresponding to derived attribute.
|
||||
def DerivedAttrBody : Type;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -28,6 +28,7 @@
|
|||
#include "llvm/ADT/StringRef.h"
|
||||
|
||||
namespace llvm {
|
||||
class CodeInit;
|
||||
class DefInit;
|
||||
class Record;
|
||||
class StringInit;
|
||||
|
@ -69,6 +70,10 @@ public:
|
|||
|
||||
// Operations operand accessors.
|
||||
struct Operand {
|
||||
bool hasMatcher() const;
|
||||
// Return the matcher template for the operand type.
|
||||
std::string createTypeMatcherTemplate() const;
|
||||
|
||||
llvm::StringInit *name;
|
||||
llvm::DefInit *defInit;
|
||||
};
|
||||
|
@ -85,6 +90,7 @@ public:
|
|||
using Argument = llvm::PointerUnion<Attribute *, Operand *>;
|
||||
Argument getArg(int index);
|
||||
StringRef getArgName(int index) const;
|
||||
int getNumArgs() const { return operands.size() + attributes.size(); }
|
||||
|
||||
private:
|
||||
// Populates the operands and attributes.
|
||||
|
|
|
@ -0,0 +1,48 @@
|
|||
//===- Predicate.h - Predicate class ----------------------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Wrapper around predicates defined in TableGen.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_TABLEGEN_PREDICATE_H_
|
||||
#define MLIR_TABLEGEN_PREDICATE_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
|
||||
namespace llvm {
|
||||
class ListInit;
|
||||
} // end namespace llvm
|
||||
|
||||
namespace mlir {
|
||||
|
||||
// Predicate in conjunctive normal form.
|
||||
class PredCNF {
|
||||
public:
|
||||
PredCNF(llvm::ListInit *conditions) : conditions(conditions) {}
|
||||
|
||||
// Return template string to construct matcher corresponding to predicate in
|
||||
// CNF form with '{0}' representing the type.
|
||||
std::string createTypeMatcherTemplate() const;
|
||||
|
||||
private:
|
||||
llvm::ListInit *conditions;
|
||||
};
|
||||
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_PREDICATE_H_
|
|
@ -20,6 +20,7 @@
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
|
@ -139,3 +140,15 @@ void Operator::populateOperandsAndAttributes() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
bool mlir::Operator::Operand::hasMatcher() const {
|
||||
llvm::Init *matcher = defInit->getDef()->getValue("predicate")->getValue();
|
||||
return !isa<llvm::UnsetInit>(matcher);
|
||||
}
|
||||
|
||||
std::string mlir::Operator::Operand::createTypeMatcherTemplate() const {
|
||||
auto predicate = defInit->getDef()->getValue("predicate")->getValue();
|
||||
auto predCnf = cast<llvm::DefInit>(predicate);
|
||||
PredCNF pred(predCnf->getDef()->getValueAsListInit("conditions"));
|
||||
return pred.createTypeMatcherTemplate();
|
||||
}
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
//===- Predicate.cpp - Predicate class ------------------------------------===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
//
|
||||
// Wrapper around predicates defined in TableGen.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "llvm/TableGen/Error.h"
|
||||
#include "llvm/TableGen/Record.h"
|
||||
|
||||
using namespace mlir;
|
||||
|
||||
std::string mlir::PredCNF::createTypeMatcherTemplate() const {
|
||||
std::string outString;
|
||||
llvm::raw_string_ostream ss(outString);
|
||||
auto conjunctiveList = conditions;
|
||||
bool firstDisjunctive = true;
|
||||
for (auto disjunctiveInit : *conjunctiveList) {
|
||||
ss << (firstDisjunctive ? "(" : " && (");
|
||||
firstDisjunctive = false;
|
||||
bool firstConjunctive = true;
|
||||
for (auto atom : *cast<llvm::ListInit>(disjunctiveInit)) {
|
||||
auto predAtom = cast<llvm::DefInit>(atom)->getDef();
|
||||
ss << (firstConjunctive ? "" : " || ")
|
||||
<< (predAtom->getValueAsBit("negated") ? "!" : "")
|
||||
<< predAtom->getValueAsString("predCall");
|
||||
firstConjunctive = false;
|
||||
}
|
||||
ss << ")";
|
||||
}
|
||||
ss.flush();
|
||||
return outString;
|
||||
}
|
|
@ -360,7 +360,7 @@ void OpEmitter::emitVerifier() {
|
|||
auto valueInit = def.getValueInit("verifier");
|
||||
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
|
||||
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
|
||||
if (!hasCustomVerify && op.getNumAttributes() == 0)
|
||||
if (!hasCustomVerify && op.getNumArgs() == 0)
|
||||
return;
|
||||
|
||||
os << " bool verify() const {\n";
|
||||
|
@ -384,6 +384,20 @@ void OpEmitter::emitVerifier() {
|
|||
<< name << "'\");\n";
|
||||
}
|
||||
|
||||
// TODO: Handle variadic.
|
||||
int opIndex = 0;
|
||||
for (const auto &operand : op.getOperands()) {
|
||||
// TODO: Commonality between matchers could be extracted to have a more
|
||||
// concise code.
|
||||
if (operand.hasMatcher()) {
|
||||
auto pred =
|
||||
"if (!(" + operand.createTypeMatcherTemplate() + ")) return false;\n";
|
||||
os.indent(4) << formatv(pred, "this->getInstruction()->getOperand(" +
|
||||
Twine(opIndex) + ")->getType()");
|
||||
}
|
||||
++opIndex;
|
||||
}
|
||||
|
||||
if (hasCustomVerify)
|
||||
os << " " << codeInit->getValue() << "\n";
|
||||
else
|
||||
|
|
|
@ -21,6 +21,7 @@
|
|||
|
||||
#include "mlir/TableGen/GenInfo.h"
|
||||
#include "mlir/TableGen/Operator.h"
|
||||
#include "mlir/TableGen/Predicate.h"
|
||||
#include "llvm/ADT/StringExtras.h"
|
||||
#include "llvm/ADT/StringSet.h"
|
||||
#include "llvm/Support/CommandLine.h"
|
||||
|
@ -130,7 +131,8 @@ void Pattern::collectBoundArguments(DagInit *tree) {
|
|||
}
|
||||
|
||||
// Helper function to match patterns.
|
||||
static void matchOp(DagInit *tree, int depth, raw_ostream &os) {
|
||||
static void matchOp(Record *pattern, DagInit *tree, int depth,
|
||||
raw_ostream &os) {
|
||||
Operator op(cast<DefInit>(tree->getOperator())->getDef());
|
||||
int indent = 4 + 2 * depth;
|
||||
// Skip the operand matching at depth 0 as the pattern rewriter already does.
|
||||
|
@ -141,6 +143,11 @@ static void matchOp(DagInit *tree, int depth, raw_ostream &os) {
|
|||
"if (!op{0}->isa<{1}>()) return matchFailure();\n", depth,
|
||||
op.qualifiedCppClassName());
|
||||
}
|
||||
if (tree->getNumArgs() != op.getNumArgs())
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
Twine("mismatch in number of arguments to op '") +
|
||||
op.getOperationName() +
|
||||
"' in pattern and op's definition");
|
||||
for (int i = 0, e = tree->getNumArgs(); i != e; ++i) {
|
||||
auto arg = tree->getArg(i);
|
||||
if (auto argTree = dyn_cast<DagInit>(arg)) {
|
||||
|
@ -148,10 +155,40 @@ static void matchOp(DagInit *tree, int depth, raw_ostream &os) {
|
|||
os.indent(indent + 2) << formatv(
|
||||
"auto op{0} = op{1}->getOperand({2})->getDefiningInst();\n",
|
||||
depth + 1, depth, i);
|
||||
matchOp(argTree, depth + 1, os);
|
||||
matchOp(pattern, argTree, depth + 1, os);
|
||||
os.indent(indent) << "}\n";
|
||||
continue;
|
||||
}
|
||||
|
||||
// Verify arguments.
|
||||
if (auto defInit = dyn_cast<DefInit>(arg)) {
|
||||
auto opArg = op.getArg(i);
|
||||
// Verify operands.
|
||||
if (auto *operand = opArg.dyn_cast<Operator::Operand *>()) {
|
||||
// Skip verification where not needed due to definition of op.
|
||||
if (operand->defInit == defInit)
|
||||
goto SkipOperandVerification;
|
||||
|
||||
if (!defInit->getDef()->isSubClassOf("Type"))
|
||||
PrintFatalError(pattern->getLoc(),
|
||||
"type argument required for operand");
|
||||
|
||||
// TODO(jpienaar): Factor out type class and move these there.
|
||||
auto predicate = defInit->getDef()->getValue("predicate")->getValue();
|
||||
auto predCnf = cast<DefInit>(predicate);
|
||||
auto conjunctiveList =
|
||||
predCnf->getDef()->getValueAsListInit("conditions");
|
||||
PredCNF pred(conjunctiveList);
|
||||
os.indent(indent)
|
||||
<< "if (!("
|
||||
<< formatv(pred.createTypeMatcherTemplate().c_str(),
|
||||
formatv("op{0}->getOperand({1})->getType()", depth, i))
|
||||
<< ")) return matchFailure();\n";
|
||||
}
|
||||
}
|
||||
SkipOperandVerification:
|
||||
// TODO(jpienaar): Verify attributes.
|
||||
|
||||
auto name = tree->getArgNameStr(i);
|
||||
if (name.empty())
|
||||
continue;
|
||||
|
@ -168,7 +205,7 @@ void Pattern::emitMatcher(DagInit *tree) {
|
|||
if (op0->getNumResults() != 1) return matchFailure();
|
||||
auto state = std::make_unique<MatchedState>();)"
|
||||
<< "\n";
|
||||
matchOp(tree, 0, os);
|
||||
matchOp(pattern, tree, 0, os);
|
||||
os.indent(4) << "return matchSuccess(std::move(state));\n }\n";
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue