diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index 8ef6e0bc5930..edbd27303b0b 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -916,6 +916,10 @@ class Op props = []> { // The list of results of the op. Default to 0 results. dag results = (outs); + // How many regions this op has. + // TODO(b/133479568): Enhance to support advanced region usage cases + int numRegions = 0; + // Attribute getters can be added to the op by adding an Attr member // with the name and type of the attribute. E.g., adding int attribute // with name "value" and type "i32": diff --git a/mlir/include/mlir/TableGen/Operator.h b/mlir/include/mlir/TableGen/Operator.h index 3f5284773481..77b3a9ff410e 100644 --- a/mlir/include/mlir/TableGen/Operator.h +++ b/mlir/include/mlir/TableGen/Operator.h @@ -129,6 +129,9 @@ public: // requiring the raw MLIR trait here. bool hasTrait(llvm::StringRef trait) const; + // Returns the number of regions. + int getNumRegions() const; + // Trait. using const_trait_iterator = const OpTrait *; const_trait_iterator trait_begin() const; @@ -174,6 +177,9 @@ private: // The traits of the op. SmallVector traits; + // The number of regions of this op. + int numRegions = 0; + // The number of native attributes stored in the leading positions of // `attributes`. int numNativeAttributes; diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 222397453512..d27db0d37f4c 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -146,6 +146,8 @@ bool tblgen::Operator::hasTrait(StringRef trait) const { return false; } +int tblgen::Operator::getNumRegions() const { return numRegions; } + auto tblgen::Operator::trait_begin() const -> const_trait_iterator { return traits.begin(); } @@ -265,6 +267,11 @@ void tblgen::Operator::populateOpStructure() { traits.reserve(traitListInit->size()); for (auto traitInit : *traitListInit) traits.push_back(OpTrait::create(traitInit)); + + // Handle regions + numRegions = def.getValueAsInt("numRegions"); + if (numRegions < 0) + PrintFatalError(def.getLoc(), "numRegions cannot be negative"); } ArrayRef tblgen::Operator::getLoc() const { return def.getLoc(); } diff --git a/mlir/test/IR/region.mlir b/mlir/test/IR/region.mlir new file mode 100644 index 000000000000..702e56ddabe8 --- /dev/null +++ b/mlir/test/IR/region.mlir @@ -0,0 +1,32 @@ +// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s + +func @correct_number_of_regions() { + // CHECK: test.two_region_op + "test.two_region_op"()( + {"work"() : () -> ()}, + {"work"() : () -> ()} + ) : () -> () + return +} + +// ----- + +func @missingk_regions() { + // expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}} + "test.two_region_op"()( + {"work"() : () -> ()} + ) : () -> () + return +} + +// ----- + +func @extra_regions() { + // expected-error@+1 {{op has incorrect number of regions: expected 2 but found 3}} + "test.two_region_op"()( + {"work"() : () -> ()}, + {"work"() : () -> ()}, + {"work"() : () -> ()} + ) : () -> () + return +} diff --git a/mlir/test/TestDialect/TestOps.td b/mlir/test/TestDialect/TestOps.td index 3c0ade3c493d..915318d36dc8 100644 --- a/mlir/test/TestDialect/TestOps.td +++ b/mlir/test/TestDialect/TestOps.td @@ -113,4 +113,12 @@ def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>; def : Pat<(OpG $input), (OpB $input, ConstantAttr:$attr)>; def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr:$attr)>; +//===----------------------------------------------------------------------===// +// Test op regions +//===----------------------------------------------------------------------===// + +def TwoRegionOp : TEST_Op<"two_region_op", []> { + let numRegions = 2; +} + #endif // TEST_OPS diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 3d5a3b9446ca..ca8b27a9c27e 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -742,6 +742,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType, } } } + + // Create the correct number of regions + if (int numRegions = op.getNumRegions()) { + for (int i = 0; i < numRegions; ++i) + m.body() << " (void)" << builderOpState << "->addRegion();\n"; + } } void OpEmitter::genBuilder() { @@ -820,6 +826,12 @@ void OpEmitter::genBuilder() { << " " << builderOpState << "->addAttribute(pair.first, pair.second);\n"; + // Create the correct number of regions + if (int numRegions = op.getNumRegions()) { + for (int i = 0; i < numRegions; ++i) + m.body() << " (void)" << builderOpState << "->addRegion();\n"; + } + // 3. Deduced result types bool useOperandType = op.hasTrait("SameOperandsAndResultType"); @@ -883,9 +895,6 @@ void OpEmitter::genVerifier() { auto valueInit = def.getValueInit("verifier"); CodeInit *codeInit = dyn_cast(valueInit); bool hasCustomVerify = codeInit && !codeInit->getValue().empty(); - if (!hasCustomVerify && op.getNumArgs() == 0 && op.getNumResults() == 0 && - op.getNumPredOpTraits() == 0) - return; auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/""); auto &body = method.body(); @@ -972,6 +981,13 @@ void OpEmitter::genVerifier() { } } + // Verify this op has the correct number of regions + body << formatv( + " if (this->getOperation()->getNumRegions() != {0}) \n return " + "emitOpError(\"has incorrect number of regions: expected {0} but found " + "\") << this->getOperation()->getNumRegions();\n", + op.getNumRegions()); + if (hasCustomVerify) body << codeInit->getValue() << "\n"; else