forked from OSchip/llvm-project
[ODS] Support region names and constraints
Similar to arguments and results, now we require region definition in ops to be specified as a DAG expression with the 'region' operator. This way we can specify the constraints for each region and optionally give the region a name. Two kinds of region constraints are added, one allowing any region, and the other requires a certain number of blocks. -- PiperOrigin-RevId: 250790211
This commit is contained in:
parent
60d6249fbd
commit
3650df50dd
|
@ -135,6 +135,8 @@ class Concat<string pre, Pred child, string suf> :
|
|||
// Constraint definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// TODO(b/130064155): Merge Constraints into Pred.
|
||||
|
||||
// Base class for named constraints.
|
||||
//
|
||||
// An op's operands/attributes/results can have various requirements, e.g.,
|
||||
|
@ -170,6 +172,10 @@ class TypeConstraint<Pred predicate, string description = ""> :
|
|||
class AttrConstraint<Pred predicate, string description = ""> :
|
||||
Constraint<predicate, description>;
|
||||
|
||||
// Subclass for constraints on a region.
|
||||
class RegionConstraint<Pred predicate, string description = ""> :
|
||||
Constraint<predicate, description>;
|
||||
|
||||
// How to use these constraint categories:
|
||||
//
|
||||
// * Use TypeConstraint to specify
|
||||
|
@ -795,6 +801,21 @@ class IntArrayNthElemMinValue<int index, int min> : AttrConstraint<
|
|||
def IsNullAttr : AttrConstraint<
|
||||
CPred<"!$_self">, "empty attribute (for optional attributes)">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Region definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
class Region<Pred condition, string descr = ""> :
|
||||
RegionConstraint<condition, descr>;
|
||||
|
||||
// Any region.
|
||||
def AnyRegion : Region<CPred<"true">, "any region">;
|
||||
|
||||
// A region with the given number of blocks.
|
||||
class SizedRegion<int numBlocks> : Region<
|
||||
CPred<"$_self.getBlocks().size() == " # numBlocks>,
|
||||
"region with " # numBlocks # " blocks">;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpTrait definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -869,6 +890,9 @@ def ins;
|
|||
// Marker used to identify the result list for an op.
|
||||
def outs;
|
||||
|
||||
// Marker used to identify the region list for an op.
|
||||
def region;
|
||||
|
||||
// Class for defining a custom builder.
|
||||
//
|
||||
// TableGen generates several generic builders for each op by default (see
|
||||
|
@ -916,9 +940,8 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
|
|||
// The list of results of the op. Default to 0 results.
|
||||
dag results = (outs);
|
||||
|
||||
// How many regions this op has.
|
||||
// TODO(b/133479568): Enhance to support advanced region usage cases
|
||||
int numRegions = 0;
|
||||
// The list of regions of the op. Default to 0 regions.
|
||||
dag regions = (region);
|
||||
|
||||
// Attribute getters can be added to the op by adding an Attr member
|
||||
// with the name and type of the attribute. E.g., adding int attribute
|
||||
|
|
|
@ -72,7 +72,7 @@ def SPV_ModuleOp : SPV_Op<"module", []> {
|
|||
|
||||
let results = (outs);
|
||||
|
||||
let numRegions = 1;
|
||||
let regions = (region AnyRegion:$body);
|
||||
|
||||
// Custom parser and printer implemented by static functions in SPVOps.cpp
|
||||
let parser = [{ return parseModule(parser, result); }];
|
||||
|
|
|
@ -57,7 +57,7 @@ public:
|
|||
StringRef getDescription() const;
|
||||
|
||||
// Constraint kind
|
||||
enum Kind { CK_Type, CK_Attr, CK_Uncategorized };
|
||||
enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized };
|
||||
|
||||
Kind getKind() const { return kind; }
|
||||
|
||||
|
|
|
@ -27,6 +27,7 @@
|
|||
#include "mlir/TableGen/Attribute.h"
|
||||
#include "mlir/TableGen/Dialect.h"
|
||||
#include "mlir/TableGen/OpTrait.h"
|
||||
#include "mlir/TableGen/Region.h"
|
||||
#include "mlir/TableGen/Type.h"
|
||||
#include "llvm/ADT/PointerUnion.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
|
@ -129,8 +130,15 @@ public:
|
|||
// requiring the raw MLIR trait here.
|
||||
bool hasTrait(llvm::StringRef trait) const;
|
||||
|
||||
using const_region_iterator = const NamedRegion *;
|
||||
const_region_iterator region_begin() const;
|
||||
const_region_iterator region_end() const;
|
||||
llvm::iterator_range<const_region_iterator> getRegions() const;
|
||||
|
||||
// Returns the number of regions.
|
||||
int getNumRegions() const;
|
||||
unsigned getNumRegions() const;
|
||||
// Returns the `index`-th region.
|
||||
const NamedRegion &getRegion(unsigned index) const;
|
||||
|
||||
// Trait.
|
||||
using const_trait_iterator = const OpTrait *;
|
||||
|
@ -177,8 +185,8 @@ private:
|
|||
// The traits of the op.
|
||||
SmallVector<OpTrait, 4> traits;
|
||||
|
||||
// The number of regions of this op.
|
||||
int numRegions = 0;
|
||||
// The regions of this op.
|
||||
SmallVector<NamedRegion, 1> regions;
|
||||
|
||||
// The number of native attributes stored in the leading positions of
|
||||
// `attributes`.
|
||||
|
|
|
@ -0,0 +1,45 @@
|
|||
//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===//
|
||||
//
|
||||
// Copyright 2019 The MLIR Authors.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
// =============================================================================
|
||||
|
||||
#ifndef MLIR_TABLEGEN_REGION_H_
|
||||
#define MLIR_TABLEGEN_REGION_H_
|
||||
|
||||
#include "mlir/Support/LLVM.h"
|
||||
#include "mlir/TableGen/Constraint.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace tblgen {
|
||||
|
||||
// Wrapper class providing helper methods for accessing Region defined in
|
||||
// TableGen.
|
||||
class Region : public Constraint {
|
||||
public:
|
||||
using Constraint::Constraint;
|
||||
|
||||
static bool classof(const Constraint *c) { return c->getKind() == CK_Region; }
|
||||
};
|
||||
|
||||
// A struct bundling a region's constraint and its name.
|
||||
struct NamedRegion {
|
||||
StringRef name;
|
||||
Region constraint;
|
||||
};
|
||||
|
||||
} // end namespace tblgen
|
||||
} // end namespace mlir
|
||||
|
||||
#endif // MLIR_TABLEGEN_REGION_H_
|
|
@ -30,6 +30,8 @@ Constraint::Constraint(const llvm::Record *record)
|
|||
kind = CK_Type;
|
||||
} else if (record->isSubClassOf("AttrConstraint")) {
|
||||
kind = CK_Attr;
|
||||
} else if (record->isSubClassOf("RegionConstraint")) {
|
||||
kind = CK_Region;
|
||||
} else {
|
||||
assert(record->isSubClassOf("Constraint"));
|
||||
}
|
||||
|
|
|
@ -146,7 +146,24 @@ bool tblgen::Operator::hasTrait(StringRef trait) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
int tblgen::Operator::getNumRegions() const { return numRegions; }
|
||||
tblgen::Operator::const_region_iterator tblgen::Operator::region_begin() const {
|
||||
return regions.begin();
|
||||
}
|
||||
|
||||
tblgen::Operator::const_region_iterator tblgen::Operator::region_end() const {
|
||||
return regions.end();
|
||||
}
|
||||
|
||||
llvm::iterator_range<tblgen::Operator::const_region_iterator>
|
||||
tblgen::Operator::getRegions() const {
|
||||
return {region_begin(), region_end()};
|
||||
}
|
||||
|
||||
unsigned tblgen::Operator::getNumRegions() const { return regions.size(); }
|
||||
|
||||
const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const {
|
||||
return regions[index];
|
||||
}
|
||||
|
||||
auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
|
||||
return traits.begin();
|
||||
|
@ -269,9 +286,21 @@ void tblgen::Operator::populateOpStructure() {
|
|||
traits.push_back(OpTrait::create(traitInit));
|
||||
|
||||
// Handle regions
|
||||
numRegions = def.getValueAsInt("numRegions");
|
||||
if (numRegions < 0)
|
||||
PrintFatalError(def.getLoc(), "numRegions cannot be negative");
|
||||
auto *regionsDag = def.getValueAsDag("regions");
|
||||
auto *regionsOp = dyn_cast<DefInit>(regionsDag->getOperator());
|
||||
if (!regionsOp || regionsOp->getDef()->getName() != "region") {
|
||||
PrintFatalError(def.getLoc(), "'regions' must have 'region' directive");
|
||||
}
|
||||
|
||||
for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) {
|
||||
auto name = regionsDag->getArgNameStr(i);
|
||||
auto *regionInit = dyn_cast<DefInit>(regionsDag->getArg(i));
|
||||
if (!regionInit) {
|
||||
PrintFatalError(def.getLoc(),
|
||||
Twine("undefined kind for region #") + Twine(i));
|
||||
}
|
||||
regions.push_back({name, Region(regionInit->getDef())});
|
||||
}
|
||||
}
|
||||
|
||||
ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
|
||||
|
|
|
@ -1,5 +1,9 @@
|
|||
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test the number of regions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @correct_number_of_regions() {
|
||||
// CHECK: test.two_region_op
|
||||
"test.two_region_op"()(
|
||||
|
@ -11,7 +15,7 @@ func @correct_number_of_regions() {
|
|||
|
||||
// -----
|
||||
|
||||
func @missingk_regions() {
|
||||
func @missing_regions() {
|
||||
// expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}}
|
||||
"test.two_region_op"()(
|
||||
{"work"() : () -> ()}
|
||||
|
@ -30,3 +34,42 @@ func @extra_regions() {
|
|||
) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test SizedRegion
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
func @unnamed_region_has_wrong_number_of_blocks() {
|
||||
// expected-error@+1 {{region #1 failed to verify constraint: region with 1 blocks}}
|
||||
"test.sized_region_op"() (
|
||||
{
|
||||
"work"() : () -> ()
|
||||
br ^next1
|
||||
^next1:
|
||||
"work"() : () -> ()
|
||||
},
|
||||
{
|
||||
"work"() : () -> ()
|
||||
br ^next2
|
||||
^next2:
|
||||
"work"() : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test region name in error message
|
||||
func @named_region_has_wrong_number_of_blocks() {
|
||||
// expected-error@+1 {{region #0 ('my_region') failed to verify constraint: region with 2 blocks}}
|
||||
"test.sized_region_op"() (
|
||||
{
|
||||
"work"() : () -> ()
|
||||
},
|
||||
{
|
||||
"work"() : () -> ()
|
||||
}) : () -> ()
|
||||
return
|
||||
}
|
||||
|
|
|
@ -129,7 +129,11 @@ def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TwoRegionOp : TEST_Op<"two_region_op", []> {
|
||||
let numRegions = 2;
|
||||
let regions = (region AnyRegion, AnyRegion);
|
||||
}
|
||||
|
||||
def SizedRegionOp : TEST_Op<"sized_region_op", []> {
|
||||
let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>);
|
||||
}
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
|
|
@ -368,6 +368,10 @@ private:
|
|||
// Generates verify method for the operation.
|
||||
void genVerifier();
|
||||
|
||||
// Generates verify statements for regions in the operation.
|
||||
// The generated code will be attached to `body`.
|
||||
void genRegionVerifier(OpMethodBody &body);
|
||||
|
||||
// Generates the traits used by the object.
|
||||
void genTraits();
|
||||
|
||||
|
@ -388,12 +392,17 @@ private:
|
|||
|
||||
// The C++ code builder for this op
|
||||
OpClass opClass;
|
||||
|
||||
// The format context for verification code generation.
|
||||
FmtContext verifyCtx;
|
||||
};
|
||||
} // end anonymous namespace
|
||||
|
||||
OpEmitter::OpEmitter(const Record &def)
|
||||
: def(def), op(def),
|
||||
opClass(op.getCppClassName(), op.getExtraClassDeclaration()) {
|
||||
verifyCtx.withOp("(*this->getOperation())");
|
||||
|
||||
genTraits();
|
||||
// Generate C++ code for various op methods. The order here determines the
|
||||
// methods in the generated file.
|
||||
|
@ -900,13 +909,11 @@ void OpEmitter::genVerifier() {
|
|||
|
||||
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
|
||||
auto &body = method.body();
|
||||
FmtContext fctx;
|
||||
fctx.withOp("(*this->getOperation())");
|
||||
|
||||
// Populate substitutions for attributes and named operands and results.
|
||||
for (const auto &namedAttr : op.getAttributes())
|
||||
fctx.addSubst(namedAttr.name,
|
||||
formatv("(&this->getAttr(\"{0}\"))", namedAttr.name));
|
||||
verifyCtx.addSubst(namedAttr.name,
|
||||
formatv("(&this->getAttr(\"{0}\"))", namedAttr.name));
|
||||
for (int i = 0, e = op.getNumOperands(); i < e; ++i) {
|
||||
auto &value = op.getOperand(i);
|
||||
// Skip from from first variadic operands for now. Else getOperand index
|
||||
|
@ -914,8 +921,8 @@ void OpEmitter::genVerifier() {
|
|||
if (value.isVariadic())
|
||||
break;
|
||||
if (!value.name.empty())
|
||||
fctx.addSubst(value.name,
|
||||
formatv("this->getOperation()->getOperand({0})", i));
|
||||
verifyCtx.addSubst(value.name,
|
||||
formatv("this->getOperation()->getOperand({0})", i));
|
||||
}
|
||||
for (int i = 0, e = op.getNumResults(); i < e; ++i) {
|
||||
auto &value = op.getResult(i);
|
||||
|
@ -924,8 +931,8 @@ void OpEmitter::genVerifier() {
|
|||
if (value.isVariadic())
|
||||
break;
|
||||
if (!value.name.empty())
|
||||
fctx.addSubst(value.name,
|
||||
formatv("this->getOperation()->getResult({0})", i));
|
||||
verifyCtx.addSubst(value.name,
|
||||
formatv("this->getOperation()->getResult({0})", i));
|
||||
}
|
||||
|
||||
// Verify the attributes have the correct type.
|
||||
|
@ -955,11 +962,12 @@ void OpEmitter::genVerifier() {
|
|||
|
||||
auto attrPred = attr.getPredicate();
|
||||
if (!attrPred.isNull()) {
|
||||
body << tgfmt(" if (!($0)) return emitOpError(\"attribute '$1' "
|
||||
"failed to satisfy constraint: $2\");\n",
|
||||
/*ctx=*/nullptr,
|
||||
tgfmt(attrPred.getCondition(), &fctx.withSelf(varName)),
|
||||
attrName, attr.getDescription());
|
||||
body << tgfmt(
|
||||
" if (!($0)) return emitOpError(\"attribute '$1' "
|
||||
"failed to satisfy constraint: $2\");\n",
|
||||
/*ctx=*/nullptr,
|
||||
tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)),
|
||||
attrName, attr.getDescription());
|
||||
}
|
||||
|
||||
body << " }\n";
|
||||
|
@ -977,10 +985,11 @@ void OpEmitter::genVerifier() {
|
|||
if (value.hasPredicate()) {
|
||||
auto description = value.constraint.getDescription();
|
||||
body << " if (!("
|
||||
<< tgfmt(value.constraint.getConditionTemplate(),
|
||||
&fctx.withSelf("this->getOperation()->get" +
|
||||
Twine(isOperand ? "Operand" : "Result") +
|
||||
"(" + Twine(index) + ")->getType()"))
|
||||
<< tgfmt(
|
||||
value.constraint.getConditionTemplate(),
|
||||
&verifyCtx.withSelf("this->getOperation()->get" +
|
||||
Twine(isOperand ? "Operand" : "Result") +
|
||||
"(" + Twine(index) + ")->getType()"))
|
||||
<< ")) {\n";
|
||||
body << " return emitOpError(\"" << (isOperand ? "operand" : "result")
|
||||
<< " #" << index
|
||||
|
@ -1000,19 +1009,14 @@ void OpEmitter::genVerifier() {
|
|||
|
||||
for (auto &trait : op.getTraits()) {
|
||||
if (auto t = dyn_cast<tblgen::PredOpTrait>(&trait)) {
|
||||
body << tgfmt(" if (!($0))\n return emitOpError(\""
|
||||
"failed to verify that $1\");\n",
|
||||
&fctx, tgfmt(t->getPredTemplate(), &fctx),
|
||||
body << tgfmt(" if (!($0)) {\n "
|
||||
"return emitOpError(\"failed to verify that $1\");\n }\n",
|
||||
&verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx),
|
||||
t->getDescription());
|
||||
}
|
||||
}
|
||||
|
||||
// Verify this op has the correct number of regions
|
||||
body << formatv(
|
||||
" if (this->getOperation()->getNumRegions() != {0}) \n return "
|
||||
"emitOpError(\"has incorrect number of regions: expected {0} but found "
|
||||
"\") << this->getOperation()->getNumRegions();\n",
|
||||
op.getNumRegions());
|
||||
genRegionVerifier(body);
|
||||
|
||||
if (hasCustomVerify)
|
||||
body << codeInit->getValue() << "\n";
|
||||
|
@ -1020,6 +1024,36 @@ void OpEmitter::genVerifier() {
|
|||
body << " return mlir::success();\n";
|
||||
}
|
||||
|
||||
void OpEmitter::genRegionVerifier(OpMethodBody &body) {
|
||||
unsigned numRegions = op.getNumRegions();
|
||||
|
||||
// Verify this op has the correct number of regions
|
||||
body << formatv(
|
||||
" if (this->getOperation()->getNumRegions() != {0}) {\n "
|
||||
"return emitOpError(\"has incorrect number of regions: expected {0} but "
|
||||
"found \") << this->getOperation()->getNumRegions();\n }\n",
|
||||
numRegions);
|
||||
|
||||
for (unsigned i = 0; i < numRegions; ++i) {
|
||||
const auto ®ion = op.getRegion(i);
|
||||
|
||||
std::string name = formatv("#{0}", i);
|
||||
if (!region.name.empty()) {
|
||||
name += formatv(" ('{0}')", region.name);
|
||||
}
|
||||
|
||||
auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str();
|
||||
auto constraint = tgfmt(region.constraint.getConditionTemplate(),
|
||||
&verifyCtx.withSelf(getRegion))
|
||||
.str();
|
||||
|
||||
body << formatv(" if (!({0})) {\n "
|
||||
"return emitOpError(\"region {1} failed to verify "
|
||||
"constraint: {2}\");\n }\n",
|
||||
constraint, name, region.constraint.getDescription());
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genTraits() {
|
||||
int numResults = op.getNumResults();
|
||||
int numVariadicResults = op.getNumVariadicResults();
|
||||
|
|
Loading…
Reference in New Issue