From 3650df50ddf21396cc300c12948d1488baef5ae5 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Thu, 30 May 2019 16:50:16 -0700 Subject: [PATCH] [ODS] Support region names and constraints Similar to arguments and results, now we require region definition in ops to be specified as a DAG expression with the 'region' operator. This way we can specify the constraints for each region and optionally give the region a name. Two kinds of region constraints are added, one allowing any region, and the other requires a certain number of blocks. -- PiperOrigin-RevId: 250790211 --- mlir/include/mlir/IR/OpBase.td | 29 ++++++- mlir/include/mlir/SPIRV/SPIRVStructureOps.td | 2 +- mlir/include/mlir/TableGen/Constraint.h | 2 +- mlir/include/mlir/TableGen/Operator.h | 14 +++- mlir/include/mlir/TableGen/Region.h | 45 ++++++++++ mlir/lib/TableGen/Constraint.cpp | 2 + mlir/lib/TableGen/Operator.cpp | 37 ++++++++- mlir/test/IR/region.mlir | 45 +++++++++- mlir/test/TestDialect/TestOps.td | 6 +- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 86 ++++++++++++++------ 10 files changed, 228 insertions(+), 40 deletions(-) create mode 100644 mlir/include/mlir/TableGen/Region.h diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index ce1a87b05320..f046fd2d7433 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -135,6 +135,8 @@ class Concat : // Constraint definitions //===----------------------------------------------------------------------===// +// TODO(b/130064155): Merge Constraints into Pred. + // Base class for named constraints. // // An op's operands/attributes/results can have various requirements, e.g., @@ -170,6 +172,10 @@ class TypeConstraint : class AttrConstraint : Constraint; +// Subclass for constraints on a region. +class RegionConstraint : + Constraint; + // How to use these constraint categories: // // * Use TypeConstraint to specify @@ -795,6 +801,21 @@ class IntArrayNthElemMinValue : AttrConstraint< def IsNullAttr : AttrConstraint< CPred<"!$_self">, "empty attribute (for optional attributes)">; +//===----------------------------------------------------------------------===// +// Region definitions +//===----------------------------------------------------------------------===// + +class Region : + RegionConstraint; + +// Any region. +def AnyRegion : Region, "any region">; + +// A region with the given number of blocks. +class SizedRegion : Region< + CPred<"$_self.getBlocks().size() == " # numBlocks>, + "region with " # numBlocks # " blocks">; + //===----------------------------------------------------------------------===// // OpTrait definitions //===----------------------------------------------------------------------===// @@ -869,6 +890,9 @@ def ins; // Marker used to identify the result list for an op. def outs; +// Marker used to identify the region list for an op. +def region; + // Class for defining a custom builder. // // TableGen generates several generic builders for each op by default (see @@ -916,9 +940,8 @@ class Op props = []> { // The list of results of the op. Default to 0 results. dag results = (outs); - // How many regions this op has. - // TODO(b/133479568): Enhance to support advanced region usage cases - int numRegions = 0; + // The list of regions of the op. Default to 0 regions. + dag regions = (region); // Attribute getters can be added to the op by adding an Attr member // with the name and type of the attribute. E.g., adding int attribute diff --git a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td index 7fef351d3e7c..bcb485c8ba1e 100644 --- a/mlir/include/mlir/SPIRV/SPIRVStructureOps.td +++ b/mlir/include/mlir/SPIRV/SPIRVStructureOps.td @@ -72,7 +72,7 @@ def SPV_ModuleOp : SPV_Op<"module", []> { let results = (outs); - let numRegions = 1; + let regions = (region AnyRegion:$body); // Custom parser and printer implemented by static functions in SPVOps.cpp let parser = [{ return parseModule(parser, result); }]; diff --git a/mlir/include/mlir/TableGen/Constraint.h b/mlir/include/mlir/TableGen/Constraint.h index f8b12d9f6e72..bcf207e5e937 100644 --- a/mlir/include/mlir/TableGen/Constraint.h +++ b/mlir/include/mlir/TableGen/Constraint.h @@ -57,7 +57,7 @@ public: StringRef getDescription() const; // Constraint kind - enum Kind { CK_Type, CK_Attr, CK_Uncategorized }; + enum Kind { CK_Attr, CK_Region, CK_Type, CK_Uncategorized }; Kind getKind() const { return kind; } diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 77b3a9ff410e..de2818e96f9c 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -27,6 +27,7 @@ #include "mlir/TableGen/Attribute.h" #include "mlir/TableGen/Dialect.h" #include "mlir/TableGen/OpTrait.h" +#include "mlir/TableGen/Region.h" #include "mlir/TableGen/Type.h" #include "llvm/ADT/PointerUnion.h" #include "llvm/ADT/SmallVector.h" @@ -129,8 +130,15 @@ public: // requiring the raw MLIR trait here. bool hasTrait(llvm::StringRef trait) const; + using const_region_iterator = const NamedRegion *; + const_region_iterator region_begin() const; + const_region_iterator region_end() const; + llvm::iterator_range getRegions() const; + // Returns the number of regions. - int getNumRegions() const; + unsigned getNumRegions() const; + // Returns the `index`-th region. + const NamedRegion &getRegion(unsigned index) const; // Trait. using const_trait_iterator = const OpTrait *; @@ -177,8 +185,8 @@ private: // The traits of the op. SmallVector traits; - // The number of regions of this op. - int numRegions = 0; + // The regions of this op. + SmallVector regions; // The number of native attributes stored in the leading positions of // `attributes`. diff --git a/mlir/include/mlir/TableGen/Region.h b/mlir/include/mlir/TableGen/Region.h new file mode 100644 index 000000000000..21dffe687f49 --- /dev/null +++ b/mlir/include/mlir/TableGen/Region.h @@ -0,0 +1,45 @@ +//===- TGRegion.h - TableGen region definitions -----------------*- C++ -*-===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef MLIR_TABLEGEN_REGION_H_ +#define MLIR_TABLEGEN_REGION_H_ + +#include "mlir/Support/LLVM.h" +#include "mlir/TableGen/Constraint.h" + +namespace mlir { +namespace tblgen { + +// Wrapper class providing helper methods for accessing Region defined in +// TableGen. +class Region : public Constraint { +public: + using Constraint::Constraint; + + static bool classof(const Constraint *c) { return c->getKind() == CK_Region; } +}; + +// A struct bundling a region's constraint and its name. +struct NamedRegion { + StringRef name; + Region constraint; +}; + +} // end namespace tblgen +} // end namespace mlir + +#endif // MLIR_TABLEGEN_REGION_H_ diff --git a/mlir/lib/TableGen/Constraint.cpp b/mlir/lib/TableGen/Constraint.cpp index 2656f4cfa05a..96f49bf12cae 100644 --- a/mlir/lib/TableGen/Constraint.cpp +++ b/mlir/lib/TableGen/Constraint.cpp @@ -30,6 +30,8 @@ Constraint::Constraint(const llvm::Record *record) kind = CK_Type; } else if (record->isSubClassOf("AttrConstraint")) { kind = CK_Attr; + } else if (record->isSubClassOf("RegionConstraint")) { + kind = CK_Region; } else { assert(record->isSubClassOf("Constraint")); } diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index d27db0d37f4c..cd3537dfdbfa 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -146,7 +146,24 @@ bool tblgen::Operator::hasTrait(StringRef trait) const { return false; } -int tblgen::Operator::getNumRegions() const { return numRegions; } +tblgen::Operator::const_region_iterator tblgen::Operator::region_begin() const { + return regions.begin(); +} + +tblgen::Operator::const_region_iterator tblgen::Operator::region_end() const { + return regions.end(); +} + +llvm::iterator_range +tblgen::Operator::getRegions() const { + return {region_begin(), region_end()}; +} + +unsigned tblgen::Operator::getNumRegions() const { return regions.size(); } + +const tblgen::NamedRegion &tblgen::Operator::getRegion(unsigned index) const { + return regions[index]; +} auto tblgen::Operator::trait_begin() const -> const_trait_iterator { return traits.begin(); @@ -269,9 +286,21 @@ void tblgen::Operator::populateOpStructure() { traits.push_back(OpTrait::create(traitInit)); // Handle regions - numRegions = def.getValueAsInt("numRegions"); - if (numRegions < 0) - PrintFatalError(def.getLoc(), "numRegions cannot be negative"); + auto *regionsDag = def.getValueAsDag("regions"); + auto *regionsOp = dyn_cast(regionsDag->getOperator()); + if (!regionsOp || regionsOp->getDef()->getName() != "region") { + PrintFatalError(def.getLoc(), "'regions' must have 'region' directive"); + } + + for (unsigned i = 0, e = regionsDag->getNumArgs(); i < e; ++i) { + auto name = regionsDag->getArgNameStr(i); + auto *regionInit = dyn_cast(regionsDag->getArg(i)); + if (!regionInit) { + PrintFatalError(def.getLoc(), + Twine("undefined kind for region #") + Twine(i)); + } + regions.push_back({name, Region(regionInit->getDef())}); + } } ArrayRef tblgen::Operator::getLoc() const { return def.getLoc(); } diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir index 702e56ddabe8..03b366cf32c9 100644 --- a/mlir/test/IR/region.mlir +++ b/mlir/test/IR/region.mlir @@ -1,5 +1,9 @@ // RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s +//===----------------------------------------------------------------------===// +// Test the number of regions +//===----------------------------------------------------------------------===// + func @correct_number_of_regions() { // CHECK: test.two_region_op "test.two_region_op"()( @@ -11,7 +15,7 @@ func @correct_number_of_regions() { // ----- -func @missingk_regions() { +func @missing_regions() { // expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}} "test.two_region_op"()( {"work"() : () -> ()} @@ -30,3 +34,42 @@ func @extra_regions() { ) : () -> () return } + +// ----- + +//===----------------------------------------------------------------------===// +// Test SizedRegion +//===----------------------------------------------------------------------===// + +func @unnamed_region_has_wrong_number_of_blocks() { + // expected-error@+1 {{region #1 failed to verify constraint: region with 1 blocks}} + "test.sized_region_op"() ( + { + "work"() : () -> () + br ^next1 + ^next1: + "work"() : () -> () + }, + { + "work"() : () -> () + br ^next2 + ^next2: + "work"() : () -> () + }) : () -> () + return +} + +// ----- + +// Test region name in error message +func @named_region_has_wrong_number_of_blocks() { + // expected-error@+1 {{region #0 ('my_region') failed to verify constraint: region with 2 blocks}} + "test.sized_region_op"() ( + { + "work"() : () -> () + }, + { + "work"() : () -> () + }) : () -> () + return +} diff --git a/mlir/test/TestDialect/TestOps.td b/mlir/test/TestDialect/TestOps.td index 5ffbcbcd7673..814bc72d849f 100644 --- a/mlir/test/TestDialect/TestOps.td +++ b/mlir/test/TestDialect/TestOps.td @@ -129,7 +129,11 @@ def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr:$attr)>; //===----------------------------------------------------------------------===// def TwoRegionOp : TEST_Op<"two_region_op", []> { - let numRegions = 2; + let regions = (region AnyRegion, AnyRegion); +} + +def SizedRegionOp : TEST_Op<"sized_region_op", []> { + let regions = (region SizedRegion<2>:$my_region, SizedRegion<1>); } #endif // TEST_OPS diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 1dc9a95c5ea6..a7b347e6a57a 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -368,6 +368,10 @@ private: // Generates verify method for the operation. void genVerifier(); + // Generates verify statements for regions in the operation. + // The generated code will be attached to `body`. + void genRegionVerifier(OpMethodBody &body); + // Generates the traits used by the object. void genTraits(); @@ -388,12 +392,17 @@ private: // The C++ code builder for this op OpClass opClass; + + // The format context for verification code generation. + FmtContext verifyCtx; }; } // end anonymous namespace OpEmitter::OpEmitter(const Record &def) : def(def), op(def), opClass(op.getCppClassName(), op.getExtraClassDeclaration()) { + verifyCtx.withOp("(*this->getOperation())"); + genTraits(); // Generate C++ code for various op methods. The order here determines the // methods in the generated file. @@ -900,13 +909,11 @@ void OpEmitter::genVerifier() { auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/""); auto &body = method.body(); - FmtContext fctx; - fctx.withOp("(*this->getOperation())"); // Populate substitutions for attributes and named operands and results. for (const auto &namedAttr : op.getAttributes()) - fctx.addSubst(namedAttr.name, - formatv("(&this->getAttr(\"{0}\"))", namedAttr.name)); + verifyCtx.addSubst(namedAttr.name, + formatv("(&this->getAttr(\"{0}\"))", namedAttr.name)); for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); // Skip from from first variadic operands for now. Else getOperand index @@ -914,8 +921,8 @@ void OpEmitter::genVerifier() { if (value.isVariadic()) break; if (!value.name.empty()) - fctx.addSubst(value.name, - formatv("this->getOperation()->getOperand({0})", i)); + verifyCtx.addSubst(value.name, + formatv("this->getOperation()->getOperand({0})", i)); } for (int i = 0, e = op.getNumResults(); i < e; ++i) { auto &value = op.getResult(i); @@ -924,8 +931,8 @@ void OpEmitter::genVerifier() { if (value.isVariadic()) break; if (!value.name.empty()) - fctx.addSubst(value.name, - formatv("this->getOperation()->getResult({0})", i)); + verifyCtx.addSubst(value.name, + formatv("this->getOperation()->getResult({0})", i)); } // Verify the attributes have the correct type. @@ -955,11 +962,12 @@ void OpEmitter::genVerifier() { auto attrPred = attr.getPredicate(); if (!attrPred.isNull()) { - body << tgfmt(" if (!($0)) return emitOpError(\"attribute '$1' " - "failed to satisfy constraint: $2\");\n", - /*ctx=*/nullptr, - tgfmt(attrPred.getCondition(), &fctx.withSelf(varName)), - attrName, attr.getDescription()); + body << tgfmt( + " if (!($0)) return emitOpError(\"attribute '$1' " + "failed to satisfy constraint: $2\");\n", + /*ctx=*/nullptr, + tgfmt(attrPred.getCondition(), &verifyCtx.withSelf(varName)), + attrName, attr.getDescription()); } body << " }\n"; @@ -977,10 +985,11 @@ void OpEmitter::genVerifier() { if (value.hasPredicate()) { auto description = value.constraint.getDescription(); body << " if (!(" - << tgfmt(value.constraint.getConditionTemplate(), - &fctx.withSelf("this->getOperation()->get" + - Twine(isOperand ? "Operand" : "Result") + - "(" + Twine(index) + ")->getType()")) + << tgfmt( + value.constraint.getConditionTemplate(), + &verifyCtx.withSelf("this->getOperation()->get" + + Twine(isOperand ? "Operand" : "Result") + + "(" + Twine(index) + ")->getType()")) << ")) {\n"; body << " return emitOpError(\"" << (isOperand ? "operand" : "result") << " #" << index @@ -1000,19 +1009,14 @@ void OpEmitter::genVerifier() { for (auto &trait : op.getTraits()) { if (auto t = dyn_cast(&trait)) { - body << tgfmt(" if (!($0))\n return emitOpError(\"" - "failed to verify that $1\");\n", - &fctx, tgfmt(t->getPredTemplate(), &fctx), + body << tgfmt(" if (!($0)) {\n " + "return emitOpError(\"failed to verify that $1\");\n }\n", + &verifyCtx, tgfmt(t->getPredTemplate(), &verifyCtx), t->getDescription()); } } - // Verify this op has the correct number of regions - body << formatv( - " if (this->getOperation()->getNumRegions() != {0}) \n return " - "emitOpError(\"has incorrect number of regions: expected {0} but found " - "\") << this->getOperation()->getNumRegions();\n", - op.getNumRegions()); + genRegionVerifier(body); if (hasCustomVerify) body << codeInit->getValue() << "\n"; @@ -1020,6 +1024,36 @@ void OpEmitter::genVerifier() { body << " return mlir::success();\n"; } +void OpEmitter::genRegionVerifier(OpMethodBody &body) { + unsigned numRegions = op.getNumRegions(); + + // Verify this op has the correct number of regions + body << formatv( + " if (this->getOperation()->getNumRegions() != {0}) {\n " + "return emitOpError(\"has incorrect number of regions: expected {0} but " + "found \") << this->getOperation()->getNumRegions();\n }\n", + numRegions); + + for (unsigned i = 0; i < numRegions; ++i) { + const auto ®ion = op.getRegion(i); + + std::string name = formatv("#{0}", i); + if (!region.name.empty()) { + name += formatv(" ('{0}')", region.name); + } + + auto getRegion = formatv("this->getOperation()->getRegion({0})", i).str(); + auto constraint = tgfmt(region.constraint.getConditionTemplate(), + &verifyCtx.withSelf(getRegion)) + .str(); + + body << formatv(" if (!({0})) {\n " + "return emitOpError(\"region {1} failed to verify " + "constraint: {2}\");\n }\n", + constraint, name, region.constraint.getDescription()); + } +} + void OpEmitter::genTraits() { int numResults = op.getNumResults(); int numVariadicResults = op.getNumVariadicResults();