From c234b65cef07b38c91b9ab7dec6a35f8b390e658 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Mon, 14 Dec 2020 11:53:34 -0800 Subject: [PATCH] [mlir][OpFormat] Add support for emitting newlines from the custom format of an operation This revision adds a new `printNewline` hook to OpAsmPrinter that allows for printing a newline within the custom format of an operation, that is then indented to the start of the operation. Support for the declarative assembly format is also added, in the form of a `\n` literal. Differential Revision: https://reviews.llvm.org/D93151 --- mlir/docs/OpDefinitions.md | 24 +++++++++ mlir/include/mlir/IR/OpImplementation.h | 4 ++ mlir/lib/IR/AsmPrinter.cpp | 8 +++ mlir/test/lib/Dialect/Test/TestOps.td | 3 +- mlir/test/mlir-tblgen/op-format-spec.td | 2 +- mlir/test/mlir-tblgen/op-format.mlir | 6 ++- mlir/tools/mlir-tblgen/OpFormatGen.cpp | 65 +++++++++++++++++++------ 7 files changed, 93 insertions(+), 19 deletions(-) diff --git a/mlir/docs/OpDefinitions.md b/mlir/docs/OpDefinitions.md index a267a60adc3e..189cd0825af7 100644 --- a/mlir/docs/OpDefinitions.md +++ b/mlir/docs/OpDefinitions.md @@ -646,6 +646,30 @@ The following are the set of valid punctuation: `:`, `,`, `=`, `<`, `>`, `(`, `)`, `{`, `}`, `[`, `]`, `->`, `?`, `+`, `*` +The following are valid whitespace punctuation: + +`\n`, ` ` + +The `\n` literal emits a newline an indents to the start of the operation. An +example is shown below: + +```tablegen +let assemblyFormat = [{ + `{` `\n` ` ` ` ` `this_is_on_a_newline` `\n` `}` attr-dict +}]; +``` + +```mlir +%results = my.operation { + this_is_on_a_newline +} +``` + +An empty literal \`\` may be used to remove a space that is inserted implicitly +after certain literal elements, such as `)`/`]`/etc. For example, "`]`" may +result in an output of `]` it is not the last element in the format. "`]` \`\`" +would trim the trailing space in this situation. + #### Variables A variable is an entity that has been registered on the operation itself, i.e. diff --git a/mlir/include/mlir/IR/OpImplementation.h b/mlir/include/mlir/IR/OpImplementation.h index a7e87dc0ab06..31d3b42c8493 100644 --- a/mlir/include/mlir/IR/OpImplementation.h +++ b/mlir/include/mlir/IR/OpImplementation.h @@ -36,6 +36,10 @@ public: virtual ~OpAsmPrinter(); virtual raw_ostream &getStream() const = 0; + /// Print a newline and indent the printer to the start of the current + /// operation. + virtual void printNewline() = 0; + /// Print implementations for various things an operation contains. virtual void printOperand(Value value) = 0; virtual void printOperand(Value value, raw_ostream &os) = 0; diff --git a/mlir/lib/IR/AsmPrinter.cpp b/mlir/lib/IR/AsmPrinter.cpp index 49e7048cfb13..1c2caa0bdfd6 100644 --- a/mlir/lib/IR/AsmPrinter.cpp +++ b/mlir/lib/IR/AsmPrinter.cpp @@ -429,6 +429,7 @@ private: /// The following are hooks of `OpAsmPrinter` that are not necessary for /// determining potential aliases. void printAffineMapOfSSAIds(AffineMapAttr, ValueRange) override {} + void printNewline() override {} void printOperand(Value) override {} void printOperand(Value, raw_ostream &os) override { // Users expect the output string to have at least the prefixed % to signal @@ -2218,6 +2219,13 @@ public: /// Return the current stream of the printer. raw_ostream &getStream() const override { return os; } + /// Print a newline and indent the printer to the start of the current + /// operation. + void printNewline() override { + os << newLine; + os.indent(currentIndent); + } + /// Print the given type. void printType(Type type) override { ModulePrinter::printType(type); } diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index 6a7291abfec7..9a7eb5940fb9 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -1393,7 +1393,8 @@ def AsmDialectInterfaceOp : TEST_Op<"asm_dialect_interface_op"> { def FormatLiteralOp : TEST_Op<"format_literal_op"> { let assemblyFormat = [{ - `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` `?` `+` `*` attr-dict + `keyword_$.` `->` `:` `,` `=` `<` `>` `(` `)` `[` `]` `` `(` ` ` `)` + `?` `+` `*` `{` `\n` `}` attr-dict }]; } diff --git a/mlir/test/mlir-tblgen/op-format-spec.td b/mlir/test/mlir-tblgen/op-format-spec.td index 7817920f8955..424dbb83c276 100644 --- a/mlir/test/mlir-tblgen/op-format-spec.td +++ b/mlir/test/mlir-tblgen/op-format-spec.td @@ -309,7 +309,7 @@ def LiteralInvalidB : TestFormat_Op<"literal_invalid_b", [{ }]>; // CHECK-NOT: error def LiteralValid : TestFormat_Op<"literal_valid", [{ - `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `abc$._` + `_` `:` `,` `=` `<` `>` `(` `)` `[` `]` `?` `+` `*` ` ` `` `->` `\n` `abc$._` attr-dict }]>; diff --git a/mlir/test/mlir-tblgen/op-format.mlir b/mlir/test/mlir-tblgen/op-format.mlir index 6286f7655146..334313debda1 100644 --- a/mlir/test/mlir-tblgen/op-format.mlir +++ b/mlir/test/mlir-tblgen/op-format.mlir @@ -7,8 +7,10 @@ // CHECK: %[[MEMREF:.*]] = %memref = "foo.op"() : () -> (memref<1xf64>) -// CHECK: test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * {foo.some_attr} -test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * {foo.some_attr} +// CHECK: test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * { +// CHECK-NEXT: } {foo.some_attr} +test.format_literal_op keyword_$. -> :, = <> () []( ) ? + * { +} {foo.some_attr} // CHECK: test.format_attr_op 10 // CHECK-NOT: {attr diff --git a/mlir/tools/mlir-tblgen/OpFormatGen.cpp b/mlir/tools/mlir-tblgen/OpFormatGen.cpp index e09cdd2ac6d4..6cc7c75dc8a4 100644 --- a/mlir/tools/mlir-tblgen/OpFormatGen.cpp +++ b/mlir/tools/mlir-tblgen/OpFormatGen.cpp @@ -58,7 +58,8 @@ public: /// This element is a literal. Literal, - /// This element prints or omits a space. It is ignored by the parser. + /// This element is a whitespace. + Newline, Space, /// This element is an variable value. @@ -296,14 +297,35 @@ bool LiteralElement::isValidLiteral(StringRef value) { } //===----------------------------------------------------------------------===// -// SpaceElement +// WhitespaceElement namespace { +/// This class represents a whitespace element, e.g. newline or space. It's a +/// literal that is printed but never parsed. +class WhitespaceElement : public Element { +public: + WhitespaceElement(Kind kind) : Element{kind} {} + static bool classof(const Element *element) { + Kind kind = element->getKind(); + return kind == Kind::Newline || kind == Kind::Space; + } +}; + +/// This class represents an instance of a newline element. It's a literal that +/// prints a newline. It is ignored by the parser. +class NewlineElement : public WhitespaceElement { +public: + NewlineElement() : WhitespaceElement(Kind::Newline) {} + static bool classof(const Element *element) { + return element->getKind() == Kind::Newline; + } +}; + /// This class represents an instance of a space element. It's a literal that /// prints or omits printing a space. It is ignored by the parser. -class SpaceElement : public Element { +class SpaceElement : public WhitespaceElement { public: - SpaceElement(bool value) : Element{Kind::Space}, value(value) {} + SpaceElement(bool value) : WhitespaceElement(Kind::Space), value(value) {} static bool classof(const Element *element) { return element->getKind() == Kind::Space; } @@ -347,7 +369,8 @@ private: std::vector> elements; /// The index of the element that acts as the anchor for the optional group. unsigned anchor; - /// The index of the first element that is parsed (is not a SpaceElement). + /// The index of the first element that is parsed (is not a + /// WhitespaceElement). unsigned parseStart; }; } // end anonymous namespace @@ -1098,8 +1121,8 @@ void OperationFormat::genElementParser(Element *element, OpMethodBody &body, genLiteralParser(literal->getLiteral(), body); body << ")\n return ::mlir::failure();\n"; - /// Spaces. - } else if (isa(element)) { + /// Whitespaces. + } else if (isa(element)) { // Nothing to parse. /// Arguments. @@ -1620,6 +1643,11 @@ void OperationFormat::genElementPrinter(Element *element, OpMethodBody &body, return genLiteralPrinter(literal->getLiteral(), body, shouldEmitSpace, lastWasPunctuation); + // Emit a whitespace element. + if (NewlineElement *newline = dyn_cast(element)) { + body << " p.printNewline();\n"; + return; + } if (SpaceElement *space = dyn_cast(element)) return genSpacePrinter(space->getValue(), body, shouldEmitSpace, lastWasPunctuation); @@ -2272,9 +2300,10 @@ LogicalResult FormatParser::verifyAttributes( for (auto &nextItPair : iteratorStack) { ElementsIterT nextIt = nextItPair.first, nextE = nextItPair.second; for (; nextIt != nextE; ++nextIt) { - // Skip any trailing spaces, attribute dictionaries, or optional groups. - if (isa(*nextIt) || isa(*nextIt) || - isa(*nextIt)) + // Skip any trailing whitespace, attribute dictionaries, or optional + // groups. + if (isa(*nextIt) || + isa(*nextIt) || isa(*nextIt)) continue; // We are only interested in `:` literals. @@ -2600,6 +2629,11 @@ LogicalResult FormatParser::parseLiteral(std::unique_ptr &element) { element = std::make_unique(!value.empty()); return ::mlir::success(); } + // The parsed literal is a newline element. + if (value == "\\n") { + element = std::make_unique(); + return ::mlir::success(); + } // Check that the parsed literal is valid. if (!LiteralElement::isValidLiteral(value)) @@ -2635,8 +2669,9 @@ LogicalResult FormatParser::parseOptional(std::unique_ptr &element, // The first parsable element of the group must be able to be parsed in an // optional fashion. - auto parseBegin = llvm::find_if_not( - elements, [](auto &element) { return isa(element.get()); }); + auto parseBegin = llvm::find_if_not(elements, [](auto &element) { + return isa(element.get()); + }); Element *firstElement = parseBegin->get(); if (!isa(firstElement) && !isa(firstElement) && @@ -2718,9 +2753,9 @@ LogicalResult FormatParser::parseOptionalChildElement( // a check here. return ::mlir::success(); }) - // Literals, spaces, custom directives, and type directives may be used, - // but they can't anchor the group. - .Case([&](Element *) { if (isAnchor)