forked from OSchip/llvm-project
[ODS] Support numRegions in Op definition
-- PiperOrigin-RevId: 250282024
This commit is contained in:
parent
c2d105811a
commit
d4c8c8de42
|
@ -916,6 +916,10 @@ class Op<Dialect dialect, string mnemonic, list<OpTrait> props = []> {
|
|||
// The list of results of the op. Default to 0 results.
|
||||
dag results = (outs);
|
||||
|
||||
// How many regions this op has.
|
||||
// TODO(b/133479568): Enhance to support advanced region usage cases
|
||||
int numRegions = 0;
|
||||
|
||||
// Attribute getters can be added to the op by adding an Attr member
|
||||
// with the name and type of the attribute. E.g., adding int attribute
|
||||
// with name "value" and type "i32":
|
||||
|
|
|
@ -129,6 +129,9 @@ public:
|
|||
// requiring the raw MLIR trait here.
|
||||
bool hasTrait(llvm::StringRef trait) const;
|
||||
|
||||
// Returns the number of regions.
|
||||
int getNumRegions() const;
|
||||
|
||||
// Trait.
|
||||
using const_trait_iterator = const OpTrait *;
|
||||
const_trait_iterator trait_begin() const;
|
||||
|
@ -174,6 +177,9 @@ private:
|
|||
// The traits of the op.
|
||||
SmallVector<OpTrait, 4> traits;
|
||||
|
||||
// The number of regions of this op.
|
||||
int numRegions = 0;
|
||||
|
||||
// The number of native attributes stored in the leading positions of
|
||||
// `attributes`.
|
||||
int numNativeAttributes;
|
||||
|
|
|
@ -146,6 +146,8 @@ bool tblgen::Operator::hasTrait(StringRef trait) const {
|
|||
return false;
|
||||
}
|
||||
|
||||
int tblgen::Operator::getNumRegions() const { return numRegions; }
|
||||
|
||||
auto tblgen::Operator::trait_begin() const -> const_trait_iterator {
|
||||
return traits.begin();
|
||||
}
|
||||
|
@ -265,6 +267,11 @@ void tblgen::Operator::populateOpStructure() {
|
|||
traits.reserve(traitListInit->size());
|
||||
for (auto traitInit : *traitListInit)
|
||||
traits.push_back(OpTrait::create(traitInit));
|
||||
|
||||
// Handle regions
|
||||
numRegions = def.getValueAsInt("numRegions");
|
||||
if (numRegions < 0)
|
||||
PrintFatalError(def.getLoc(), "numRegions cannot be negative");
|
||||
}
|
||||
|
||||
ArrayRef<llvm::SMLoc> tblgen::Operator::getLoc() const { return def.getLoc(); }
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
// RUN: mlir-test-opt %s -split-input-file -verify | FileCheck %s
|
||||
|
||||
func @correct_number_of_regions() {
|
||||
// CHECK: test.two_region_op
|
||||
"test.two_region_op"()(
|
||||
{"work"() : () -> ()},
|
||||
{"work"() : () -> ()}
|
||||
) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @missingk_regions() {
|
||||
// expected-error@+1 {{op has incorrect number of regions: expected 2 but found 1}}
|
||||
"test.two_region_op"()(
|
||||
{"work"() : () -> ()}
|
||||
) : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @extra_regions() {
|
||||
// expected-error@+1 {{op has incorrect number of regions: expected 2 but found 3}}
|
||||
"test.two_region_op"()(
|
||||
{"work"() : () -> ()},
|
||||
{"work"() : () -> ()},
|
||||
{"work"() : () -> ()}
|
||||
) : () -> ()
|
||||
return
|
||||
}
|
|
@ -113,4 +113,12 @@ def : Pat<(OpD $input), (OpF $input), [], (addBenefit 10)>;
|
|||
def : Pat<(OpG $input), (OpB $input, ConstantAttr<I32Attr, "20">:$attr)>;
|
||||
def : Pat<(OpG (OpG $input)), (OpB $input, ConstantAttr<I32Attr, "34">:$attr)>;
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Test op regions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def TwoRegionOp : TEST_Op<"two_region_op", []> {
|
||||
let numRegions = 2;
|
||||
}
|
||||
|
||||
#endif // TEST_OPS
|
||||
|
|
|
@ -742,6 +742,12 @@ void OpEmitter::genStandaloneParamBuilder(bool useOperandType,
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Create the correct number of regions
|
||||
if (int numRegions = op.getNumRegions()) {
|
||||
for (int i = 0; i < numRegions; ++i)
|
||||
m.body() << " (void)" << builderOpState << "->addRegion();\n";
|
||||
}
|
||||
}
|
||||
|
||||
void OpEmitter::genBuilder() {
|
||||
|
@ -820,6 +826,12 @@ void OpEmitter::genBuilder() {
|
|||
<< " " << builderOpState
|
||||
<< "->addAttribute(pair.first, pair.second);\n";
|
||||
|
||||
// Create the correct number of regions
|
||||
if (int numRegions = op.getNumRegions()) {
|
||||
for (int i = 0; i < numRegions; ++i)
|
||||
m.body() << " (void)" << builderOpState << "->addRegion();\n";
|
||||
}
|
||||
|
||||
// 3. Deduced result types
|
||||
|
||||
bool useOperandType = op.hasTrait("SameOperandsAndResultType");
|
||||
|
@ -883,9 +895,6 @@ void OpEmitter::genVerifier() {
|
|||
auto valueInit = def.getValueInit("verifier");
|
||||
CodeInit *codeInit = dyn_cast<CodeInit>(valueInit);
|
||||
bool hasCustomVerify = codeInit && !codeInit->getValue().empty();
|
||||
if (!hasCustomVerify && op.getNumArgs() == 0 && op.getNumResults() == 0 &&
|
||||
op.getNumPredOpTraits() == 0)
|
||||
return;
|
||||
|
||||
auto &method = opClass.newMethod("LogicalResult", "verify", /*params=*/"");
|
||||
auto &body = method.body();
|
||||
|
@ -972,6 +981,13 @@ void OpEmitter::genVerifier() {
|
|||
}
|
||||
}
|
||||
|
||||
// Verify this op has the correct number of regions
|
||||
body << formatv(
|
||||
" if (this->getOperation()->getNumRegions() != {0}) \n return "
|
||||
"emitOpError(\"has incorrect number of regions: expected {0} but found "
|
||||
"\") << this->getOperation()->getNumRegions();\n",
|
||||
op.getNumRegions());
|
||||
|
||||
if (hasCustomVerify)
|
||||
body << codeInit->getValue() << "\n";
|
||||
else
|
||||
|
|
Loading…
Reference in New Issue