From c35378003c64b87e02542187ae583b3fb6623df7 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Thu, 21 Nov 2019 14:34:03 -0800 Subject: [PATCH] Add support for using the ODS result names as the Asm result names for multi-result operations. This changes changes the OpDefinitionsGen to automatically add the OpAsmOpInterface for operations with multiple result groups using the provided ODS names. We currently just limit the generation to multi-result ops as most single result operations don't have an interesting name(result/output/etc.). An example is shown below: // The following operation: def MyOp : ... { let results = (outs AnyType:$first, Variadic:$middle, AnyType); } // May now be printed as: %first, %middle:2, %0 = "my.op" ... PiperOrigin-RevId: 281834156 --- mlir/lib/TableGen/Operator.cpp | 3 ++ mlir/test/lib/TestDialect/TestDialect.cpp | 15 ------- mlir/test/lib/TestDialect/TestOps.td | 5 +-- mlir/test/mlir-tblgen/op-decl.td | 2 +- mlir/test/mlir-tblgen/pattern.mlir | 50 ++++++++++----------- mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp | 36 +++++++++++++++ 6 files changed, 67 insertions(+), 44 deletions(-) diff --git a/mlir/lib/TableGen/Operator.cpp b/mlir/lib/TableGen/Operator.cpp index 8afffd03fcb1..927f275e0800 100644 --- a/mlir/lib/TableGen/Operator.cpp +++ b/mlir/lib/TableGen/Operator.cpp @@ -153,6 +153,9 @@ bool tblgen::Operator::hasTrait(StringRef trait) const { } else if (auto opTrait = dyn_cast(&t)) { if (opTrait->getTrait() == trait) return true; + } else if (auto opTrait = dyn_cast(&t)) { + if (opTrait->getTrait() == trait) + return true; } } return false; diff --git a/mlir/test/lib/TestDialect/TestDialect.cpp b/mlir/test/lib/TestDialect/TestDialect.cpp index d838f75f7e7c..3c7fbee3671b 100644 --- a/mlir/test/lib/TestDialect/TestDialect.cpp +++ b/mlir/test/lib/TestDialect/TestDialect.cpp @@ -254,21 +254,6 @@ static ParseResult parsePolyForOp(OpAsmParser &parser, OperationState &result) { return parser.parseRegion(*body, ivsInfo, argTypes); } -//===----------------------------------------------------------------------===// -// Test OpAsmInterface. -//===----------------------------------------------------------------------===// - -void AsmInterfaceOp::getAsmResultNames( - function_ref setNameFn) { - // Give a name to the first and middle results. - setNameFn(firstResult(), "first"); - if (!llvm::empty(middleResults())) - setNameFn(*middleResults().begin(), "middle_results"); - - // Use default numbering for the last result. - setNameFn(getResult(getNumResults() - 1), ""); -} - //===----------------------------------------------------------------------===// // Test removing op with inner ops. //===----------------------------------------------------------------------===// diff --git a/mlir/test/lib/TestDialect/TestOps.td b/mlir/test/lib/TestDialect/TestOps.td index 2c07a2557dcc..d804fdc1b78c 100644 --- a/mlir/test/lib/TestDialect/TestOps.td +++ b/mlir/test/lib/TestDialect/TestOps.td @@ -1003,9 +1003,8 @@ def PolyForOp : TEST_Op<"polyfor"> //===----------------------------------------------------------------------===// // Test OpAsmInterface. -def AsmInterfaceOp : TEST_Op<"asm_interface_op", - [DeclareOpInterfaceMethods]> { - let results = (outs AnyType:$firstResult, Variadic:$middleResults, +def AsmInterfaceOp : TEST_Op<"asm_interface_op"> { + let results = (outs AnyType:$first, Variadic:$middle_results, AnyType); } diff --git a/mlir/test/mlir-tblgen/op-decl.td b/mlir/test/mlir-tblgen/op-decl.td index 936bf3a4bfa4..672cfeffbfe8 100644 --- a/mlir/test/mlir-tblgen/op-decl.td +++ b/mlir/test/mlir-tblgen/op-decl.td @@ -52,7 +52,7 @@ def NS_AOp : NS_Op<"a_op", [NoSideEffect]> { // CHECK: ArrayRef tblgen_operands; // CHECK: }; -// CHECK: class AOp : public Op::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl> { +// CHECK: class AOp : public Op::Impl, OpTrait::HasNoSideEffect, OpTrait::AtLeastNOperands<1>::Impl // CHECK: public: // CHECK: using Op::Op; // CHECK: using OperandAdaptor = AOpOperandAdaptor; diff --git a/mlir/test/mlir-tblgen/pattern.mlir b/mlir/test/mlir-tblgen/pattern.mlir index 7586d841bdbf..1f6da059dcfc 100644 --- a/mlir/test/mlir-tblgen/pattern.mlir +++ b/mlir/test/mlir-tblgen/pattern.mlir @@ -222,46 +222,46 @@ func @rewrite_f64elementsattr() -> () { // CHECK-LABEL: @useMultiResultOpToReplaceWhole func @useMultiResultOpToReplaceWhole() -> (i32, f32, f32) { - // CHECK: %0:3 = "test.another_three_result"() - // CHECK: return %0#0, %0#1, %0#2 + // CHECK: %[[A:.*]], %[[B:.*]], %[[C:.*]] = "test.another_three_result"() + // CHECK: return %[[A]], %[[B]], %[[C]] %0:3 = "test.three_result"() {kind = 1} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } // CHECK-LABEL: @useMultiResultOpToReplacePartial1 func @useMultiResultOpToReplacePartial1() -> (i32, f32, f32) { - // CHECK: %0:2 = "test.two_result"() - // CHECK: %1 = "test.one_result1"() - // CHECK: return %0#0, %0#1, %1 + // CHECK: %[[A:.*]], %[[B:.*]] = "test.two_result"() + // CHECK: %[[C:.*]] = "test.one_result1"() + // CHECK: return %[[A]], %[[B]], %[[C]] %0:3 = "test.three_result"() {kind = 2} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } // CHECK-LABEL: @useMultiResultOpToReplacePartial2 func @useMultiResultOpToReplacePartial2() -> (i32, f32, f32) { - // CHECK: %0 = "test.one_result2"() - // CHECK: %1:2 = "test.another_two_result"() - // CHECK: return %0, %1#0, %1#1 + // CHECK: %[[A:.*]] = "test.one_result2"() + // CHECK: %[[B:.*]], %[[C:.*]] = "test.another_two_result"() + // CHECK: return %[[A]], %[[B]], %[[C]] %0:3 = "test.three_result"() {kind = 3} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } // CHECK-LABEL: @useMultiResultOpResultsSeparately func @useMultiResultOpResultsSeparately() -> (i32, f32, f32) { - // CHECK: %0:2 = "test.two_result"() - // CHECK: %1 = "test.one_result1"() - // CHECK: %2:2 = "test.two_result"() - // CHECK: return %0#0, %1, %2#1 + // CHECK: %[[A:.*]], %[[B:.*]] = "test.two_result"() + // CHECK: %[[C:.*]] = "test.one_result1"() + // CHECK: %[[D:.*]], %[[E:.*]] = "test.two_result"() + // CHECK: return %[[A]], %[[C]], %[[E]] %0:3 = "test.three_result"() {kind = 4} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } // CHECK-LABEL: @constraintOnSourceOpResult func @constraintOnSourceOpResult() -> (i32, f32, i32) { - // CHECK: %0:2 = "test.two_result"() - // CHECK: %1 = "test.one_result2"() - // CHECK: %2 = "test.one_result1"() - // CHECK: return %0#0, %0#1, %1 + // CHECK: %[[A:.*]], %[[B:.*]] = "test.two_result"() + // CHECK: %[[C:.*]] = "test.one_result2"() + // CHECK: %[[D:.*]] = "test.one_result1"() + // CHECK: return %[[A]], %[[B]], %[[C]] %0:2 = "test.two_result"() {kind = 5} : () -> (i32, f32) %1:2 = "test.two_result"() {kind = 5} : () -> (i32, f32) return %0#0, %0#1, %1#0 : i32, f32, i32 @@ -271,11 +271,11 @@ func @constraintOnSourceOpResult() -> (i32, f32, i32) { func @useAuxiliaryOpToReplaceMultiResultOp() -> (i32, f32, f32) { // An auxiliary op is generated to help building the op for replacing the // matched op. - // CHECK: %0:2 = "test.two_result"() + // CHECK: %[[A:.*]], %[[B:.*]] = "test.two_result"() - // CHECK: %1 = "test.one_result3"(%0#1) - // CHECK: %2:2 = "test.another_two_result"() - // CHECK: return %1, %2#0, %2#1 + // CHECK: %[[C:.*]] = "test.one_result3"(%[[B]]) + // CHECK: %[[D:.*]], %[[E:.*]] = "test.another_two_result"() + // CHECK: return %[[C]], %[[D]], %[[E]] %0:3 = "test.three_result"() {kind = 6} : () -> (i32, f32, f32) return %0#0, %0#1, %0#2 : i32, f32, f32 } @@ -312,9 +312,9 @@ func @replaceMixedVariadicInputOp(%arg0: i32, %arg1: f32, %arg2: i32) -> () { // CHECK-LABEL: @replaceMixedVariadicOutputOp func @replaceMixedVariadicOutputOp() -> (f32, i32, f32, i32, i32, i32, f32, i32, i32) { // CHECK: %[[cnt1:.*]] = "test.mixed_variadic_out2"() - // CHECK: %[[cnt3:.*]]:3 = "test.mixed_variadic_out2"() - // CHECK: %[[cnt5:.*]]:5 = "test.mixed_variadic_out2"() - // CHECK: return %[[cnt1]], %[[cnt3]]#0, %[[cnt3]]#1, %[[cnt3]]#2, %[[cnt5]]#0, %[[cnt5]]#1, %[[cnt5]]#2, %[[cnt5]]#3, %[[cnt5]]#4 + // CHECK: %[[cnt3_a:.*]], %[[cnt3_b:.*]], %[[cnt3_c:.*]] = "test.mixed_variadic_out2"() + // CHECK: %[[cnt5_a:.*]]:2, %[[cnt5_b:.*]], %[[cnt5_c:.*]]:2 = "test.mixed_variadic_out2"() + // CHECK: return %[[cnt1]], %[[cnt3_a]], %[[cnt3_b]], %[[cnt3_c]], %[[cnt5_a]]#0, %[[cnt5_a]]#1, %[[cnt5_b]], %[[cnt5_c]]#0, %[[cnt5_c]]#1 %0 = "test.mixed_variadic_out1"() : () -> (f32) %1:3 = "test.mixed_variadic_out1"() : () -> (i32, f32, i32) @@ -324,8 +324,8 @@ func @replaceMixedVariadicOutputOp() -> (f32, i32, f32, i32, i32, i32, f32, i32, // CHECK-LABEL: @generateVariadicOutputOpInNestedPattern func @generateVariadicOutputOpInNestedPattern() -> (i32) { - // CHECK: %[[cnt5:.*]]:5 = "test.mixed_variadic_out3"() - // CHECK: %[[res:.*]] = "test.mixed_variadic_in3"(%[[cnt5]]#0, %[[cnt5]]#1, %[[cnt5]]#2, %[[cnt5]]#3, %[[cnt5]]#4) + // CHECK: %[[cnt5_a:.*]], %[[cnt5_b:.*]]:2, %[[cnt5_c:.*]]:2 = "test.mixed_variadic_out3"() + // CHECK: %[[res:.*]] = "test.mixed_variadic_in3"(%[[cnt5_a]], %[[cnt5_b]]#0, %[[cnt5_b]]#1, %[[cnt5_c]]#0, %[[cnt5_c]]#1) // CHECK: return %[[res]] %0 = "test.one_i32_out"() : () -> (i32) diff --git a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp index 46803b557e49..538aa6e79a7a 100644 --- a/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp +++ b/mlir/tools/mlir-tblgen/OpDefinitionsGen.cpp @@ -459,6 +459,9 @@ private: void emitDecl(raw_ostream &os); void emitDef(raw_ostream &os); + // Generates the OpAsmOpInterface for this operation if possible. + void genOpAsmInterface(); + // Generates the `getOperationName` method for this op. void genOpNameGetter(); @@ -575,6 +578,7 @@ OpEmitter::OpEmitter(const Operator &op) genTraits(); // Generate C++ code for various op methods. The order here determines the // methods in the generated file. + genOpAsmInterface(); genOpNameGetter(); genNamedOperandGetters(); genNamedResultGetters(); @@ -1393,6 +1397,38 @@ void OpEmitter::genOpNameGetter() { method.body() << " return \"" << op.getOperationName() << "\";\n"; } +void OpEmitter::genOpAsmInterface() { + // If the user only has one results or specifically added the Asm trait, + // then don't generate it for them. We specifically only handle multi result + // operations, because the name of a single result in the common case is not + // interesting(generally 'result'/'output'/etc.). + // TODO: We could also add a flag to allow operations to opt in to this + // generation, even if they only have a single operation. + int numResults = op.getNumResults(); + if (numResults <= 1 || op.hasTrait("OpAsmOpInterface::Trait")) + return; + + SmallVector resultNames(numResults); + for (int i = 0; i != numResults; ++i) + resultNames[i] = op.getResultName(i); + + // Don't add the trait if none of the results have a valid name. + if (llvm::all_of(resultNames, [](StringRef name) { return name.empty(); })) + return; + opClass.addTrait("OpAsmOpInterface::Trait"); + + // Generate the right accessor for the number of results. + auto &method = opClass.newMethod("void", "getAsmResultNames", + "OpAsmSetValueNameFn setNameFn"); + auto &body = method.body(); + for (int i = 0; i != numResults; ++i) { + body << " auto resultGroup" << i << " = getODSResults(" << i << ");\n" + << " if (!llvm::empty(resultGroup" << i << "))\n" + << " setNameFn(*resultGroup" << i << ".begin(), \"" + << resultNames[i] << "\");\n"; + } +} + //===----------------------------------------------------------------------===// // OpOperandAdaptor emitter //===----------------------------------------------------------------------===//