[mlir:LSP] Add support for code completing attributes and types

This required changing a bit of how attributes/types are parsed. A new
`KeywordSwitch` class was added to AsmParser that provides a StringSwitch
like API for parsing keywords with a set of potential matches. It intends to
both provide a cleaner API, and enable injection for code completion. This
required changing the API of `generated(Attr|Type)Parser` to handle the
parsing of the keyword, instead of having the user do it. Most upstream
dialects use the autogenerated handling and didn't require a direct update.

Differential Revision: https://reviews.llvm.org/D129267
This commit is contained in:
River Riddle 2022-07-06 22:54:36 -07:00
parent 2e41ea3247
commit fe4f512be7
21 changed files with 565 additions and 179 deletions

View File

@ -116,10 +116,8 @@ RecordType verifyDerived(mlir::AsmParser &parser, RecordType derivedTy,
mlir::Type fir::parseFirType(FIROpsDialect *dialect,
mlir::DialectAsmParser &parser) {
mlir::StringRef typeTag;
if (parser.parseKeyword(&typeTag))
return {};
mlir::Type genType;
auto parseResult = generatedTypeParser(parser, typeTag, genType);
auto parseResult = generatedTypeParser(parser, &typeTag, genType);
if (parseResult.hasValue())
return genType;
parser.emitError(parser.getNameLoc(), "unknown fir type: ") << typeTag;

View File

@ -473,10 +473,10 @@ one for printing. These static functions placed alongside the class definitions
and have the following function signatures:
```c++
static ParseResult generatedAttributeParser(DialectAsmParser& parser, StringRef mnemonic, Type attrType, Attribute &result);
static ParseResult generatedAttributeParser(DialectAsmParser& parser, StringRef *mnemonic, Type attrType, Attribute &result);
static LogicalResult generatedAttributePrinter(Attribute attr, DialectAsmPrinter& printer);
static ParseResult generatedTypeParser(DialectAsmParser& parser, StringRef mnemonic, Type &result);
static ParseResult generatedTypeParser(DialectAsmParser& parser, StringRef *mnemonic, Type &result);
static LogicalResult generatedTypePrinter(Type type, DialectAsmPrinter& printer);
```

View File

@ -571,43 +571,6 @@ public:
/// Parse a quoted string token if present.
virtual ParseResult parseOptionalString(std::string *string) = 0;
/// Parse a given keyword.
ParseResult parseKeyword(StringRef keyword) {
return parseKeyword(keyword, "");
}
virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
/// Parse a keyword into 'keyword'.
ParseResult parseKeyword(StringRef *keyword) {
auto loc = getCurrentLocation();
if (parseOptionalKeyword(keyword))
return emitError(loc, "expected valid keyword");
return success();
}
/// Parse the given keyword if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
/// Parse a keyword, if present, into 'keyword'.
virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
/// Parse a keyword, if present, and if one of the 'allowedValues',
/// into 'keyword'
virtual ParseResult
parseOptionalKeyword(StringRef *keyword,
ArrayRef<StringRef> allowedValues) = 0;
/// Parse a keyword or a quoted string.
ParseResult parseKeywordOrString(std::string *result) {
if (failed(parseOptionalKeywordOrString(result)))
return emitError(getCurrentLocation())
<< "expected valid keyword or string";
return success();
}
/// Parse an optional keyword or string.
virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
/// Parse a `(` token.
virtual ParseResult parseLParen() = 0;
@ -712,6 +675,115 @@ public:
return parseCommaSeparatedList(Delimiter::None, parseElementFn);
}
//===--------------------------------------------------------------------===//
// Keyword Parsing
//===--------------------------------------------------------------------===//
/// This class represents a StringSwitch like class that is useful for parsing
/// expected keywords. On construction, it invokes `parseKeyword` and
/// processes each of the provided cases statements until a match is hit. The
/// provided `ResultT` must be assignable from `failure()`.
template <typename ResultT = ParseResult>
class KeywordSwitch {
public:
KeywordSwitch(AsmParser &parser)
: parser(parser), loc(parser.getCurrentLocation()) {
if (failed(parser.parseKeywordOrCompletion(&keyword)))
result = failure();
}
/// Case that uses the provided value when true.
KeywordSwitch &Case(StringLiteral str, ResultT value) {
return Case(str, [&](StringRef, SMLoc) { return std::move(value); });
}
KeywordSwitch &Default(ResultT value) {
return Default([&](StringRef, SMLoc) { return std::move(value); });
}
/// Case that invokes the provided functor when true. The parameters passed
/// to the functor are the keyword, and the location of the keyword (in case
/// any errors need to be emitted).
template <typename FnT>
std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
Case(StringLiteral str, FnT &&fn) {
if (result)
return *this;
// If the word was empty, record this as a completion.
if (keyword.empty())
parser.codeCompleteExpectedTokens(str);
else if (keyword == str)
result.emplace(std::move(fn(keyword, loc)));
return *this;
}
template <typename FnT>
std::enable_if_t<!std::is_convertible<FnT, ResultT>::value, KeywordSwitch &>
Default(FnT &&fn) {
if (!result)
result.emplace(fn(keyword, loc));
return *this;
}
/// Returns true if this switch has a value yet.
bool hasValue() const { return result.hasValue(); }
/// Return the result of the switch.
LLVM_NODISCARD operator ResultT() {
if (!result)
return parser.emitError(loc, "unexpected keyword: ") << keyword;
return std::move(*result);
}
private:
/// The parser used to construct this switch.
AsmParser &parser;
/// The location of the keyword, used to emit errors as necessary.
SMLoc loc;
/// The parsed keyword itself.
StringRef keyword;
/// The result of the switch statement or none if currently unknown.
Optional<ResultT> result;
};
/// Parse a given keyword.
ParseResult parseKeyword(StringRef keyword) {
return parseKeyword(keyword, "");
}
virtual ParseResult parseKeyword(StringRef keyword, const Twine &msg) = 0;
/// Parse a keyword into 'keyword'.
ParseResult parseKeyword(StringRef *keyword) {
auto loc = getCurrentLocation();
if (parseOptionalKeyword(keyword))
return emitError(loc, "expected valid keyword");
return success();
}
/// Parse the given keyword if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword) = 0;
/// Parse a keyword, if present, into 'keyword'.
virtual ParseResult parseOptionalKeyword(StringRef *keyword) = 0;
/// Parse a keyword, if present, and if one of the 'allowedValues',
/// into 'keyword'
virtual ParseResult
parseOptionalKeyword(StringRef *keyword,
ArrayRef<StringRef> allowedValues) = 0;
/// Parse a keyword or a quoted string.
ParseResult parseKeywordOrString(std::string *result) {
if (failed(parseOptionalKeywordOrString(result)))
return emitError(getCurrentLocation())
<< "expected valid keyword or string";
return success();
}
/// Parse an optional keyword or string.
virtual ParseResult parseOptionalKeywordOrString(std::string *result) = 0;
//===--------------------------------------------------------------------===//
// Attribute/Type Parsing
//===--------------------------------------------------------------------===//
@ -1124,6 +1196,17 @@ protected:
virtual FailureOr<AsmDialectResourceHandle>
parseResourceHandle(Dialect *dialect) = 0;
//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
/// Parse a keyword, or an empty string if the current location signals a code
/// completion.
virtual ParseResult parseKeywordOrCompletion(StringRef *keyword) = 0;
/// Signal the code completion of a set of expected tokens.
virtual void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) = 0;
private:
AsmParser(const AsmParser &) = delete;
void operator=(const AsmParser &) = delete;

View File

@ -10,9 +10,13 @@
#define MLIR_PARSER_CODECOMPLETE_H
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/StringMap.h"
#include "llvm/Support/SourceMgr.h"
namespace mlir {
class Attribute;
class Type;
/// This class provides an abstract interface into the parser for hooking in
/// code completion events. This class is only really useful for providing
/// language tooling for MLIR, general clients should not need to use this
@ -28,8 +32,9 @@ public:
// Completion Hooks
//===--------------------------------------------------------------------===//
/// Signal code completion for a dialect name.
virtual void completeDialectName() = 0;
/// Signal code completion for a dialect name, with an optional prefix.
virtual void completeDialectName(StringRef prefix) = 0;
void completeDialectName() { completeDialectName(""); }
/// Signal code completion for an operation name within the given dialect.
virtual void completeOperationName(StringRef dialectName) = 0;
@ -48,6 +53,16 @@ public:
virtual void completeExpectedTokens(ArrayRef<StringRef> tokens,
bool optional) = 0;
/// Signal a completion for an attribute.
virtual void completeAttribute(const llvm::StringMap<Attribute> &aliases) = 0;
virtual void completeDialectAttributeOrAlias(
const llvm::StringMap<Attribute> &aliases) = 0;
/// Signal a completion for a type.
virtual void completeType(const llvm::StringMap<Type> &aliases) = 0;
virtual void
completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) = 0;
protected:
/// Create a new code completion context with the given code complete
/// location.

View File

@ -35,11 +35,9 @@ void PDLDialect::registerTypes() {
static Type parsePDLType(AsmParser &parser) {
StringRef typeTag;
if (parser.parseKeyword(&typeTag))
return Type();
{
Type genType;
auto parseResult = generatedTypeParser(parser, typeTag, genType);
auto parseResult = generatedTypeParser(parser, &typeTag, genType);
if (parseResult.hasValue())
return genType;
}

View File

@ -577,17 +577,11 @@ Attribute SPIRVDialect::parseAttribute(DialectAsmParser &parser,
// Parse the kind keyword first.
StringRef attrKind;
if (parser.parseKeyword(&attrKind))
return {};
Attribute attr;
OptionalParseResult result =
generatedAttributeParser(parser, attrKind, type, attr);
if (result.hasValue()) {
if (failed(result.getValue()))
return {};
generatedAttributeParser(parser, &attrKind, type, attr);
if (result.hasValue())
return attr;
}
if (attrKind == spirv::TargetEnvAttr::getKindName())
return parseTargetEnvAttr(parser);

View File

@ -242,6 +242,56 @@ public:
return success();
}
/// Parse a floating point value from the stream.
ParseResult parseFloat(double &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
SMLoc loc = curTok.getLoc();
// Check for a floating point value.
if (curTok.is(Token::floatliteral)) {
auto val = curTok.getFloatingPointValue();
if (!val)
return emitError(loc, "floating point value too large");
parser.consumeToken(Token::floatliteral);
result = isNegative ? -*val : *val;
return success();
}
// Check for a hexadecimal float value.
if (curTok.is(Token::integer)) {
Optional<APFloat> apResult;
if (failed(parser.parseFloatFromIntegerLiteral(
apResult, curTok, isNegative, APFloat::IEEEdouble(),
/*typeSizeInBits=*/64)))
return failure();
parser.consumeToken(Token::integer);
result = apResult->convertToDouble();
return success();
}
return emitError(loc, "expected floating point literal");
}
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result) override {
return parser.parseOptionalInteger(result);
}
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
ParseResult parseCommaSeparatedList(Delimiter delimiter,
function_ref<ParseResult()> parseElt,
StringRef contextMessage) override {
return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
}
//===--------------------------------------------------------------------===//
// Keyword Parsing
//===--------------------------------------------------------------------===//
ParseResult parseKeyword(StringRef keyword, const Twine &msg) override {
if (parser.getToken().isCodeCompletion())
return parser.codeCompleteExpectedTokens(keyword);
@ -251,6 +301,7 @@ public:
return emitError(loc, "expected '") << keyword << "'" << msg;
return success();
}
using AsmParser::parseKeyword;
/// Parse the given keyword if present.
ParseResult parseOptionalKeyword(StringRef keyword) override {
@ -308,52 +359,6 @@ public:
return parseOptionalString(result);
}
/// Parse a floating point value from the stream.
ParseResult parseFloat(double &result) override {
bool isNegative = parser.consumeIf(Token::minus);
Token curTok = parser.getToken();
SMLoc loc = curTok.getLoc();
// Check for a floating point value.
if (curTok.is(Token::floatliteral)) {
auto val = curTok.getFloatingPointValue();
if (!val)
return emitError(loc, "floating point value too large");
parser.consumeToken(Token::floatliteral);
result = isNegative ? -*val : *val;
return success();
}
// Check for a hexadecimal float value.
if (curTok.is(Token::integer)) {
Optional<APFloat> apResult;
if (failed(parser.parseFloatFromIntegerLiteral(
apResult, curTok, isNegative, APFloat::IEEEdouble(),
/*typeSizeInBits=*/64)))
return failure();
parser.consumeToken(Token::integer);
result = apResult->convertToDouble();
return success();
}
return emitError(loc, "expected floating point literal");
}
/// Parse an optional integer value from the stream.
OptionalParseResult parseOptionalInteger(APInt &result) override {
return parser.parseOptionalInteger(result);
}
/// Parse a list of comma-separated items with an optional delimiter. If a
/// delimiter is provided, then an empty list is allowed. If not, then at
/// least one element will be parsed.
ParseResult parseCommaSeparatedList(Delimiter delimiter,
function_ref<ParseResult()> parseElt,
StringRef contextMessage) override {
return parser.parseCommaSeparatedList(delimiter, parseElt, contextMessage);
}
//===--------------------------------------------------------------------===//
// Attribute Parsing
//===--------------------------------------------------------------------===//
@ -528,6 +533,28 @@ public:
return parser.parseXInDimensionList();
}
//===--------------------------------------------------------------------===//
// Code Completion
//===--------------------------------------------------------------------===//
/// Parse a keyword, or an empty string if the current location signals a code
/// completion.
ParseResult parseKeywordOrCompletion(StringRef *keyword) override {
Token tok = parser.getToken();
if (tok.isCodeCompletion() && tok.getSpelling().empty()) {
*keyword = "";
return success();
}
return parseKeyword(keyword);
}
/// Signal the code completion of a set of expected tokens.
void codeCompleteExpectedTokens(ArrayRef<StringRef> tokens) override {
Token tok = parser.getToken();
if (tok.isCodeCompletion() && tok.getSpelling().empty())
(void)parser.codeCompleteExpectedTokens(tokens);
}
protected:
/// The source location of the dialect symbol.
SMLoc nameLoc;

View File

@ -213,6 +213,12 @@ Attribute Parser::parseAttribute(Type type) {
consumeToken(Token::kw_unit);
return builder.getUnitAttr();
// Handle completion of an attribute.
case Token::code_complete:
if (getToken().isCodeCompletionFor(Token::hash_identifier))
return parseExtendedAttr(type);
return codeCompleteAttribute();
default:
// Parse a type attribute. We parse `Optional` here to allow for providing a
// better error message.

View File

@ -43,9 +43,6 @@ private:
};
} // namespace
/// Parse the body of a dialect symbol, which starts and ends with <>'s, and may
/// be recursive. Return with the 'body' StringRef encompassing the entire
/// body.
///
/// pretty-dialect-sym-body ::= '<' pretty-dialect-sym-contents+ '>'
/// pretty-dialect-sym-contents ::= pretty-dialect-sym-body
@ -54,7 +51,8 @@ private:
/// | '{' pretty-dialect-sym-contents+ '}'
/// | '[^[<({>\])}\0]+'
///
ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
ParseResult Parser::parseDialectSymbolBody(StringRef &body,
bool &isCodeCompletion) {
// Symbol bodies are a relatively unstructured format that contains a series
// of properly nested punctuation, with anything else in the middle. Scan
// ahead to find it and consume it if successful, otherwise emit an error.
@ -65,7 +63,16 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
// go until we find the matching '>' character.
assert(*curPtr == '<');
SmallVector<char, 8> nestedPunctuation;
const char *codeCompleteLoc = state.lex.getCodeCompleteLoc();
do {
// Handle code completions, which may appear in the middle of the symbol
// body.
if (curPtr == codeCompleteLoc) {
isCodeCompletion = true;
nestedPunctuation.clear();
break;
}
char c = *curPtr++;
switch (c) {
case '\0':
@ -107,9 +114,19 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
case '"': {
// Dispatch to the lexer to lex past strings.
resetToken(curPtr - 1);
curPtr = state.curToken.getEndLoc().getPointer();
// Handle code completions, which may appear in the middle of the symbol
// body.
if (state.curToken.isCodeCompletion()) {
isCodeCompletion = true;
nestedPunctuation.clear();
break;
}
// Otherwise, ensure this token was actually a string.
if (state.curToken.isNot(Token::string))
return failure();
curPtr = state.curToken.getEndLoc().getPointer();
break;
}
@ -129,19 +146,24 @@ ParseResult Parser::parseDialectSymbolBody(StringRef &body) {
/// Parse an extended dialect symbol.
template <typename Symbol, typename SymbolAliasMap, typename CreateFn>
static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
SymbolAliasMap &aliases,
static Symbol parseExtendedSymbol(Parser &p, SymbolAliasMap &aliases,
CreateFn &&createSymbol) {
Token tok = p.getToken();
// Handle code completion of the extended symbol.
StringRef identifier = tok.getSpelling().drop_front();
if (tok.isCodeCompletion() && identifier.empty())
return p.codeCompleteDialectSymbol(aliases);
// Parse the dialect namespace.
StringRef identifier = p.getTokenSpelling().drop_front();
SMLoc loc = p.getToken().getLoc();
p.consumeToken(identifierTok);
p.consumeToken();
// Check to see if this is a pretty name.
StringRef dialectName;
StringRef symbolData;
std::tie(dialectName, symbolData) = identifier.split('.');
bool isPrettyName = !symbolData.empty();
bool isPrettyName = !symbolData.empty() || identifier.back() == '.';
// Check to see if the symbol has trailing data, i.e. has an immediately
// following '<'.
@ -167,9 +189,17 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
if (!isPrettyName) {
// Point the symbol data to the end of the dialect name to start.
symbolData = StringRef(dialectName.end(), 0);
if (p.parseDialectSymbolBody(symbolData))
// Parse the body of the symbol.
bool isCodeCompletion = false;
if (p.parseDialectSymbolBody(symbolData, isCodeCompletion))
return nullptr;
symbolData = symbolData.drop_front().drop_back();
symbolData = symbolData.drop_front();
// If the body contained a code completion it won't have the trailing `>`
// token, so don't drop it.
if (!isCodeCompletion)
symbolData = symbolData.drop_back();
} else {
loc = SMLoc::getFromPointer(symbolData.data());
@ -192,7 +222,7 @@ static Symbol parseExtendedSymbol(Parser &p, Token::Kind identifierTok,
Attribute Parser::parseExtendedAttr(Type type) {
MLIRContext *ctx = getContext();
Attribute attr = parseExtendedSymbol<Attribute>(
*this, Token::hash_identifier, state.symbols.attributeAliasDefinitions,
*this, state.symbols.attributeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Attribute {
// Parse an optional trailing colon type.
Type attrType = type;
@ -238,7 +268,7 @@ Attribute Parser::parseExtendedAttr(Type type) {
Type Parser::parseExtendedType() {
MLIRContext *ctx = getContext();
return parseExtendedSymbol<Type>(
*this, Token::exclamation_identifier, state.symbols.typeAliasDefinitions,
*this, state.symbols.typeAliasDefinitions,
[&](StringRef dialectName, StringRef symbolData, SMLoc loc) -> Type {
// If we found a registered dialect, then ask it to parse the type.
if (auto *dialect = ctx->getOrLoadDialect(dialectName)) {

View File

@ -40,6 +40,10 @@ public:
/// Returns the start of the buffer.
const char *getBufferBegin() { return curBuffer.data(); }
/// Return the code completion location of the lexer, or nullptr if there is
/// none.
const char *getCodeCompleteLoc() const { return codeCompleteLoc; }
private:
// Helpers.
Token formToken(Token::Kind kind, const char *tokStart) {

View File

@ -404,6 +404,26 @@ ParseResult Parser::codeCompleteOptionalTokens(ArrayRef<StringRef> tokens) {
return failure();
}
Attribute Parser::codeCompleteAttribute() {
state.codeCompleteContext->completeAttribute(
state.symbols.attributeAliasDefinitions);
return {};
}
Type Parser::codeCompleteType() {
state.codeCompleteContext->completeType(state.symbols.typeAliasDefinitions);
return {};
}
Attribute
Parser::codeCompleteDialectSymbol(const llvm::StringMap<Attribute> &aliases) {
state.codeCompleteContext->completeDialectAttributeOrAlias(aliases);
return {};
}
Type Parser::codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases) {
state.codeCompleteContext->completeDialectTypeOrAlias(aliases);
return {};
}
//===----------------------------------------------------------------------===//
// OperationParser
//===----------------------------------------------------------------------===//

View File

@ -57,7 +57,16 @@ public:
return parseCommaSeparatedList(Delimiter::None, parseElementFn);
}
ParseResult parseDialectSymbolBody(StringRef &body);
/// Parse the body of a dialect symbol, which starts and ends with <>'s, and
/// may be recursive. Return with the 'body' StringRef encompassing the entire
/// body. `isCodeCompletion` is set to true if the body contained a code
/// completion location, in which case the body is only populated up to the
/// completion.
ParseResult parseDialectSymbolBody(StringRef &body, bool &isCodeCompletion);
ParseResult parseDialectSymbolBody(StringRef &body) {
bool isCodeCompletion = false;
return parseDialectSymbolBody(body, isCodeCompletion);
}
// We have two forms of parsing methods - those that return a non-null
// pointer on success, and those that return a ParseResult to indicate whether
@ -322,6 +331,12 @@ public:
ParseResult codeCompleteExpectedTokens(ArrayRef<StringRef> tokens);
ParseResult codeCompleteOptionalTokens(ArrayRef<StringRef> tokens);
Attribute codeCompleteAttribute();
Type codeCompleteType();
Attribute
codeCompleteDialectSymbol(const llvm::StringMap<Attribute> &aliases);
Type codeCompleteDialectSymbol(const llvm::StringMap<Type> &aliases);
protected:
/// The Parser is subclassed and reinstantiated. Do not add additional
/// non-trivial state here, add it to the ParserState class.

View File

@ -358,6 +358,12 @@ Type Parser::parseNonFunctionType() {
// extended type
case Token::exclamation_identifier:
return parseExtendedType();
// Handle completion of a dialect type.
case Token::code_complete:
if (getToken().isCodeCompletionFor(Token::exclamation_identifier))
return parseExtendedType();
return codeCompleteType();
}
}

View File

@ -781,8 +781,9 @@ enum class InsertTextFormat {
struct CompletionItem {
CompletionItem() = default;
CompletionItem(StringRef label, CompletionItemKind kind)
: label(label.str()), kind(kind),
CompletionItem(const Twine &label, CompletionItemKind kind,
StringRef sortText = "")
: label(label.str()), kind(kind), sortText(sortText.str()),
insertTextFormat(InsertTextFormat::PlainText) {}
/// The label of this completion item. By default also the text that is

View File

@ -636,15 +636,17 @@ public:
: AsmParserCodeCompleteContext(completeLoc),
completionList(completionList), ctx(ctx) {}
/// Signal code completion for a dialect name.
void completeDialectName() final {
/// Signal code completion for a dialect name, with an optional prefix.
void completeDialectName(StringRef prefix) final {
for (StringRef dialect : ctx->getAvailableDialects()) {
lsp::CompletionItem item(dialect, lsp::CompletionItemKind::Module);
item.sortText = "2";
lsp::CompletionItem item(prefix + dialect,
lsp::CompletionItemKind::Module,
/*sortText=*/"3");
item.detail = "dialect";
completionList.items.emplace_back(item);
}
}
using AsmParserCodeCompleteContext::completeDialectName;
/// Signal code completion for an operation name within the given dialect.
void completeOperationName(StringRef dialectName) final {
@ -658,8 +660,8 @@ public:
lsp::CompletionItem item(
op.getStringRef().drop_front(dialectName.size() + 1),
lsp::CompletionItemKind::Field);
item.sortText = "1";
lsp::CompletionItemKind::Field,
/*sortText=*/"1");
item.detail = "operation";
completionList.items.emplace_back(item);
}
@ -693,13 +695,71 @@ public:
/// Signal a completion for the given expected token.
void completeExpectedTokens(ArrayRef<StringRef> tokens, bool optional) final {
for (StringRef token : tokens) {
lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword);
item.sortText = "0";
lsp::CompletionItem item(token, lsp::CompletionItemKind::Keyword,
/*sortText=*/"0");
item.detail = optional ? "optional" : "";
completionList.items.emplace_back(item);
}
}
/// Signal a completion for an attribute.
void completeAttribute(const llvm::StringMap<Attribute> &aliases) override {
appendSimpleCompletions({"affine_set", "affine_map", "dense", "false",
"loc", "opaque", "sparse", "true", "unit"},
lsp::CompletionItemKind::Field,
/*sortText=*/"1");
completeDialectName("#");
completeAliases(aliases, "#");
}
void completeDialectAttributeOrAlias(
const llvm::StringMap<Attribute> &aliases) override {
completeDialectName();
completeAliases(aliases);
}
/// Signal a completion for a type.
void completeType(const llvm::StringMap<Type> &aliases) override {
appendSimpleCompletions({"memref", "tensor", "complex", "tuple", "vector",
"bf16", "f16", "f32", "f64", "f80", "f128",
"index", "none"},
lsp::CompletionItemKind::Field,
/*sortText=*/"1");
lsp::CompletionItem item("i<N>", lsp::CompletionItemKind::Field,
/*sortText=*/"1");
item.insertText = "i";
completionList.items.emplace_back(item);
completeDialectName("!");
completeAliases(aliases, "!");
}
void
completeDialectTypeOrAlias(const llvm::StringMap<Type> &aliases) override {
completeDialectName();
completeAliases(aliases);
}
/// Add completion results for the given set of aliases.
template <typename T>
void completeAliases(const llvm::StringMap<T> &aliases,
StringRef prefix = "") {
for (const auto &alias : aliases) {
lsp::CompletionItem item(prefix + alias.getKey(),
lsp::CompletionItemKind::Field,
/*sortText=*/"2");
llvm::raw_string_ostream(item.detail) << "alias: " << alias.getValue();
completionList.items.emplace_back(item);
}
}
/// Add a set of simple completions that all have the same kind.
void appendSimpleCompletions(ArrayRef<StringRef> completions,
lsp::CompletionItemKind kind,
StringRef sortText = "") {
for (StringRef completion : completions)
completionList.items.emplace_back(completion, kind, sortText);
}
private:
lsp::CompletionList &completionList;
MLIRContext *ctx;

View File

@ -408,12 +408,9 @@ void TestDialect::registerTypes() {
Type TestDialect::parseTestType(AsmParser &parser,
SetVector<Type> &stack) const {
StringRef typeTag;
if (failed(parser.parseKeyword(&typeTag)))
return Type();
{
Type genType;
auto parseResult = generatedTypeParser(parser, typeTag, genType);
auto parseResult = generatedTypeParser(parser, &typeTag, genType);
if (parseResult.hasValue())
return genType;
}

View File

@ -5,14 +5,14 @@
"uri":"test:///foo.mlir",
"languageId":"mlir",
"version":1,
"text":"func.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (i32)\nreturn %"
"text":"#attr = i32\n!alias = i32\nfunc.func private @foo(%arg: i32) -> i32 {\n%cast = \"builtin.unrealized_conversion_cast\"() : () -> (!pdl.value)\nreturn %"
}}}
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":0,"character":0}
"position":{"line":2,"character":0}
}}
// CHECK: "id": 1
// CHECK-LABEL: "id": 1
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
@ -22,7 +22,7 @@
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 9,
// CHECK: "label": "builtin",
// CHECK: "sortText": "2"
// CHECK: "sortText": "3"
// CHECK: },
// CHECK: {
// CHECK: "detail": "operation",
@ -34,11 +34,11 @@
// CHECK: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
{"jsonrpc":"2.0","id":2,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":1,"character":9}
"position":{"line":3,"character":9}
}}
// CHECK: "id": 1
// CHECK-LABEL: "id": 2
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
@ -48,17 +48,17 @@
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 9,
// CHECK: "label": "builtin",
// CHECK: "sortText": "2"
// CHECK: "sortText": "3"
// CHECK: },
// CHECK-NOT: "detail": "operation",
// CHECK: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
{"jsonrpc":"2.0","id":3,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":1,"character":17}
"position":{"line":3,"character":17}
}}
// CHECK: "id": 1
// CHECK-LABEL: "id": 3
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
@ -74,17 +74,17 @@
// CHECK: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
{"jsonrpc":"2.0","id":4,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":2,"character":8}
"position":{"line":4,"character":8}
}}
// CHECK: "id": 1
// CHECK-LABEL: "id": 4
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK-NEXT: {
// CHECK-NEXT: "detail": "builtin.unrealized_conversion_cast: i32",
// CHECK-NEXT: "detail": "builtin.unrealized_conversion_cast: !pdl.value",
// CHECK-NEXT: "insertText": "cast",
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 6,
@ -100,11 +100,11 @@
// CHECK: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":1,"method":"textDocument/completion","params":{
{"jsonrpc":"2.0","id":5,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":0,"character":10}
"position":{"line":2,"character":10}
}}
// CHECK: "id": 1
// CHECK-LABEL: "id": 5
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
@ -133,6 +133,134 @@
// CHECK-NEXT: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":3,"method":"shutdown"}
{"jsonrpc":"2.0","id":6,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":0,"character":8}
}}
// CHECK-LABEL: "id": 6
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK: {
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 5,
// CHECK: "label": "false"
// CHECK: },
// CHECK: {
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 5,
// CHECK: "label": "loc"
// CHECK: },
// CHECK: {
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 5,
// CHECK: "label": "true"
// CHECK: },
// CHECK: {
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 5,
// CHECK: "label": "unit"
// CHECK: }
// CHECK: ]
// CHECK: }
// -----
{"jsonrpc":"2.0","id":7,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":3,"character":56}
}}
// CHECK-LABEL: "id": 7
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK: {
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 5,
// CHECK: "label": "index"
// CHECK: },
// CHECK: {
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 5,
// CHECK: "label": "none"
// CHECK: },
// CHECK: {
// CHECK: "insertText": "i",
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 5,
// CHECK: "label": "i<N>"
// CHECK: }
// CHECK: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":8,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":3,"character":57}
}}
// CHECK-LABEL: "id": 8
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK: {
// CHECK: "detail": "dialect",
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 9,
// CHECK: "label": "builtin",
// CHECK: "sortText": "3"
// CHECK: },
// CHECK: {
// CHECK: "detail": "alias: i32",
// CHECK: "insertTextFormat": 1,
// CHECK: "kind": 5,
// CHECK: "label": "alias",
// CHECK: "sortText": "2"
// CHECK: }
// CHECK: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":9,"method":"textDocument/completion","params":{
"textDocument":{"uri":"test:///foo.mlir"},
"position":{"line":3,"character":61}
}}
// CHECK-LABEL: "id": 9
// CHECK-NEXT: "jsonrpc": "2.0",
// CHECK-NEXT: "result": {
// CHECK-NEXT: "isIncomplete": false,
// CHECK-NEXT: "items": [
// CHECK-NEXT: {
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 14,
// CHECK-NEXT: "label": "attribute",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 14,
// CHECK-NEXT: "label": "operation",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 14,
// CHECK-NEXT: "label": "range",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 14,
// CHECK-NEXT: "label": "type",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: },
// CHECK-NEXT: {
// CHECK-NEXT: "insertTextFormat": 1,
// CHECK-NEXT: "kind": 14,
// CHECK-NEXT: "label": "value",
// CHECK-NEXT: "sortText": "0"
// CHECK-NEXT: }
// CHECK-NEXT: ]
// CHECK-NEXT: }
// -----
{"jsonrpc":"2.0","id":10,"method":"shutdown"}
// -----
{"jsonrpc":"2.0","method":"exit"}

View File

@ -21,16 +21,19 @@ include "mlir/IR/OpBase.td"
// DEF-LABEL: ::mlir::OptionalParseResult generatedAttributeParser(
// DEF-SAME: ::mlir::AsmParser &parser,
// DEF-SAME: ::llvm::StringRef mnemonic, ::mlir::Type type,
// DEF-SAME: ::llvm::StringRef *mnemonic, ::mlir::Type type,
// DEF-SAME: ::mlir::Attribute &value) {
// DEF: if (mnemonic == ::test::CompoundAAttr::getMnemonic()) {
// DEF: return ::mlir::AsmParser::KeywordSwitch<::mlir::OptionalParseResult>(parser)
// DEF: .Case(::test::CompoundAAttr::getMnemonic()
// DEF-NEXT: value = ::test::CompoundAAttr::parse(parser, type);
// DEF-NEXT: return ::mlir::success(!!value);
// DEF-NEXT: }
// DEF-NEXT: if (mnemonic == ::test::IndexAttr::getMnemonic()) {
// DEF-NEXT: })
// DEF-NEXT: .Case(::test::IndexAttr::getMnemonic()
// DEF-NEXT: value = ::test::IndexAttr::parse(parser, type);
// DEF-NEXT: return ::mlir::success(!!value);
// DEF: return {};
// DEF: .Default([&](llvm::StringRef keyword,
// DEF-NEXT: *mnemonic = keyword;
// DEF-NEXT: return llvm::None;
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect

View File

@ -27,11 +27,9 @@ def AttrA : TestAttr<"AttrA"> {
// ATTR: ::mlir::Type type) const {
// ATTR: ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
// ATTR: ::llvm::StringRef attrTag;
// ATTR: if (::mlir::failed(parser.parseKeyword(&attrTag)))
// ATTR: return {};
// ATTR: {
// ATTR: ::mlir::Attribute attr;
// ATTR: auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
// ATTR: auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
// ATTR: if (parseResult.hasValue())
// ATTR: return attr;
// ATTR: }
@ -57,10 +55,8 @@ def TypeA : TestType<"TypeA"> {
// TYPE: ::mlir::Type TestDialect::parseType(::mlir::DialectAsmParser &parser) const {
// TYPE: ::llvm::SMLoc typeLoc = parser.getCurrentLocation();
// TYPE: ::llvm::StringRef mnemonic;
// TYPE: if (parser.parseKeyword(&mnemonic))
// TYPE: return ::mlir::Type();
// TYPE: ::mlir::Type genType;
// TYPE: auto parseResult = generatedTypeParser(parser, mnemonic, genType);
// TYPE: auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
// TYPE: if (parseResult.hasValue())
// TYPE: return genType;
// TYPE: parser.emitError(typeLoc) << "unknown type `"

View File

@ -22,16 +22,18 @@ include "mlir/IR/OpBase.td"
// DEF-LABEL: ::mlir::OptionalParseResult generatedTypeParser(
// DEF-SAME: ::mlir::AsmParser &parser,
// DEF-SAME: ::llvm::StringRef mnemonic,
// DEF-SAME: ::llvm::StringRef *mnemonic,
// DEF-SAME: ::mlir::Type &value) {
// DEF: if (mnemonic == ::test::CompoundAType::getMnemonic()) {
// DEF: .Case(::test::CompoundAType::getMnemonic()
// DEF-NEXT: value = ::test::CompoundAType::parse(parser);
// DEF-NEXT: return ::mlir::success(!!value);
// DEF-NEXT: }
// DEF-NEXT: if (mnemonic == ::test::IndexType::getMnemonic()) {
// DEF-NEXT: })
// DEF-NEXT: .Case(::test::IndexType::getMnemonic()
// DEF-NEXT: value = ::test::IndexType::parse(parser);
// DEF-NEXT: return ::mlir::success(!!value);
// DEF: return {};
// DEF: .Default([&](llvm::StringRef keyword,
// DEF-NEXT: *mnemonic = keyword;
// DEF-NEXT: return llvm::None;
def Test_Dialect: Dialect {
// DECL-NOT: TestDialect

View File

@ -673,11 +673,9 @@ static const char *const dialectDefaultAttrPrinterParserDispatch = R"(
::mlir::Type type) const {{
::llvm::SMLoc typeLoc = parser.getCurrentLocation();
::llvm::StringRef attrTag;
if (::mlir::failed(parser.parseKeyword(&attrTag)))
return {{};
{{
::mlir::Attribute attr;
auto parseResult = generatedAttributeParser(parser, attrTag, type, attr);
auto parseResult = generatedAttributeParser(parser, &attrTag, type, attr);
if (parseResult.hasValue())
return attr;
}
@ -723,10 +721,8 @@ static const char *const dialectDefaultTypePrinterParserDispatch = R"(
::mlir::Type {0}::parseType(::mlir::DialectAsmParser &parser) const {{
::llvm::SMLoc typeLoc = parser.getCurrentLocation();
::llvm::StringRef mnemonic;
if (parser.parseKeyword(&mnemonic))
return ::mlir::Type();
::mlir::Type genType;
auto parseResult = generatedTypeParser(parser, mnemonic, genType);
auto parseResult = generatedTypeParser(parser, &mnemonic, genType);
if (parseResult.hasValue())
return genType;
{1}
@ -771,7 +767,7 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
}
// Declare the parser.
SmallVector<MethodParameter> params = {{"::mlir::AsmParser &", "parser"},
{"::llvm::StringRef", "mnemonic"}};
{"::llvm::StringRef *", "mnemonic"}};
if (isAttrGenerator)
params.emplace_back("::mlir::Type", "type");
params.emplace_back(strfmt("::mlir::{0} &", valueType), "value");
@ -784,14 +780,18 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
{{strfmt("::mlir::{0}", valueType), "def"},
{"::mlir::AsmPrinter &", "printer"}});
// The parser dispatch is just a list of if-elses, matching on the mnemonic
// and calling the def's parse function.
// The parser dispatch uses a KeywordSwitch, matching on the mnemonic and
// calling the def's parse function.
parse.body() << " return "
"::mlir::AsmParser::KeywordSwitch<::mlir::"
"OptionalParseResult>(parser)\n";
const char *const getValueForMnemonic =
R"( if (mnemonic == {0}::getMnemonic()) {{
value = {0}::{1};
return ::mlir::success(!!value);
}
R"( .Case({0}::getMnemonic(), [&](llvm::StringRef, llvm::SMLoc) {{
value = {0}::{1};
return ::mlir::success(!!value);
})
)";
// The printer dispatch uses llvm::TypeSwitch to find and call the correct
// printer.
printer.body() << " return ::llvm::TypeSwitch<::mlir::" << valueType
@ -822,7 +822,10 @@ void DefGenerator::emitParsePrintDispatch(ArrayRef<AttrOrTypeDef> defs) {
printDef = "\nt.print(printer);";
printer.body() << llvm::formatv(printValue, defClass, printDef);
}
parse.body() << " return {};";
parse.body() << " .Default([&](llvm::StringRef keyword, llvm::SMLoc) {\n"
" *mnemonic = keyword;\n"
" return llvm::None;\n"
" });";
printer.body() << " .Default([](auto) { return ::mlir::failure(); });";
raw_indented_ostream indentedOs(os);