[ODS] Support region names and constraints

Similar to arguments and results, now we require region definition in ops to
    be specified as a DAG expression with the 'region' operator. This way we can
    specify the constraints for each region and optionally give the region a name.

    Two kinds of region constraints are added, one allowing any region, and the
    other requires a certain number of blocks.

--

PiperOrigin-RevId: 250790211
This commit is contained in:
Lei Zhang 2019-05-30 16:50:16 -07:00 committed by Mehdi Amini
parent 60d6249fbd
commit 3650df50dd
10 changed files with 228 additions and 40 deletions

View File

@ -135,6 +135,8 @@ class Concat<string pre, Pred child, string suf> :
// Constraint definitions
//===----------------------------------------------------------------------===//
// TODO(b/130064155): Merge Constraints into Pred.
// Base class for named constraints.
//
// An op's operands/attributes/results can have various requirements, e.g.,
@ -170,6 +172,10 @@ class TypeConstraint<Pred predicate, string description = ""> :
class AttrConstraint<Pred predicate, string description = ""> :
Constraint<predicate, description>;
// Subclass for constraints on a region.
class RegionConstraint<Pred predicate, string description = ""> :
Constraint<predicate, description>;
// How to use these constraint categories:
//
// * Use TypeConstraint to specify
@ -795,6 +801,21 @@ class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
def IsNullAttr : AttrConstraint<
CPred<"!$_self">, "empty attribute (for optional attributes)">;
//===----------------------------------------------------------------------===//
// Region definitions
//===----------------------------------------------------------------------===//
class Region<Pred condition, string descr = ""> :
RegionConstraint<condition, descr>;
// Any region.
def AnyRegion : Region<CPred<"true">, "any region">;
// A region with the given number of blocks.
class SizedRegion<int numBlocks> : Region<
CPred<"$_self.getBlocks().size() == " # numBlocks>,
"region with " # numBlocks # " blocks">;
//===----------------------------------------------------------------------===//
// OpTrait definitions
//===----------------------------------------------------------------------===//
@ -869,6 +890,9 @@ def ins;
// Marker used to identify the result list for an op.
def outs;
// Marker used to identify the region list for an op.
def region;
// Class for defining a custom builder.
//
// TableGen generates several generic builders for each op by default (see
@ -916,9 +940,8 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
// The list of results of the op. Default to 0 results.
dag results = (outs);
// How many regions this op has.
// TODO(b/133479568): Enhance to support advanced region usage cases
int numRegions = 0;
// The list of regions of the op. Default to 0 regions.
dag regions = (region);
// Attribute getters can be added to the op by adding an Attr member
// with the name and type of the attribute. E.g., adding int attribute

View File

@ -72,7 +72,7 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
let results = (outs);
let numRegions = 1;
let regions = (region AnyRegion:$body);
// Custom parser and printer implemented by static functions in SPVOps.cpp
let parser = [{ return parseModule(parser, result); }];

View File

@ -57,7 +57,7 @@ public:
StringRef getDescription() const;
// Constraint kind
enum Kind { CK_Type, CK_Attr, CK_Uncategorized };
enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized };
Kind getKind() const { return kind; }

View File

@ -27,6 +27,7 @@
#include "mlir/TableGen/Attribute.h"
#include "mlir/TableGen/Dialect.h"
#include "mlir/TableGen/OpTrait.h"
#include "mlir/TableGen/Region.h"
#include "mlir/TableGen/Type.h"
#include "llvm/ADT/PointerUnion.h"
#include "llvm/ADT/SmallVector.h"
@ -129,8 +130,15 @@ public:
// requiring the raw MLIR trait here.
bool hasTrait(llvm::StringRef trait) const;
using const_region_iterator = const NamedRegion *;
const_region_iterator region_begin() const;
const_region_iterator region_end() const;
llvm::iterator_range<const_region_iterator> getRegions() const;
// Returns the number of regions.
int getNumRegions() const;
unsigned getNumRegions() const;
// Returns the `index`-th region.
const NamedRegion &getRegion(unsigned index) const;
// Trait.
using const_trait_iterator = const OpTrait *;
@ -177,8 +185,8 @@ private:
// The traits of the op.
SmallVector<OpTrait, 4> traits;
// The number of regions of this op.
int numRegions = 0;
// The regions of this op.
SmallVector<NamedRegion, 1> regions;
// The number of native attributes stored in the leading positions of
// `attributes`.

View File

@ -0,0 +1,45 @@
//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===//
//
// Copyright 2019 The MLIR Authors.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#ifndef MLIR_TABLEGEN_REGION_H_
#define MLIR_TABLEGEN_REGION_H_
#include "mlir/Support/LLVM.h"
#include "mlir/TableGen/Constraint.h"
namespace mlir {
namespace tblgen {
// Wrapper class providing helper methods for accessing Region defined in
// TableGen.
class Region : public Constraint {
public:
using Constraint::Constraint;
static bool classof(const Constraint *c) { return c->getKind() == CK_Region; }
};
// A struct bundling a region's constraint and its name.
struct NamedRegion {
StringRef name;
Region constraint;
};
} // end namespace tblgen
} // end namespace mlir
#endif // MLIR_TABLEGEN_REGION_H_

View File

@ -30,6 +30,8 @@ Constraint::Constraint(const llvm::Record *record)
kind = CK_Type;
} else if (record->isSubClassOf("AttrConstraint")) {
kind = CK_Attr;
} else if (record->isSubClassOf("RegionConstraint")) {
kind = CK_Region;
} else {
assert(record->isSubClassOf("Constraint"));
}

View File

@ -146,7 +146,24 @@ bool tblgen::Operator::hasTrait(StringRef trait) const {
return false;
}
int tblgen::Operator::getNumRegions() const { return numRegions; }
tblgen::Operator::const_region_iterator tblgen::Operator::region_begin() const {
return regions.begin();
}
tblgen::Operator::const_region_iterator tblgen::Operator::region_end() const {
return regions.end();
}
llvm::iterator_range<tblgen::Operator::const_region_iterator>
tblgen::Operator::getRegions() const {
return {region_begin(), region_end()};
}
unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
return regions[index];
}
auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
return traits.begin();
@ -269,9 +286,21 @@ void tblgen::Operator::populateOpStructure() {
traits.push_back(OpTrait::create(traitInit));
// Handle regions
numRegions = def.getValueAsInt("numRegions");
if (numRegions < 0)
PrintFatalError(def.getLoc(), "numRegions cannot be negative");
auto *regionsDag = def.getValueAsDag("regions");
auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
if (!regionsOp || regionsOp->getDef()->getName() != "region") {
PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
}
for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
auto name = regionsDag->getArgNameStr(i);
auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
if (!regionInit) {
PrintFatalError(def.getLoc(),
Twine("undefined kind for region #") + Twine(i));
}
regions.push_back({name, Region(regionInit->getDef())});
}
}
ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }

View File

@ -1,5 +1,9 @@
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
//===----------------------------------------------------------------------===//
// Test the number of regions
//===----------------------------------------------------------------------===//
func @correct_number_of_regions() {
// CHECK: test.two_region_op
"test.two_region_op"()(
@ -11,7 +15,7 @@ func @correct_number_of_regions() {
// -----
func @missingk_regions() {
func @missing_regions() {
// expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}}
"test.two_region_op"()(
{"work"() : () -> ()}
@ -30,3 +34,42 @@ func @extra_regions() {
) : () -> ()
return
}
// -----
//===----------------------------------------------------------------------===//
// Test SizedRegion
//===----------------------------------------------------------------------===//
func @unnamed_region_has_wrong_number_of_blocks() {
// expected-error@+1 {{region #1 failed to verify constraint: region with 1 blocks}}
"test.sized_region_op"() (
{
"work"() : () -> ()
br ^next1
^next1:
"work"() : () -> ()
},
{
"work"() : () -> ()
br ^next2
^next2:
"work"() : () -> ()
}) : () -> ()
return
}
// -----
// Test region name in error message
func @named_region_has_wrong_number_of_blocks() {
// expected-error@+1 {{region #0 ('my_region') failed to verify constraint: region with 2 blocks}}
"test.sized_region_op"() (
{
"work"() : () -> ()
},
{
"work"() : () -> ()
}) : () -> ()
return
}

View File

@ -129,7 +129,11 @@ def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
//===----------------------------------------------------------------------===//
def TwoRegionOp : TEST_Op<"two_region_op", []> {
let numRegions = 2;
let regions = (region AnyRegion, AnyRegion);
}
def SizedRegionOp : TEST_Op<"sized_region_op", []> {
let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
}
#endif // TEST_OPS

View File

@ -368,6 +368,10 @@ private:
// Generates verify method for the operation.
void genVerifier();
// Generates verify statements for regions in the operation.
// The generated code will be attached to `body`.
void genRegionVerifier(OpMethodBody &body);
// Generates the traits used by the object.
void genTraits();
@ -388,12 +392,17 @@ private:
// The C++ code builder for this op
OpClass opClass;
// The format context for verification code generation.
FmtContext verifyCtx;
};
} // end anonymous namespace
OpEmitter::OpEmitter(const Record &def)
: def(def), op(def),
opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
verifyCtx.withOp("(*this->getOperation())");
genTraits();
// Generate C++ code for various op methods. The order here determines the
// methods in the generated file.
@ -900,12 +909,10 @@ void OpEmitter::genVerifier() {
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
auto &body = method.body();
FmtContext fctx;
fctx.withOp("(*this->getOperation())");
// Populate substitutions for attributes and named operands and results.
for (const auto &namedAttr : op.getAttributes())
fctx.addSubst(namedAttr.name,
verifyCtx.addSubst(namedAttr.name,
formatv("(&this->getAttr(\"{0}\"))", namedAttr.name));
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
auto &value = op.getOperand(i);
@ -914,7 +921,7 @@ void OpEmitter::genVerifier() {
if (value.isVariadic())
break;
if (!value.name.empty())
fctx.addSubst(value.name,
verifyCtx.addSubst(value.name,
formatv("this->getOperation()->getOperand({0})", i));
}
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
@ -924,7 +931,7 @@ void OpEmitter::genVerifier() {
if (value.isVariadic())
break;
if (!value.name.empty())
fctx.addSubst(value.name,
verifyCtx.addSubst(value.name,
formatv("this->getOperation()->getResult({0})", i));
}
@ -955,10 +962,11 @@ void OpEmitter::genVerifier() {
auto attrPred = attr.getPredicate();
if (!attrPred.isNull()) {
body << tgfmt(" if (!($0)) return emitOpError(\"attribute '$1' "
body << tgfmt(
" if (!($0)) return emitOpError(\"attribute '$1' "
"failed to satisfy constraint: $2\");\n",
/*ctx=*/nullptr,
tgfmt(attrPred.getCondition(), &fctx.withSelf(varName)),
tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)),
attrName, attr.getDescription());
}
@ -977,8 +985,9 @@ void OpEmitter::genVerifier() {
if (value.hasPredicate()) {
auto description = value.constraint.getDescription();
body << " if (!("
<< tgfmt(value.constraint.getConditionTemplate(),
&fctx.withSelf("this->getOperation()->get" +
<< tgfmt(
value.constraint.getConditionTemplate(),
&verifyCtx.withSelf("this->getOperation()->get" +
Twine(isOperand ? "Operand" : "Result") +
"(" + Twine(index) + ")->getType()"))
<< ")) {\n";
@ -1000,19 +1009,14 @@ void OpEmitter::genVerifier() {
for (auto &trait : op.getTraits()) {
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
body << tgfmt(" if (!($0))\n return emitOpError(\""
"failed to verify that $1\");\n",
&fctx, tgfmt(t->getPredTemplate(), &fctx),
body << tgfmt(" if (!($0)) {\n "
"return emitOpError(\"failed to verify that $1\");\n }\n",
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
t->getDescription());
}
}
// Verify this op has the correct number of regions
body << formatv(
" if (this->getOperation()->getNumRegions() != {0}) \n return "
"emitOpError(\"has incorrect number of regions: expected {0} but found "
"\") << this->getOperation()->getNumRegions();\n",
op.getNumRegions());
genRegionVerifier(body);
if (hasCustomVerify)
body << codeInit->getValue() << "\n";
@ -1020,6 +1024,36 @@ void OpEmitter::genVerifier() {
body << " return mlir::success();\n";
}
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
unsigned numRegions = op.getNumRegions();
// Verify this op has the correct number of regions
body << formatv(
" if (this->getOperation()->getNumRegions() != {0}) {\n "
"return emitOpError(\"has incorrect number of regions: expected {0} but "
"found \") << this->getOperation()->getNumRegions();\n }\n",
numRegions);
for (unsigned i = 0; i < numRegions; ++i) {
const auto &region = op.getRegion(i);
std::string name = formatv("#{0}", i);
if (!region.name.empty()) {
name += formatv(" ('{0}')", region.name);
}
auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str();
auto constraint = tgfmt(region.constraint.getConditionTemplate(),
&verifyCtx.withSelf(getRegion))
.str();
body << formatv(" if (!({0})) {\n "
"return emitOpError(\"region {1} failed to verify "
"constraint: {2}\");\n }\n",
constraint, name, region.constraint.getDescription());
}
}
void OpEmitter::genTraits() {
int numResults = op.getNumResults();
int numVariadicResults = op.getNumVariadicResults();