forked from OSchip/llvm-project
2467 lines
86 KiB
C++
2467 lines
86 KiB
C++
//===- mlir-linalg-ods-gen.cpp - Linalg ODS generation from math form -----===//
|
|
//
|
|
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
|
// See https://llvm.org/LICENSE.txt for license information.
|
|
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
//
|
|
// This file contains the implementation for the Tensor Comprehension-inspired
|
|
// parser and ODS pretty-printer for specifying Linalg "named ops" from a
|
|
// mathematical form.
|
|
//
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
#include "mlir/IR/AffineExpr.h"
|
|
#include "mlir/IR/AffineMap.h"
|
|
#include "mlir/IR/MLIRContext.h"
|
|
#include "mlir/IR/OpImplementation.h"
|
|
#include "mlir/Support/FileUtilities.h"
|
|
#include "mlir/Support/LLVM.h"
|
|
#include "mlir/Support/LogicalResult.h"
|
|
#include "llvm/ADT/DenseMap.h"
|
|
#include "llvm/ADT/Optional.h"
|
|
#include "llvm/ADT/STLExtras.h"
|
|
#include "llvm/ADT/SetVector.h"
|
|
#include "llvm/ADT/SmallVector.h"
|
|
#include "llvm/ADT/StringExtras.h"
|
|
#include "llvm/ADT/StringRef.h"
|
|
#include "llvm/ADT/StringSwitch.h"
|
|
#include "llvm/ADT/Twine.h"
|
|
#include "llvm/Support/Casting.h"
|
|
#include "llvm/Support/CommandLine.h"
|
|
#include "llvm/Support/Debug.h"
|
|
#include "llvm/Support/FormatVariadic.h"
|
|
#include "llvm/Support/MemoryBuffer.h"
|
|
#include "llvm/Support/SourceMgr.h"
|
|
#include "llvm/Support/ToolOutputFile.h"
|
|
|
|
#include <map>
|
|
#include <set>
|
|
|
|
#define DEBUG_TYPE "linalg-ods-gen"
|
|
|
|
static llvm::cl::OptionCategory ODSGenCat("Linalg ODS Gen");
|
|
|
|
// Commandline options
|
|
static llvm::cl::opt<std::string>
|
|
inputFilename(llvm::cl::Positional, llvm::cl::desc("<input file>"),
|
|
llvm::cl::init("-"), llvm::cl::value_desc("filename"));
|
|
|
|
static llvm::cl::opt<std::string>
|
|
outputFilename("o", llvm::cl::desc("Output filename"),
|
|
llvm::cl::value_desc("filename"), llvm::cl::init("-"));
|
|
|
|
static llvm::cl::opt<bool>
|
|
genODSDecl("gen-ods-decl", llvm::cl::desc("Emit the ODS ops declarations."),
|
|
llvm::cl::cat(ODSGenCat));
|
|
|
|
static llvm::cl::opt<bool>
|
|
genODSImpl("gen-impl", llvm::cl::desc("Emit the ops implementations"),
|
|
llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
|
|
|
|
static llvm::cl::opt<bool> testEmitIncludeTdHeader(
|
|
"test-emit-include-td-header",
|
|
llvm::cl::desc("Include LinalgStructuredOps.td for end-to-end "
|
|
"tblgen testing."),
|
|
llvm::cl::init(false), llvm::cl::cat(ODSGenCat));
|
|
|
|
using llvm::SMLoc;
|
|
using llvm::StringRef;
|
|
using llvm::Twine;
|
|
|
|
using namespace mlir;
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Special "op aliases" substitutions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Perform substitutions of known special ops.
|
|
/// This is a poor man's way of achieving "op aliases": i.e. giving an op a
|
|
/// name.
|
|
/// This is hacky and temporary until migration to the python opdsl is complete.
|
|
static void substituteOpAliases(std::string &expressionsStr) {
|
|
for (auto kvp : SmallVector<std::pair<std::string, std::string>>{
|
|
{"b.create<CmpIOpSGT>(", "b.create<CmpIOp>(CmpIPredicate::sgt, "},
|
|
{"b.create<CmpFOpOGT>(", "b.create<CmpFOp>(CmpFPredicate::OGT, "},
|
|
{"b.create<CmpFOpOLT>(", "b.create<CmpFOp>(CmpFPredicate::OLT, "},
|
|
{"b.create<SignExtendIOp32>(",
|
|
"b.create<SignExtendIOp>(b.getI32Type(), "},
|
|
}) {
|
|
size_t pos = 0;
|
|
while ((pos = expressionsStr.find(kvp.first, pos)) != std::string::npos) {
|
|
expressionsStr.replace(pos, kvp.first.size(), kvp.second);
|
|
pos += kvp.second.size();
|
|
}
|
|
}
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Lexer
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
/// This class represents a specific token in the input format.
|
|
class Token {
|
|
public:
|
|
enum class Kind {
|
|
// Markers.
|
|
eof,
|
|
error,
|
|
|
|
// Tokens with no info.
|
|
colon,
|
|
comma,
|
|
doc_str,
|
|
equal,
|
|
gt,
|
|
l_brace,
|
|
l_paren,
|
|
l_square,
|
|
lt,
|
|
minus,
|
|
plus,
|
|
question,
|
|
r_brace,
|
|
r_paren,
|
|
r_square,
|
|
semicolon,
|
|
star,
|
|
|
|
// Keywords.
|
|
kw_def,
|
|
FIRST_KEYWORD = kw_def,
|
|
kw_ods_def,
|
|
kw_implements_interface,
|
|
kw_attr_def,
|
|
kw_floordiv,
|
|
kw_ceildiv,
|
|
kw_mod,
|
|
LAST_KEYWORD = kw_mod,
|
|
|
|
// String valued tokens.
|
|
id,
|
|
integer,
|
|
};
|
|
|
|
Token(Kind kind, StringRef spelling) : kind(kind), spelling(spelling) {}
|
|
|
|
/// Return the bytes that make up this token.
|
|
StringRef getSpelling() const { return spelling; }
|
|
|
|
/// Return the kind of this token.
|
|
Kind getKind() const { return kind; }
|
|
|
|
/// Return a location for this token.
|
|
llvm::SMLoc getLoc() const {
|
|
return llvm::SMLoc::getFromPointer(spelling.data());
|
|
}
|
|
|
|
/// Return if this token is a keyword.
|
|
bool isKeyword() const {
|
|
return kind >= Kind::FIRST_KEYWORD && kind <= Kind::LAST_KEYWORD;
|
|
}
|
|
bool is(Kind k) const { return kind == k; }
|
|
bool isNot(Kind k) const { return kind != k; }
|
|
|
|
Optional<uint64_t> getUInt64IntegerValue() const {
|
|
bool isHex = spelling.size() > 1 && spelling[1] == 'x';
|
|
|
|
uint64_t result = 0;
|
|
if (spelling.getAsInteger(isHex ? 0 : 10, result))
|
|
return None;
|
|
return result;
|
|
}
|
|
|
|
private:
|
|
/// Discriminator that indicates the kind of token this is.
|
|
Kind kind;
|
|
|
|
/// A reference to the entire token contents; this is always a pointer into
|
|
/// a memory buffer owned by the source manager.
|
|
StringRef spelling;
|
|
};
|
|
|
|
/// This class implements a simple lexer.
|
|
class Lexer {
|
|
public:
|
|
Lexer(llvm::SourceMgr &mgr);
|
|
|
|
/// Lex the next token and return it.
|
|
Token lexToken();
|
|
|
|
/// Emit an error to the lexer with the given location and message.
|
|
Token emitError(llvm::SMLoc loc, const Twine &msg);
|
|
Token emitError(const char *loc, const Twine &msg);
|
|
|
|
/// Change the position of the lexer cursor. The next token we lex will start
|
|
/// at the designated point in the input.
|
|
void resetPointer(const char *newPtr) { curPtr = newPtr; }
|
|
|
|
private:
|
|
Token formToken(Token::Kind kind, const char *tokStart) {
|
|
return Token(kind, StringRef(tokStart, curPtr - tokStart));
|
|
}
|
|
|
|
/// Return the next character in the stream.
|
|
int getNextChar();
|
|
|
|
/// Lex an identifier.
|
|
Token lexIdentifier(const char *tokStart);
|
|
|
|
// Lex an integer.
|
|
Token lexInteger(const char *tokStart);
|
|
|
|
// Lex a string.
|
|
Token lexString(const char *tokStart);
|
|
|
|
// Skip a comment line, starting with a '//'.
|
|
void skipComment();
|
|
|
|
llvm::SourceMgr &srcMgr;
|
|
StringRef curBuffer;
|
|
const char *curPtr;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
Lexer::Lexer(llvm::SourceMgr &mgr) : srcMgr(mgr) {
|
|
curBuffer = srcMgr.getMemoryBuffer(mgr.getMainFileID())->getBuffer();
|
|
curPtr = curBuffer.begin();
|
|
}
|
|
|
|
Token Lexer::emitError(llvm::SMLoc loc, const Twine &msg) {
|
|
srcMgr.PrintMessage(loc, llvm::SourceMgr::DK_Error, msg);
|
|
return formToken(Token::Kind::error, loc.getPointer());
|
|
}
|
|
Token Lexer::emitError(const char *loc, const Twine &msg) {
|
|
return emitError(llvm::SMLoc::getFromPointer(loc), msg);
|
|
}
|
|
|
|
int Lexer::getNextChar() {
|
|
char curChar = *curPtr++;
|
|
switch (curChar) {
|
|
default:
|
|
return (unsigned char)curChar;
|
|
case 0: {
|
|
// A nul character in the stream is either the end of the current buffer
|
|
// or a random nul in the file. Disambiguate that here.
|
|
if (curPtr - 1 != curBuffer.end())
|
|
return 0;
|
|
|
|
// Otherwise, return end of file.
|
|
--curPtr;
|
|
return EOF;
|
|
}
|
|
case '\n':
|
|
case '\r':
|
|
// Handle the newline character by ignoring it and incrementing the line
|
|
// count. However, be careful about 'dos style' files with \n\r in them.
|
|
// Only treat a \n\r or \r\n as a single line.
|
|
if ((*curPtr == '\n' || (*curPtr == '\r')) && *curPtr != curChar)
|
|
++curPtr;
|
|
return '\n';
|
|
}
|
|
}
|
|
|
|
Token Lexer::lexToken() {
|
|
while (true) {
|
|
const char *tokStart = curPtr;
|
|
|
|
// This always consumes at least one character.
|
|
int curChar = getNextChar();
|
|
switch (curChar) {
|
|
default:
|
|
// Handle identifiers: [a-zA-Z_]
|
|
if (isalpha(curChar) || curChar == '_')
|
|
return lexIdentifier(tokStart);
|
|
|
|
// Handle integers: [0-9]
|
|
if (isdigit(curChar))
|
|
return lexInteger(tokStart);
|
|
|
|
// Unknown character, emit an error.
|
|
return emitError(tokStart, "unexpected character");
|
|
|
|
case EOF:
|
|
// Return EOF denoting the end of lexing.
|
|
return formToken(Token::Kind::eof, tokStart);
|
|
|
|
// Lex punctuation.
|
|
case ':':
|
|
return formToken(Token::Kind::colon, tokStart);
|
|
case ',':
|
|
return formToken(Token::Kind::comma, tokStart);
|
|
case '=':
|
|
return formToken(Token::Kind::equal, tokStart);
|
|
case '{':
|
|
return formToken(Token::Kind::l_brace, tokStart);
|
|
case '(':
|
|
return formToken(Token::Kind::l_paren, tokStart);
|
|
case '[':
|
|
return formToken(Token::Kind::l_square, tokStart);
|
|
case '}':
|
|
return formToken(Token::Kind::r_brace, tokStart);
|
|
case ')':
|
|
return formToken(Token::Kind::r_paren, tokStart);
|
|
case ']':
|
|
return formToken(Token::Kind::r_square, tokStart);
|
|
case '<':
|
|
return formToken(Token::Kind::lt, tokStart);
|
|
case '>':
|
|
return formToken(Token::Kind::gt, tokStart);
|
|
case '+':
|
|
return formToken(Token::Kind::plus, tokStart);
|
|
case '-':
|
|
return formToken(Token::Kind::minus, tokStart);
|
|
case ';':
|
|
return formToken(Token::Kind::semicolon, tokStart);
|
|
case '*':
|
|
return formToken(Token::Kind::star, tokStart);
|
|
case '?':
|
|
return formToken(Token::Kind::question, tokStart);
|
|
case '"':
|
|
return lexString(tokStart);
|
|
case '/':
|
|
if (*curPtr == '/') {
|
|
skipComment();
|
|
continue;
|
|
}
|
|
// Unknown character, emit an error.
|
|
return emitError(tokStart, "unexpected character: not a comment");
|
|
|
|
// Ignore whitespace characters.
|
|
case 0:
|
|
case ' ':
|
|
case '\t':
|
|
case '\n':
|
|
return lexToken();
|
|
}
|
|
}
|
|
}
|
|
|
|
Token Lexer::lexIdentifier(const char *tokStart) {
|
|
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
|
|
while (isalnum(*curPtr) || *curPtr == '_' || *curPtr == '-')
|
|
++curPtr;
|
|
|
|
// Check to see if this identifier is a keyword.
|
|
StringRef str(tokStart, curPtr - tokStart);
|
|
Token::Kind kind =
|
|
StringSwitch<Token::Kind>(str)
|
|
.Case("attr", Token::Kind::kw_attr_def)
|
|
.Case("def", Token::Kind::kw_def)
|
|
.Case("ods_def", Token::Kind::kw_ods_def)
|
|
.Case("implements_interface", Token::Kind::kw_implements_interface)
|
|
.Case("floordiv", Token::Kind::kw_floordiv)
|
|
.Case("ceildiv", Token::Kind::kw_ceildiv)
|
|
.Case("mod", Token::Kind::kw_mod)
|
|
.Default(Token::Kind::id);
|
|
|
|
return Token(kind, str);
|
|
}
|
|
|
|
Token Lexer::lexInteger(const char *tokStart) {
|
|
// Match the rest of the identifier regex: [0-9a-zA-Z_\-]*
|
|
while (isdigit(*curPtr))
|
|
++curPtr;
|
|
|
|
StringRef str(tokStart, curPtr - tokStart);
|
|
return Token(Token::Kind::integer, str);
|
|
}
|
|
|
|
Token Lexer::lexString(const char *tokStart) {
|
|
assert(curPtr[-1] == '"');
|
|
|
|
if (*curPtr == '"' && *(curPtr + 1) == '"') {
|
|
curPtr += 2;
|
|
while (true) {
|
|
switch (*curPtr++) {
|
|
case '"':
|
|
if (*curPtr == '"' && *(curPtr + 1) == '"') {
|
|
Token token(Token::Kind::doc_str,
|
|
StringRef(tokStart + 3, curPtr - tokStart - 4));
|
|
curPtr += 2;
|
|
return token;
|
|
}
|
|
continue;
|
|
case 0:
|
|
// If this is a random nul character in the middle of the doc string,
|
|
// just include it. If it is the end of file, then it is an error.
|
|
if (curPtr - 1 != curBuffer.end())
|
|
continue;
|
|
return emitError(curPtr - 1, "expected '\"\"\"' to end doc string");
|
|
default:
|
|
continue;
|
|
}
|
|
}
|
|
}
|
|
|
|
return emitError(curPtr - 1, "expected '\"\"\"' to start doc string");
|
|
}
|
|
|
|
/// Skip a comment line, starting with a '//'.
|
|
void Lexer::skipComment() {
|
|
// Advance over the second '/' in a '//' comment.
|
|
assert(*curPtr == '/');
|
|
++curPtr;
|
|
|
|
while (true) {
|
|
switch (*curPtr++) {
|
|
case '\n':
|
|
case '\r':
|
|
// Newline is end of comment.
|
|
return;
|
|
case 0:
|
|
// If this is the end of the buffer, end the comment.
|
|
if (curPtr - 1 == curBuffer.end()) {
|
|
--curPtr;
|
|
return;
|
|
}
|
|
LLVM_FALLTHROUGH;
|
|
default:
|
|
// Skip over other characters.
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
namespace {
|
|
|
|
class Parser {
|
|
public:
|
|
Parser(llvm::SourceMgr &mgr, MLIRContext *ctx)
|
|
: lexer(mgr), curToken(lexer.lexToken()), context(ctx) {}
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Lexer Utilities
|
|
//===--------------------------------------------------------------------===//
|
|
|
|
LogicalResult parseInteger(uint64_t &value) {
|
|
if (!curToken.is(Token::Kind::integer))
|
|
return emitError(curToken.getLoc(), "expected integer");
|
|
value = curToken.getUInt64IntegerValue().getValue();
|
|
consumeToken();
|
|
return success();
|
|
}
|
|
|
|
/// Advance the current lexer onto the next token.
|
|
void consumeToken() {
|
|
assert(curToken.getKind() != Token::Kind::eof &&
|
|
curToken.getKind() != Token::Kind::error &&
|
|
"shouldn't advance past EOF or errors");
|
|
curToken = lexer.lexToken();
|
|
}
|
|
|
|
void consumeToken(Token::Kind kind) {
|
|
assert(curToken.getKind() == kind && "unexpected token");
|
|
curToken = lexer.lexToken();
|
|
}
|
|
|
|
LogicalResult parseToken(Token::Kind kind, const Twine &msg) {
|
|
if (curToken.getKind() != kind)
|
|
return emitError(curToken.getLoc(), msg);
|
|
consumeToken();
|
|
return success();
|
|
}
|
|
|
|
/// Parses an optional token and returns failure if failed to parse.
|
|
LogicalResult parseOptionalToken(Token::Kind kind) {
|
|
return success(consumeIf(kind));
|
|
}
|
|
|
|
LogicalResult emitError(llvm::SMLoc loc, const Twine &msg) {
|
|
lexer.emitError(loc, msg);
|
|
return failure();
|
|
}
|
|
|
|
LogicalResult emitError(const Twine &msg) {
|
|
return emitError(curToken.getLoc(), msg);
|
|
}
|
|
|
|
bool consumeIf(Token::Kind kind) {
|
|
if (curToken.isNot(kind))
|
|
return false;
|
|
consumeToken(kind);
|
|
return true;
|
|
}
|
|
|
|
LogicalResult
|
|
parseCommaSeparatedList(llvm::function_ref<ParseResult()> parseElement) {
|
|
// Non-empty case starts with an element.
|
|
if (parseElement())
|
|
return failure();
|
|
|
|
// Otherwise we have a list of comma separated elements.
|
|
while (consumeIf(Token::Kind::comma)) {
|
|
if (parseElement())
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
LogicalResult
|
|
parseCommaSeparatedListUntil(Token::Kind rightToken,
|
|
llvm::function_ref<ParseResult()> parseElement,
|
|
bool allowEmptyList) {
|
|
// Handle the empty case.
|
|
if (curToken.is(rightToken)) {
|
|
if (!allowEmptyList)
|
|
return emitError("expected list element");
|
|
consumeToken(rightToken);
|
|
return success();
|
|
}
|
|
|
|
if (failed(parseCommaSeparatedList(parseElement)) ||
|
|
failed(
|
|
parseToken(rightToken, "expected ',' or right-terminating token")))
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
Lexer lexer;
|
|
Token curToken;
|
|
MLIRContext *context;
|
|
};
|
|
} // namespace
|
|
|
|
/// Encodes an attribute use of the form:
|
|
///
|
|
/// index-list ::= integer-literal (`,` integer-literal)*
|
|
/// attr-use ::= bare-id `[` index-list `]`
|
|
struct AttrUse {
|
|
// Referenced attribute
|
|
StringRef attrName;
|
|
// Indices into the attribute
|
|
SmallVector<uint64_t, 4> indices;
|
|
/// Affine symbol for this usage.
|
|
/// This is represented as an affine symbol because at the time of parsing the
|
|
/// spec and generating the op's ODS/C++, we don't know the concrete constant
|
|
/// value. But they should be replaced with constants read from the attribute
|
|
/// and thus folded away for concrete op instances.
|
|
AffineExpr symbol;
|
|
|
|
std::string getKey() {
|
|
SmallVector<std::string, 4> indexStrs;
|
|
for (uint64_t index : indices)
|
|
indexStrs.push_back(std::to_string(index));
|
|
return llvm::formatv("{0}[{1}]", attrName, llvm::join(indexStrs, ","));
|
|
}
|
|
};
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Affine parsing.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Lower precedence ops (all at the same precedence level). LNoOp is false in
|
|
/// the boolean sense.
|
|
enum AffineLowPrecOp {
|
|
/// Null value.
|
|
LNoOp,
|
|
Add,
|
|
Sub
|
|
};
|
|
|
|
/// Higher precedence ops - all at the same precedence level. HNoOp is false
|
|
/// in the boolean sense.
|
|
enum AffineHighPrecOp {
|
|
/// Null value.
|
|
HNoOp,
|
|
Mul,
|
|
FloorDiv,
|
|
CeilDiv,
|
|
Mod
|
|
};
|
|
|
|
using AffineDimList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
|
|
using AffineSymbolList = SmallVector<std::pair<StringRef, AffineExpr>, 4>;
|
|
|
|
/// This is a specialized parser for affine expressions.
|
|
class AffineParser {
|
|
public:
|
|
/// Creates an affine parser that parses tokens from `p`.
|
|
///
|
|
/// The affine parser introduces new dimensions and symbols eagerly as new
|
|
/// `id` are discovered. To additionally support attribute use `id`s, for a
|
|
/// parsed `id`, the resolution mechanism proceeds as follows:
|
|
/// 1. Try to parse `id` as an attribute use (using the `attrUseParsingHook`).
|
|
/// 2. If unsuccessful, try to match `id` to a known dim or symbol.
|
|
/// 3. If still unsuccessful, eagerly create a new dim or symbol and add it to
|
|
/// the known dims or symbols (using the `bareIdParsingHook`).
|
|
explicit AffineParser(
|
|
Parser &p, std::function<AffineExpr(StringRef)> bareIdParsingHook,
|
|
std::function<llvm::Optional<AffineExpr>()> attrUseParsingHook,
|
|
AffineDimList &dimList, AffineSymbolList &symbolList)
|
|
: parser(p), bareIdFallback(bareIdParsingHook),
|
|
attrUseCallback(attrUseParsingHook), dims(dimList),
|
|
symbols(symbolList) {}
|
|
|
|
/// Parse a comma-separated list of affine exprs.
|
|
SmallVector<AffineExpr, 4>
|
|
parseAffineExprs(Token::Kind lDelim = Token::Kind::l_paren,
|
|
Token::Kind rDelim = Token::Kind::r_paren);
|
|
|
|
/// Parse a single affine expr.`.
|
|
AffineExpr parseAffineExpr();
|
|
|
|
private:
|
|
// Binary affine op parsing.
|
|
AffineLowPrecOp consumeIfLowPrecOp();
|
|
AffineHighPrecOp consumeIfHighPrecOp();
|
|
|
|
// AffineExpr parsing.
|
|
AffineExpr parseParentheticalExpr();
|
|
AffineExpr parseNegateExpression(AffineExpr lhs);
|
|
AffineExpr parseIntegerExpr();
|
|
AffineExpr parseAttrUseOrBareIdExpr();
|
|
AffineExpr parseBareIdExpr();
|
|
|
|
AffineExpr getAffineBinaryOpExpr(AffineHighPrecOp op, AffineExpr lhs,
|
|
AffineExpr rhs, SMLoc opLoc);
|
|
AffineExpr getAffineBinaryOpExpr(AffineLowPrecOp op, AffineExpr lhs,
|
|
AffineExpr rhs);
|
|
AffineExpr parseAffineOperandExpr(AffineExpr lhs);
|
|
AffineExpr parseAffineLowPrecOpExpr(AffineExpr llhs, AffineLowPrecOp llhsOp);
|
|
AffineExpr parseAffineHighPrecOpExpr(AffineExpr llhs, AffineHighPrecOp llhsOp,
|
|
SMLoc llhsOpLoc);
|
|
|
|
Parser &parser;
|
|
std::function<AffineExpr(StringRef)> bareIdFallback;
|
|
std::function<llvm::Optional<AffineExpr>()> attrUseCallback;
|
|
AffineDimList &dims;
|
|
AffineSymbolList &symbols;
|
|
};
|
|
} // end anonymous namespace
|
|
|
|
/// Create an affine binary high precedence op expression (mul's, div's, mod).
|
|
/// opLoc is the location of the op token to be used to report errors
|
|
/// for non-conforming expressions.
|
|
AffineExpr AffineParser::getAffineBinaryOpExpr(AffineHighPrecOp op,
|
|
AffineExpr lhs, AffineExpr rhs,
|
|
SMLoc opLoc) {
|
|
switch (op) {
|
|
case Mul:
|
|
if (!lhs.isSymbolicOrConstant() && !rhs.isSymbolicOrConstant()) {
|
|
(void)parser.emitError(
|
|
opLoc, "non-affine expression: at least one of the multiply "
|
|
"operands has to be either a constant or symbolic");
|
|
return nullptr;
|
|
}
|
|
return lhs * rhs;
|
|
case FloorDiv:
|
|
if (!rhs.isSymbolicOrConstant()) {
|
|
(void)parser.emitError(opLoc,
|
|
"non-affine expression: right operand of floordiv "
|
|
"has to be either a constant or symbolic");
|
|
return nullptr;
|
|
}
|
|
return lhs.floorDiv(rhs);
|
|
case CeilDiv:
|
|
if (!rhs.isSymbolicOrConstant()) {
|
|
(void)parser.emitError(opLoc,
|
|
"non-affine expression: right operand of ceildiv "
|
|
"has to be either a constant or symbolic");
|
|
return nullptr;
|
|
}
|
|
return lhs.ceilDiv(rhs);
|
|
case Mod:
|
|
if (!rhs.isSymbolicOrConstant()) {
|
|
(void)parser.emitError(opLoc,
|
|
"non-affine expression: right operand of mod "
|
|
"has to be either a constant or symbolic");
|
|
return nullptr;
|
|
}
|
|
return lhs % rhs;
|
|
case HNoOp:
|
|
llvm_unreachable("can't create affine expression for null high prec op");
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("Unknown AffineHighPrecOp");
|
|
}
|
|
|
|
/// Create an affine binary low precedence op expression (add, sub).
|
|
AffineExpr AffineParser::getAffineBinaryOpExpr(AffineLowPrecOp op,
|
|
AffineExpr lhs, AffineExpr rhs) {
|
|
switch (op) {
|
|
case AffineLowPrecOp::Add:
|
|
return lhs + rhs;
|
|
case AffineLowPrecOp::Sub:
|
|
return lhs - rhs;
|
|
case AffineLowPrecOp::LNoOp:
|
|
llvm_unreachable("can't create affine expression for null low prec op");
|
|
return nullptr;
|
|
}
|
|
llvm_unreachable("Unknown AffineLowPrecOp");
|
|
}
|
|
|
|
/// Consume this token if it is a lower precedence affine op (there are only
|
|
/// two precedence levels).
|
|
AffineLowPrecOp AffineParser::consumeIfLowPrecOp() {
|
|
switch (parser.curToken.getKind()) {
|
|
case Token::Kind::plus:
|
|
parser.consumeToken();
|
|
return AffineLowPrecOp::Add;
|
|
case Token::Kind::minus:
|
|
parser.consumeToken();
|
|
return AffineLowPrecOp::Sub;
|
|
default:
|
|
return AffineLowPrecOp::LNoOp;
|
|
}
|
|
}
|
|
|
|
/// Consume this token if it is a higher precedence affine op (there are only
|
|
/// two precedence levels)
|
|
AffineHighPrecOp AffineParser::consumeIfHighPrecOp() {
|
|
switch (parser.curToken.getKind()) {
|
|
case Token::Kind::star:
|
|
parser.consumeToken(Token::Kind::star);
|
|
return Mul;
|
|
case Token::Kind::kw_floordiv:
|
|
parser.consumeToken(Token::Kind::kw_floordiv);
|
|
return FloorDiv;
|
|
case Token::Kind::kw_ceildiv:
|
|
parser.consumeToken(Token::Kind::kw_ceildiv);
|
|
return CeilDiv;
|
|
case Token::Kind::kw_mod:
|
|
parser.consumeToken(Token::Kind::kw_mod);
|
|
return Mod;
|
|
default:
|
|
return HNoOp;
|
|
}
|
|
}
|
|
|
|
/// Parse a high precedence op expression list: mul, div, and mod are high
|
|
/// precedence binary ops, i.e., parse a
|
|
/// expr_1 op_1 expr_2 op_2 ... expr_n
|
|
/// where op_1, op_2 are all a AffineHighPrecOp (mul, div, mod).
|
|
/// All affine binary ops are left associative.
|
|
/// Given llhs, returns (llhs llhsOp lhs) op rhs, or (lhs op rhs) if llhs is
|
|
/// null. If no rhs can be found, returns (llhs llhsOp lhs) or lhs if llhs is
|
|
/// null. llhsOpLoc is the location of the llhsOp token that will be used to
|
|
/// report an error for non-conforming expressions.
|
|
AffineExpr AffineParser::parseAffineHighPrecOpExpr(AffineExpr llhs,
|
|
AffineHighPrecOp llhsOp,
|
|
SMLoc llhsOpLoc) {
|
|
AffineExpr lhs = parseAffineOperandExpr(llhs);
|
|
if (!lhs)
|
|
return nullptr;
|
|
|
|
// Found an LHS. Parse the remaining expression.
|
|
auto opLoc = parser.curToken.getLoc();
|
|
if (AffineHighPrecOp op = consumeIfHighPrecOp()) {
|
|
if (llhs) {
|
|
AffineExpr expr = getAffineBinaryOpExpr(llhsOp, llhs, lhs, opLoc);
|
|
if (!expr)
|
|
return nullptr;
|
|
return parseAffineHighPrecOpExpr(expr, op, opLoc);
|
|
}
|
|
// No LLHS, get RHS
|
|
return parseAffineHighPrecOpExpr(lhs, op, opLoc);
|
|
}
|
|
|
|
// This is the last operand in this expression.
|
|
if (llhs)
|
|
return getAffineBinaryOpExpr(llhsOp, llhs, lhs, llhsOpLoc);
|
|
|
|
// No llhs, 'lhs' itself is the expression.
|
|
return lhs;
|
|
}
|
|
|
|
/// Parse an affine expression inside parentheses.
|
|
///
|
|
/// affine-expr ::= `(` affine-expr `)`
|
|
AffineExpr AffineParser::parseParentheticalExpr() {
|
|
if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
|
|
return nullptr;
|
|
if (parser.curToken.is(Token::Kind::r_paren))
|
|
return ((void)parser.emitError("no expression inside parentheses"),
|
|
nullptr);
|
|
|
|
auto expr = parseAffineExpr();
|
|
if (!expr)
|
|
return nullptr;
|
|
if (failed(parser.parseToken(Token::Kind::r_paren, "expected ')'")))
|
|
return nullptr;
|
|
|
|
return expr;
|
|
}
|
|
|
|
/// Parse the negation expression.
|
|
///
|
|
/// affine-expr ::= `-` affine-expr
|
|
AffineExpr AffineParser::parseNegateExpression(AffineExpr lhs) {
|
|
if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")))
|
|
return nullptr;
|
|
|
|
AffineExpr operand = parseAffineOperandExpr(lhs);
|
|
// Since negation has the highest precedence of all ops (including high
|
|
// precedence ops) but lower than parentheses, we are only going to use
|
|
// parseAffineOperandExpr instead of parseAffineExpr here.
|
|
if (!operand)
|
|
// Extra error message although parseAffineOperandExpr would have
|
|
// complained. Leads to a better diagnostic.
|
|
return ((void)parser.emitError("missing operand of negation"), nullptr);
|
|
return (-1) * operand;
|
|
}
|
|
|
|
AffineExpr AffineParser::parseAttrUseOrBareIdExpr() {
|
|
if (llvm::Optional<AffineExpr> attrUse = attrUseCallback())
|
|
return attrUse.getValue();
|
|
return parseBareIdExpr();
|
|
}
|
|
|
|
/// Parse a bare id that may appear in an affine expression.
|
|
///
|
|
/// affine-expr ::= bare-id
|
|
AffineExpr AffineParser::parseBareIdExpr() {
|
|
if (parser.curToken.isNot(Token::Kind::id))
|
|
return ((void)parser.emitError("expected id"), nullptr);
|
|
|
|
StringRef sRef = parser.curToken.getSpelling();
|
|
for (auto &list : {dims, symbols}) {
|
|
for (auto entry : list) {
|
|
if (entry.first == sRef) {
|
|
parser.consumeToken(Token::Kind::id);
|
|
return entry.second;
|
|
}
|
|
}
|
|
}
|
|
|
|
// Not found, check fallback path.
|
|
AffineExpr expr = bareIdFallback(sRef);
|
|
if (expr) {
|
|
parser.consumeToken(Token::Kind::id);
|
|
return expr;
|
|
}
|
|
|
|
return ((void)parser.emitError("use of undeclared id"), nullptr);
|
|
}
|
|
|
|
/// Parse a positive integral constant appearing in an affine expression.
|
|
///
|
|
/// affine-expr ::= integer-literal
|
|
AffineExpr AffineParser::parseIntegerExpr() {
|
|
auto val = parser.curToken.getUInt64IntegerValue();
|
|
if (!val.hasValue() || (int64_t)val.getValue() < 0)
|
|
return ((void)parser.emitError("constant too large for index"), nullptr);
|
|
|
|
parser.consumeToken(Token::Kind::integer);
|
|
return getAffineConstantExpr((int64_t)val.getValue(), parser.context);
|
|
}
|
|
|
|
/// Parses an expression that can be a valid operand of an affine expression.
|
|
/// lhs: if non-null, lhs is an affine expression that is the lhs of a binary
|
|
/// operator, the rhs of which is being parsed. This is used to determine
|
|
/// whether an error should be emitted for a missing right operand.
|
|
// Eg: for an expression without parentheses (like i + j + k + l), each
|
|
// of the four identifiers is an operand. For i + j*k + l, j*k is not an
|
|
// operand expression, it's an op expression and will be parsed via
|
|
// parseAffineHighPrecOpExpression(). However, for i + (j*k) + -l, (j*k) and
|
|
// -l are valid operands that will be parsed by this function.
|
|
AffineExpr AffineParser::parseAffineOperandExpr(AffineExpr lhs) {
|
|
switch (parser.curToken.getKind()) {
|
|
case Token::Kind::id:
|
|
return parseAttrUseOrBareIdExpr();
|
|
case Token::Kind::integer:
|
|
return parseIntegerExpr();
|
|
case Token::Kind::l_paren:
|
|
return parseParentheticalExpr();
|
|
case Token::Kind::minus:
|
|
return parseNegateExpression(lhs);
|
|
case Token::Kind::kw_ceildiv:
|
|
case Token::Kind::kw_floordiv:
|
|
case Token::Kind::kw_mod:
|
|
case Token::Kind::plus:
|
|
case Token::Kind::star:
|
|
if (lhs)
|
|
(void)parser.emitError("missing right operand of binary operator");
|
|
else
|
|
(void)parser.emitError("missing left operand of binary operator");
|
|
return nullptr;
|
|
default:
|
|
if (lhs)
|
|
(void)parser.emitError("missing right operand of binary operator");
|
|
else
|
|
(void)parser.emitError("expected affine expression");
|
|
return nullptr;
|
|
}
|
|
}
|
|
|
|
/// Parse affine expressions that are bare-id's, integer constants,
|
|
/// parenthetical affine expressions, and affine op expressions that are a
|
|
/// composition of those.
|
|
///
|
|
/// All binary op's associate from left to right.
|
|
///
|
|
/// {add, sub} have lower precedence than {mul, div, and mod}.
|
|
///
|
|
/// Add, sub'are themselves at the same precedence level. Mul, floordiv,
|
|
/// ceildiv, and mod are at the same higher precedence level. Negation has
|
|
/// higher precedence than any binary op.
|
|
///
|
|
/// llhs: the affine expression appearing on the left of the one being parsed.
|
|
/// This function will return ((llhs llhsOp lhs) op rhs) if llhs is non null,
|
|
/// and lhs op rhs otherwise; if there is no rhs, llhs llhsOp lhs is returned
|
|
/// if llhs is non-null; otherwise lhs is returned. This is to deal with left
|
|
/// associativity.
|
|
///
|
|
/// Eg: when the expression is e1 + e2*e3 + e4, with e1 as llhs, this function
|
|
/// will return the affine expr equivalent of (e1 + (e2*e3)) + e4, where
|
|
/// (e2*e3) will be parsed using parseAffineHighPrecOpExpr().
|
|
AffineExpr AffineParser::parseAffineLowPrecOpExpr(AffineExpr llhs,
|
|
AffineLowPrecOp llhsOp) {
|
|
AffineExpr lhs;
|
|
if (!(lhs = parseAffineOperandExpr(llhs)))
|
|
return nullptr;
|
|
|
|
// Found an LHS. Deal with the ops.
|
|
if (AffineLowPrecOp lOp = consumeIfLowPrecOp()) {
|
|
if (llhs) {
|
|
AffineExpr sum = getAffineBinaryOpExpr(llhsOp, llhs, lhs);
|
|
return parseAffineLowPrecOpExpr(sum, lOp);
|
|
}
|
|
// No LLHS, get RHS and form the expression.
|
|
return parseAffineLowPrecOpExpr(lhs, lOp);
|
|
}
|
|
auto opLoc = parser.curToken.getLoc();
|
|
if (AffineHighPrecOp hOp = consumeIfHighPrecOp()) {
|
|
// We have a higher precedence op here. Get the rhs operand for the llhs
|
|
// through parseAffineHighPrecOpExpr.
|
|
AffineExpr highRes = parseAffineHighPrecOpExpr(lhs, hOp, opLoc);
|
|
if (!highRes)
|
|
return nullptr;
|
|
|
|
// If llhs is null, the product forms the first operand of the yet to be
|
|
// found expression. If non-null, the op to associate with llhs is llhsOp.
|
|
AffineExpr expr =
|
|
llhs ? getAffineBinaryOpExpr(llhsOp, llhs, highRes) : highRes;
|
|
|
|
// Recurse for subsequent low prec op's after the affine high prec op
|
|
// expression.
|
|
if (AffineLowPrecOp nextOp = consumeIfLowPrecOp())
|
|
return parseAffineLowPrecOpExpr(expr, nextOp);
|
|
return expr;
|
|
}
|
|
// Last operand in the expression list.
|
|
if (llhs)
|
|
return getAffineBinaryOpExpr(llhsOp, llhs, lhs);
|
|
// No llhs, 'lhs' itself is the expression.
|
|
return lhs;
|
|
}
|
|
|
|
/// Parse an affine expression.
|
|
/// affine-expr ::= `(` affine-expr `)`
|
|
/// | `-` affine-expr
|
|
/// | affine-expr `+` affine-expr
|
|
/// | affine-expr `-` affine-expr
|
|
/// | affine-expr `*` affine-expr
|
|
/// | affine-expr `floordiv` affine-expr
|
|
/// | affine-expr `ceildiv` affine-expr
|
|
/// | affine-expr `mod` affine-expr
|
|
/// | bare-id
|
|
/// | integer-literal
|
|
///
|
|
/// Additional conditions are checked depending on the production. For eg.,
|
|
/// one of the operands for `*` has to be either constant/symbolic; the second
|
|
/// operand for floordiv, ceildiv, and mod has to be a positive integer.
|
|
AffineExpr AffineParser::parseAffineExpr() {
|
|
return parseAffineLowPrecOpExpr(nullptr, AffineLowPrecOp::LNoOp);
|
|
}
|
|
|
|
SmallVector<AffineExpr, 4> AffineParser::parseAffineExprs(Token::Kind lDelim,
|
|
Token::Kind rDelim) {
|
|
if (failed(parser.parseToken(lDelim,
|
|
"expected lDelim at start of affine expr list")))
|
|
return {};
|
|
|
|
SmallVector<AffineExpr, 4> exprs;
|
|
auto parseElt = [&]() -> LogicalResult {
|
|
auto elt = parseAffineExpr();
|
|
exprs.push_back(elt);
|
|
return elt ? success() : failure();
|
|
};
|
|
|
|
if (failed(parser.parseCommaSeparatedListUntil(rDelim, parseElt,
|
|
/*allowEmptyList=*/true)))
|
|
llvm_unreachable("Failed AffineExpr parsing");
|
|
|
|
return exprs;
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TC parsing.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
namespace {
|
|
|
|
/// Base class for expressions involved in TC parsing.
|
|
struct Expression {
|
|
enum class Kind {
|
|
Uninitialized = 0,
|
|
TensorExpr = 1,
|
|
TensorUse = 2,
|
|
};
|
|
|
|
explicit Expression(Kind k = Kind::Uninitialized) : kind(k) {}
|
|
virtual ~Expression() = default;
|
|
|
|
operator bool() const { return kind != Kind::Uninitialized; }
|
|
|
|
Kind kind;
|
|
};
|
|
|
|
/// Encodes a tensor use of the form:
|
|
///
|
|
/// affine-expr-list ::= affine-expr (`,` affine-expr)*
|
|
/// tensor-use ::= bare-id `(` `)`
|
|
/// | bare-id `(` affine-expr-list `)`
|
|
///
|
|
/// The affine-expr-list is stored as an AffineMap.
|
|
struct TensorUse : public Expression {
|
|
TensorUse() : TensorUse("", AffineMap()) {}
|
|
TensorUse(StringRef name, AffineMap map)
|
|
: Expression(Kind::TensorUse), tensorId(name), indexingMap(map) {}
|
|
|
|
static bool classof(const Expression *e) {
|
|
return e->kind == Kind::TensorUse;
|
|
}
|
|
|
|
bool operator==(const TensorUse &other) const {
|
|
return tensorId == other.tensorId && indexingMap == other.indexingMap;
|
|
}
|
|
|
|
/// Visitation function. Performs preorder or postorder traversal depending on
|
|
/// `PreOrder` and applies `callback` on each node.
|
|
template <typename Lambda, bool PreOrder> void visit(Lambda callback) const;
|
|
|
|
StringRef tensorId;
|
|
AffineMap indexingMap;
|
|
};
|
|
|
|
/// Encodes a tensor expression of the form:
|
|
///
|
|
/// op-spec ::= bare-id `<` reduction-dims-list `>`
|
|
/// | bare-id
|
|
/// op-arg ::= tensor-expr
|
|
/// | tensor-use
|
|
/// op-arg-list ::= op-arg (`,` op-arg)*
|
|
/// tensor-expr ::= op-spec `(` op-arg-list `)`
|
|
///
|
|
/// Underlying op-arg are stored by unique_ptr to base class.
|
|
struct TensorExpr : public Expression {
|
|
TensorExpr(StringRef name,
|
|
SmallVectorImpl<std::unique_ptr<Expression>> &&exprs,
|
|
ArrayRef<unsigned> reductionDims)
|
|
: Expression(Kind::TensorExpr), operationName(name),
|
|
expressions(std::move(exprs)),
|
|
reductionDimensions(reductionDims.begin(), reductionDims.end()) {}
|
|
|
|
static bool classof(const Expression *e) {
|
|
return e->kind == Kind::TensorExpr;
|
|
}
|
|
|
|
bool operator==(const TensorExpr &other) const {
|
|
if (operationName != other.operationName)
|
|
return false;
|
|
if (expressions.size() != other.expressions.size())
|
|
return false;
|
|
for (unsigned i = 0, e = expressions.size(); i < e; ++i)
|
|
if (*expressions[i] != *other.expressions[i])
|
|
return false;
|
|
for (unsigned i = 0, e = reductionDimensions.size(); i < e; ++i)
|
|
if (reductionDimensions[i] != other.reductionDimensions[i])
|
|
return false;
|
|
return true;
|
|
}
|
|
|
|
/// Visitation function. Performs preorder or postorder traversal depending on
|
|
/// `PreOrder` and applies `callback` on each node.
|
|
template <typename Lambda, bool PreOrder> void visit(Lambda callback) const;
|
|
|
|
StringRef operationName;
|
|
SmallVector<std::unique_ptr<Expression>, 4> expressions;
|
|
SetVector<unsigned> reductionDimensions;
|
|
};
|
|
|
|
/// This is a specialized parser for a TCDef.
|
|
/// This maintains the dims it finds in an eager fashion.
|
|
class TCParser {
|
|
enum class EagerDiscoveryMode { None = 0, Symbols, Dimensions };
|
|
|
|
public:
|
|
explicit TCParser(Parser &p);
|
|
|
|
/// Uses the AffineParser to parse the affine exprs used in a tensor
|
|
/// definition. If `discoveryMode` is set to Symbols (resp. Dimensions), new
|
|
/// symbols (resp. dimensions) are added eagerly. Otherwise, an error is
|
|
/// emitted on new identifiers.
|
|
SmallVector<AffineExpr, 4>
|
|
parseAffineExprs(EagerDiscoveryMode discoveryMode, AffineDimList &dims,
|
|
Token::Kind lDelim = Token::Kind::l_paren,
|
|
Token::Kind rDelim = Token::Kind::r_paren);
|
|
|
|
/// Parse the information for a tensor def.
|
|
/// All the affine-expr must be dimensionless (i.e. contain only expressions
|
|
/// involving symbols and constants), but can otherwise contain arbitrary
|
|
/// affine expressions.
|
|
LogicalResult parseTensorDef(bool isOutput);
|
|
|
|
/// Parses a tensor use.
|
|
struct ComprehensionParsingState {
|
|
/// The number of operands (which includes inputs and outputs) in a
|
|
/// comprehension.
|
|
size_t numArgs;
|
|
AffineDimList dims;
|
|
SmallVector<std::unique_ptr<Expression>, 4> expressions;
|
|
llvm::DenseMap<TensorUse, unsigned> orderedTensorArgs;
|
|
};
|
|
LogicalResult parseTensorUse(TensorUse &result,
|
|
ComprehensionParsingState &state);
|
|
|
|
/// Parses an attribute definition.
|
|
LogicalResult parseAttrDef();
|
|
|
|
/// Parses an optional attribute use.
|
|
LogicalResult parseAttrUse(AttrUse &result);
|
|
|
|
/// Parses a tensor expression.
|
|
LogicalResult parseExpression(TensorUse currentDefinition,
|
|
std::unique_ptr<Expression> &result,
|
|
ComprehensionParsingState &state);
|
|
|
|
/// Parse a single comprehension.
|
|
LogicalResult parseOneComprehension(StringRef cppOpName,
|
|
StringRef linalgOpName,
|
|
ComprehensionParsingState &state);
|
|
|
|
/// Parse and print the information for a TC def.
|
|
/// When `gen-ods-decl` is used, this prints the ODS declaration for the TC.
|
|
/// When `gen-impl` is used, this prints the C++ implementation for the extra
|
|
/// methods defined in ODS (`iterator_types`, `indexing_maps` and
|
|
/// `regionBuilder`).
|
|
LogicalResult parseAndEmitODSDef(llvm::raw_ostream &os);
|
|
|
|
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
|
|
void printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
|
StringRef linalgOpName, ArrayRef<StringRef> interfaces,
|
|
ComprehensionParsingState &state);
|
|
|
|
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
|
|
void printReferenceIterators(llvm::raw_ostream &os, StringRef cppOpName,
|
|
ComprehensionParsingState &state);
|
|
|
|
/// Print methods related to indexing map required attributes.
|
|
///
|
|
/// Specifically, this prints the definitions for the following methods:
|
|
/// bool hasDynamicIndexingMaps();
|
|
/// LogicalResult verifyIndexingMapRequiredAttributes();
|
|
void printIndexingMapRequiredAttrMethods(llvm::raw_ostream &os,
|
|
StringRef cppOpName,
|
|
ComprehensionParsingState &state);
|
|
|
|
/// Print the C++ StructuredOpsInterface impl of `indexing_maps`.
|
|
void printReferenceIndexingMaps(llvm::raw_ostream &os, StringRef cppOpName,
|
|
ComprehensionParsingState &state);
|
|
|
|
/// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
|
|
void printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
|
|
ComprehensionParsingState &state);
|
|
|
|
/// Print the C++ impl for named ops canonicalizers and folders.
|
|
void printCanonicalizersAndFolders(llvm::raw_ostream &os,
|
|
StringRef cppOpName);
|
|
|
|
private:
|
|
//===--------------------------------------------------------------------===//
|
|
// Internal bookkeeping of tensors.
|
|
//===--------------------------------------------------------------------===//
|
|
struct RegisteredTensor {
|
|
StringRef type;
|
|
AffineMap shape;
|
|
bool isOutput;
|
|
AffineMap indexingMap;
|
|
unsigned index;
|
|
};
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Internal bookkeeping of attributes.
|
|
//===--------------------------------------------------------------------===//
|
|
struct RegisteredAttr {
|
|
StringRef elementType;
|
|
SmallVector<uint64_t, 4> vectorDims;
|
|
bool isArray;
|
|
bool isOptional;
|
|
|
|
// Returns the function to get values at the given indices from this
|
|
// attribute.
|
|
llvm::Optional<std::string> getValueFn(ArrayRef<uint64_t> indices) const;
|
|
};
|
|
|
|
//===--------------------------------------------------------------------===//
|
|
// Per-TC def state.
|
|
//===--------------------------------------------------------------------===//
|
|
/// Symbols are per TC def.
|
|
AffineSymbolList symbols;
|
|
|
|
/// Attribute usages in all affine expressions.
|
|
SmallVector<AttrUse, 8> attrUses;
|
|
|
|
/// Tensors are per TC def.
|
|
llvm::StringMap<RegisteredTensor> registeredTensors;
|
|
unsigned nextRegisteredTensorIndex;
|
|
|
|
/// Attributes are per TC def.
|
|
std::map<std::string, RegisteredAttr> registeredAttrs;
|
|
|
|
/// A map from AttrUse to AffineExpr symbol.
|
|
llvm::StringMap<AffineExpr> registeredAttrUseToSymbol;
|
|
|
|
StringRef docString;
|
|
|
|
Parser &parser;
|
|
};
|
|
} // namespace
|
|
|
|
namespace llvm {
|
|
|
|
template <> struct DenseMapInfo<TensorUse> {
|
|
static TensorUse getEmptyKey() { return TensorUse("", AffineMap()); }
|
|
static TensorUse getTombstoneKey() {
|
|
return TensorUse(DenseMapInfo<StringRef>::getTombstoneKey(),
|
|
DenseMapInfo<AffineMap>::getTombstoneKey());
|
|
}
|
|
static unsigned getHashValue(const TensorUse &val) {
|
|
return ::llvm::hash_value(val.tensorId); // don't care about collisions.
|
|
}
|
|
static bool isEqual(const TensorUse &LHS, const TensorUse &RHS) {
|
|
return LHS == RHS;
|
|
}
|
|
};
|
|
|
|
} // namespace llvm
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Visitation functions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
template <typename Lambda, bool PreOrder>
|
|
void visit(const Expression &expr, Lambda callback) {
|
|
switch (expr.kind) {
|
|
default:
|
|
llvm_unreachable("Unexpected kind");
|
|
case Expression::Kind::TensorExpr:
|
|
static_cast<const TensorExpr &>(expr).visit<Lambda, PreOrder>(callback);
|
|
break;
|
|
case Expression::Kind::TensorUse:
|
|
static_cast<const TensorUse &>(expr).visit<Lambda, PreOrder>(callback);
|
|
break;
|
|
}
|
|
}
|
|
|
|
template <typename Lambda>
|
|
void visitPreorder(const Expression &expr, Lambda callback) {
|
|
visit<Lambda, false>(expr, callback);
|
|
}
|
|
|
|
template <typename Lambda>
|
|
void visitPostorder(Expression &expr, Lambda callback) {
|
|
visit<Lambda, true>(expr, callback);
|
|
}
|
|
|
|
template <typename Lambda, bool PreOrder>
|
|
void TensorExpr::visit(Lambda callback) const {
|
|
if (!PreOrder)
|
|
callback(*this);
|
|
for (auto &e : expressions)
|
|
::visit<Lambda, PreOrder>(*e, callback);
|
|
if (PreOrder)
|
|
callback(*this);
|
|
}
|
|
|
|
template <typename Lambda, bool PreOrder>
|
|
void TensorUse::visit(Lambda callback) const {
|
|
callback(*this);
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// TC parsing functions.
|
|
//===----------------------------------------------------------------------===//
|
|
TCParser::TCParser(Parser &p)
|
|
: symbols(), registeredTensors(), nextRegisteredTensorIndex(0), parser(p) {}
|
|
|
|
/// Uses the AffineParser to parse the affine exprs used in a tensor
|
|
/// definition. All identifiers are interpreted as symbols, new symbols are
|
|
/// added eagerly.
|
|
SmallVector<AffineExpr, 4>
|
|
TCParser::parseAffineExprs(EagerDiscoveryMode discoveryMode,
|
|
AffineDimList &dims, Token::Kind lDelim,
|
|
Token::Kind rDelim) {
|
|
auto createAffineBareId = [&](StringRef sRef) {
|
|
AffineExpr expr;
|
|
if (discoveryMode == EagerDiscoveryMode::Symbols) {
|
|
expr = getAffineSymbolExpr(symbols.size(), parser.context);
|
|
symbols.emplace_back(sRef, expr);
|
|
} else if (discoveryMode == EagerDiscoveryMode::Dimensions) {
|
|
expr = getAffineDimExpr(dims.size(), parser.context);
|
|
dims.emplace_back(sRef, expr);
|
|
}
|
|
return expr;
|
|
};
|
|
|
|
auto tryToParseAttrUse = [&]() -> llvm::Optional<AffineExpr> {
|
|
if (!parser.curToken.is(Token::Kind::id))
|
|
return llvm::None;
|
|
|
|
StringRef attrName = parser.curToken.getSpelling();
|
|
auto it = registeredAttrs.find(attrName.str());
|
|
if (it == registeredAttrs.end())
|
|
return llvm::None;
|
|
|
|
AttrUse result;
|
|
if (failed(parseAttrUse(result)))
|
|
return llvm::None;
|
|
|
|
auto symbolIt = registeredAttrUseToSymbol.find(result.getKey());
|
|
if (symbolIt == registeredAttrUseToSymbol.end()) {
|
|
result.symbol = getAffineSymbolExpr(symbols.size(), parser.context);
|
|
symbols.emplace_back("<attr-use>", result.symbol);
|
|
registeredAttrUseToSymbol[result.getKey()] = result.symbol;
|
|
attrUses.push_back(result);
|
|
} else {
|
|
result.symbol = symbolIt->second;
|
|
}
|
|
|
|
return result.symbol;
|
|
};
|
|
|
|
AffineParser affineParser(parser, createAffineBareId, tryToParseAttrUse, dims,
|
|
symbols);
|
|
return affineParser.parseAffineExprs(lDelim, rDelim);
|
|
}
|
|
|
|
/// Parse the information for a tensor def of the form:
|
|
///
|
|
/// affine-expr-list ::= affine-expr (`,` affine-expr )*
|
|
/// tensor-typedef ::= type `(` `)`
|
|
/// | type `(` affine-expr-list `)`
|
|
/// tensor-def ::= bare-id `:` tensor-typedef
|
|
LogicalResult TCParser::parseTensorDef(bool isOutput) {
|
|
StringRef tensorId = parser.curToken.getSpelling();
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected an id")) ||
|
|
failed(parser.parseToken(Token::Kind::colon, "expected colon")))
|
|
return failure();
|
|
|
|
StringRef tensorType = parser.curToken.getSpelling();
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
|
|
return failure();
|
|
|
|
AffineDimList emptyDims;
|
|
auto exprs = parseAffineExprs(EagerDiscoveryMode::Symbols, emptyDims);
|
|
assert(emptyDims.empty() && "Unexpected dimension in tensor def");
|
|
AffineMap map =
|
|
AffineMap::get(/*dimCount=*/0, symbols.size(), exprs, parser.context);
|
|
|
|
auto iterBoolPair = registeredTensors.try_emplace(
|
|
tensorId, RegisteredTensor{tensorType, map, isOutput, AffineMap(),
|
|
nextRegisteredTensorIndex++});
|
|
(void)iterBoolPair;
|
|
assert(iterBoolPair.second && "Could not emplace tensor registration");
|
|
LLVM_DEBUG(llvm::dbgs() << "Recorded: " << tensorId << " "
|
|
<< "with typeString: " << tensorType << " "
|
|
<< "and shape: " << map << "\n");
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Parses a tensor use of the form:
|
|
///
|
|
/// affine-expr-list ::= affine-expr (`,` affine-expr)*
|
|
/// tensor-use ::= bare-id `(` `)`
|
|
/// | bare-id `(` affine-expr-list `)`
|
|
LogicalResult TCParser::parseTensorUse(TensorUse &result,
|
|
ComprehensionParsingState &state) {
|
|
StringRef tensorId = parser.curToken.getSpelling();
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
|
|
return failure();
|
|
|
|
auto exprs = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims);
|
|
AffineMap map =
|
|
AffineMap::get(state.dims.size(), symbols.size(), exprs, parser.context);
|
|
LLVM_DEBUG(llvm::dbgs() << "Use of tensor: " << tensorId << " map: " << map
|
|
<< "\n");
|
|
|
|
result = TensorUse(tensorId, map);
|
|
return success();
|
|
}
|
|
|
|
/// Parse the information for an attribute def of the form:
|
|
///
|
|
/// affine-expr-list ::= affine-expr (`,` affine-expr )*
|
|
/// attr-id ::= bare-id (`?`)?
|
|
/// dim-list ::= (integer-literal 'x')+
|
|
/// attr-typedef ::= dim-list? type (`[` `]`)?
|
|
/// attr-def ::= attr-id `:` attr-typedef
|
|
LogicalResult TCParser::parseAttrDef() {
|
|
auto attrLoc = parser.curToken.getLoc();
|
|
StringRef attrName = parser.curToken.getSpelling();
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
|
|
return failure();
|
|
bool isOptional = succeeded(parser.parseOptionalToken(Token::Kind::question));
|
|
if (failed(parser.parseToken(Token::Kind::colon, "expected colon")))
|
|
return failure();
|
|
|
|
// Parse the attribute's type. We don't expect the type to be arbitrary
|
|
// complex, so just use this ad-hoc handling here.
|
|
|
|
// Parse potential dimension list
|
|
SmallVector<uint64_t, 4> vectorDims;
|
|
while (parser.curToken.is(Token::Kind::integer)) {
|
|
uint64_t value;
|
|
if (failed(parser.parseInteger(value)))
|
|
return failure();
|
|
vectorDims.push_back(value);
|
|
|
|
StringRef spelling = parser.curToken.getSpelling();
|
|
if (spelling[0] != 'x')
|
|
return parser.emitError(parser.curToken.getLoc(),
|
|
"expected 'x' in dimension list");
|
|
|
|
// If we had a prefix of 'x', lex the next token immediately after the 'x'.
|
|
if (spelling.size() != 1)
|
|
parser.lexer.resetPointer(spelling.data() + 1);
|
|
|
|
parser.consumeToken();
|
|
}
|
|
|
|
StringRef elementType = parser.curToken.getSpelling();
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
|
|
return failure();
|
|
|
|
bool isArray = false;
|
|
auto arrayLoc = parser.curToken.getLoc();
|
|
if (succeeded(parser.parseOptionalToken(Token::Kind::l_square))) {
|
|
isArray = true;
|
|
if (failed(parser.parseToken(Token::Kind::r_square, "expected ']'")))
|
|
return failure();
|
|
}
|
|
|
|
if (!vectorDims.empty() && isArray)
|
|
return parser.emitError(arrayLoc, "unsupported vector array attribute");
|
|
|
|
auto iterBoolPair = registeredAttrs.emplace(
|
|
attrName.str(),
|
|
RegisteredAttr{elementType, vectorDims, isArray, isOptional});
|
|
if (!iterBoolPair.second)
|
|
return parser.emitError(attrLoc,
|
|
"Failed to register attribute '" + attrName + "'");
|
|
|
|
LLVM_DEBUG(llvm::dbgs() << "Recorded: " << (isOptional ? "[optional]" : "")
|
|
<< " " << attrName << " "
|
|
<< "with type: " << elementType
|
|
<< (isArray ? "[]" : "") << "\n");
|
|
|
|
return success();
|
|
}
|
|
|
|
LogicalResult TCParser::parseAttrUse(AttrUse &result) {
|
|
result.attrName = parser.curToken.getSpelling();
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected an id")))
|
|
return failure();
|
|
|
|
auto it = registeredAttrs.find(result.attrName.str());
|
|
assert(it != registeredAttrs.end());
|
|
const RegisteredAttr &attr = it->second;
|
|
|
|
if (!attr.vectorDims.empty() || attr.isArray) {
|
|
// This is a vector/array attribute. Parse indices for it.
|
|
auto indexLoc = parser.curToken.getLoc();
|
|
|
|
if (failed(parser.parseToken(Token::Kind::l_square, "expected '['")))
|
|
return failure();
|
|
|
|
auto parseIndex = [&]() {
|
|
uint64_t value;
|
|
if (failed(parser.parseInteger(value)))
|
|
return failure();
|
|
result.indices.push_back(value);
|
|
return success();
|
|
};
|
|
if (failed(parser.parseCommaSeparatedListUntil(
|
|
Token::Kind::r_square, parseIndex, /*allowEmptyList=*/false)))
|
|
return failure();
|
|
|
|
size_t rank = attr.isArray ? 1 : attr.vectorDims.size();
|
|
if (result.indices.size() != rank)
|
|
return parser.emitError(indexLoc,
|
|
"number of indices mismatch: expected " +
|
|
std::to_string(rank) + ", but found " +
|
|
std::to_string(result.indices.size()));
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Parses a tensor expression of the form:
|
|
///
|
|
/// op-spec ::= bare-id `<` reduction-dims-list `>`
|
|
/// | bare-id
|
|
/// op-arg ::= tensor-expr
|
|
/// | tensor-use
|
|
/// op-arg-list ::= op-arg (`,` op-arg)*
|
|
/// tensor-expr ::= op-spec `(` op-arg-list `)`
|
|
LogicalResult TCParser::parseExpression(TensorUse currentDefinition,
|
|
std::unique_ptr<Expression> &result,
|
|
ComprehensionParsingState &state) {
|
|
StringRef opOrTensor = parser.curToken.getSpelling();
|
|
if (registeredTensors.count(opOrTensor) > 0) {
|
|
TensorUse use;
|
|
auto res = parseTensorUse(use, state);
|
|
if (failed(res))
|
|
return res;
|
|
result = std::make_unique<TensorUse>(use);
|
|
return success();
|
|
}
|
|
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected an operation")))
|
|
return failure();
|
|
|
|
// This is an op.
|
|
SmallVector<unsigned, 4> reductionDims;
|
|
SmallVector<std::unique_ptr<Expression>, 4> expressions;
|
|
|
|
// Check if it has a reduction set, discover dimensions eagerly.
|
|
if (parser.curToken.is(Token::Kind::lt)) {
|
|
auto iters = parseAffineExprs(EagerDiscoveryMode::Dimensions, state.dims,
|
|
Token::Kind::lt, Token::Kind::gt);
|
|
for (auto iter : iters)
|
|
reductionDims.push_back(iter.cast<AffineDimExpr>().getPosition());
|
|
}
|
|
|
|
auto parseExpr = [&]() -> LogicalResult {
|
|
std::unique_ptr<Expression> e;
|
|
if (failed(parseExpression(currentDefinition, e, state)))
|
|
return failure();
|
|
expressions.push_back(std::move(e));
|
|
return success();
|
|
};
|
|
if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")) ||
|
|
failed(parser.parseCommaSeparatedListUntil(
|
|
Token::Kind::r_paren, parseExpr, /*allowEmptyList=*/true)))
|
|
return failure();
|
|
|
|
result = std::make_unique<TensorExpr>(opOrTensor, std::move(expressions),
|
|
reductionDims);
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Parse and Emit functions.
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Parse the information for a single comprehension.
|
|
///
|
|
/// tensor-def-list ::= tensor-def (`,` tensor-def)*
|
|
/// tensor-expr-list ::= tensor-expr (`,` tensor-expr)*
|
|
/// comprehension ::= tensor-def-list `=` tensor-expr-list `;`
|
|
LogicalResult
|
|
TCParser::parseOneComprehension(StringRef cppOpName, StringRef linalgOpName,
|
|
ComprehensionParsingState &state) {
|
|
// 1. Parse LHS of `=`, these become the definitions that appear as the output
|
|
// tensors or read/write buffers.
|
|
SmallVector<TensorUse, 4> definitions;
|
|
auto parseUse = [&]() -> LogicalResult {
|
|
TensorUse use;
|
|
if (failed(parseTensorUse(use, state)))
|
|
return failure();
|
|
definitions.push_back(use);
|
|
return success();
|
|
};
|
|
if (failed(parser.parseCommaSeparatedListUntil(Token::Kind::equal, parseUse,
|
|
/*allowEmptyList=*/true)))
|
|
return failure();
|
|
|
|
// 2. Parse RHS of `=`, this becomes the expressions from which we emit
|
|
// computations.
|
|
unsigned idx = 0;
|
|
auto parseExpr = [&]() -> LogicalResult {
|
|
std::unique_ptr<Expression> expr;
|
|
if (idx >= definitions.size())
|
|
return parser.emitError("Fewer LHS definitions than RHS expressions");
|
|
if (failed(parseExpression(definitions[idx++], expr, state)))
|
|
return failure();
|
|
state.expressions.push_back(std::move(expr));
|
|
return success();
|
|
};
|
|
if (failed(parser.parseCommaSeparatedListUntil(
|
|
Token::Kind::semicolon, parseExpr, /*allowEmptyList=*/true)))
|
|
return failure();
|
|
if (idx != definitions.size())
|
|
return parser.emitError("Fewer RHS expressions than LHS definitions");
|
|
|
|
// 3. Postprocess.
|
|
// 3.a. Normalize all maps to the proper state.dims and symbols counts.
|
|
SmallVector<TensorUse, 4> allUses;
|
|
allUses.reserve(registeredTensors.size());
|
|
for (auto &def : definitions)
|
|
allUses.push_back(def);
|
|
for (auto &pExpr : state.expressions)
|
|
visitPostorder(*pExpr, [&](const Expression &e) {
|
|
if (auto *use = dyn_cast<TensorUse>(&e))
|
|
allUses.push_back(*use);
|
|
});
|
|
for (auto &use : allUses)
|
|
use.indexingMap =
|
|
AffineMap::get(state.dims.size(), symbols.size(),
|
|
use.indexingMap.getResults(), parser.context);
|
|
|
|
// 3.b. Traverse definitions
|
|
llvm::DenseSet<StringRef> seenDefs;
|
|
for (auto &def : definitions) {
|
|
if (seenDefs.count(def.tensorId) > 0)
|
|
return parser.emitError("Unexpected multi-write to a single tensor");
|
|
seenDefs.insert(def.tensorId);
|
|
auto tensorIter = registeredTensors.find(def.tensorId);
|
|
assert(tensorIter != registeredTensors.end() && "unregistered tensor");
|
|
auto &tensor = tensorIter->getValue();
|
|
tensor.indexingMap = def.indexingMap;
|
|
state.orderedTensorArgs[def] = tensor.index;
|
|
}
|
|
|
|
bool failed = false;
|
|
for (auto &pExpr : state.expressions)
|
|
visitPostorder(*pExpr, [&](const Expression &e) {
|
|
auto *pUse = dyn_cast<TensorUse>(&e);
|
|
if (failed || !pUse)
|
|
return;
|
|
auto &use = *pUse;
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "\nuse: " << use.tensorId << " map: " << use.indexingMap);
|
|
auto tensorIter = registeredTensors.find(use.tensorId);
|
|
assert(tensorIter != registeredTensors.end() && "unregistered tensor");
|
|
auto &tensor = tensorIter->getValue();
|
|
if (tensor.indexingMap && state.orderedTensorArgs.count(use) == 0 &&
|
|
tensor.indexingMap.getResults() != use.indexingMap.getResults()) {
|
|
LLVM_DEBUG(llvm::dbgs() << "\nexisting: " << tensor.indexingMap);
|
|
(void)parser.emitError(
|
|
"Unexpected multi-read of a tensor with different accesses");
|
|
failed = true;
|
|
return;
|
|
}
|
|
seenDefs.insert(use.tensorId);
|
|
tensor.indexingMap = use.indexingMap;
|
|
state.orderedTensorArgs[use] = tensor.index;
|
|
});
|
|
// If more than one definitions are less. They are shaped-only operand, which
|
|
// are used to define reduction loops. For now, only accept exactly one
|
|
// shaped-only operand.
|
|
if (state.numArgs > seenDefs.size() + 1) {
|
|
failed = true;
|
|
} else if (state.numArgs == seenDefs.size() + 1) {
|
|
for (auto &tensorIter : registeredTensors) {
|
|
auto &tensor = tensorIter.getValue();
|
|
if (tensor.indexingMap)
|
|
continue;
|
|
if (auto *pTensorExpr =
|
|
dyn_cast<TensorExpr>(state.expressions[0].get())) {
|
|
SmallVector<AffineExpr, 4> exprs;
|
|
for (auto dim : pTensorExpr->reductionDimensions)
|
|
exprs.push_back(getAffineDimExpr(dim, parser.context));
|
|
tensor.indexingMap = AffineMap::get(state.dims.size(), symbols.size(),
|
|
exprs, parser.context);
|
|
}
|
|
}
|
|
}
|
|
if (failed)
|
|
return failure();
|
|
|
|
return success();
|
|
}
|
|
|
|
/// Parse and print the information for a ODS def.
|
|
///
|
|
/// tensor-def-list ::= tensor-def (`,` tensor-def )*
|
|
/// attr-def-list ::= attr-def (`,` attr-def )*
|
|
///
|
|
/// comprehension-list ::= comprehension comprehension*
|
|
///
|
|
/// tc-attr-def ::= `attr` `(` attr-def-list `)`
|
|
/// tc-def ::= `def` bare-id `(`tensor-def-list`)` `->` `(` tensor-def-list`)`
|
|
/// (tc-attr-def)?
|
|
/// `{` comprehension-list `}`
|
|
///
|
|
/// implements-interface ::=
|
|
/// `implements_interface` `<` bare-id (`,` bare-id)* `>` `:` tc-def
|
|
///
|
|
/// ods-def ::= `ods_def` `<` bare-id `>`
|
|
/// (implements-interface)? `:`
|
|
/// tc-def
|
|
///
|
|
/// All the affine-expr in a `tensor-typedef` must be dimensionless (i.e.
|
|
/// contain only expressions involving symbols and constants), but can
|
|
/// otherwise contain arbitrary affine expressions.
|
|
LogicalResult TCParser::parseAndEmitODSDef(llvm::raw_ostream &os) {
|
|
// Parse ods-def header (including C++ op name)
|
|
if (failed(parser.parseToken(Token::Kind::kw_ods_def,
|
|
"expected 'ods_def' to define a TC ODS")) ||
|
|
failed(parser.parseToken(Token::Kind::lt, "expected '<'")))
|
|
return failure();
|
|
StringRef cppOpName = parser.curToken.getSpelling();
|
|
LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing ODS: " << cppOpName << "\n");
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
|
|
failed(parser.parseToken(Token::Kind::gt, "expected '>'")))
|
|
return failure();
|
|
|
|
// Parse optional implements-interface header (including C++ op names)
|
|
SmallVector<StringRef> interfaces;
|
|
bool implementsInterface = succeeded(
|
|
parser.parseOptionalToken(Token::Kind::kw_implements_interface));
|
|
if (implementsInterface) {
|
|
auto parseInterfaceString = [&]() -> LogicalResult {
|
|
StringRef interfaceName = parser.curToken.getSpelling();
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected id")))
|
|
return failure();
|
|
interfaces.push_back(interfaceName);
|
|
return success();
|
|
};
|
|
if (failed(parser.parseToken(Token::Kind::lt, "expected '<'")) ||
|
|
failed(parser.parseCommaSeparatedListUntil(
|
|
Token::Kind::gt, parseInterfaceString, /*allowEmptyList=*/false)))
|
|
return failure();
|
|
}
|
|
|
|
// Parse column.
|
|
if (failed(parser.parseToken(Token::Kind::colon, "expected ':'")))
|
|
return failure();
|
|
|
|
// Parse TC op name.
|
|
if (failed(parser.parseToken(Token::Kind::kw_def,
|
|
"expected 'def' to define a TC")))
|
|
return failure();
|
|
StringRef tcName = parser.curToken.getSpelling();
|
|
LLVM_DEBUG(llvm::dbgs() << "\n\nStart parsing TC: " << tcName << "\n");
|
|
|
|
// Parse input/output tensor definitions
|
|
if (failed(parser.parseToken(Token::Kind::id, "expected id")) ||
|
|
failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
|
|
return failure();
|
|
|
|
auto parseInputDef = [&]() -> LogicalResult {
|
|
return parseTensorDef(/*isOutput=*/false);
|
|
};
|
|
if (failed(parser.parseCommaSeparatedListUntil(
|
|
Token::Kind::r_paren, parseInputDef, /*allowEmptyList=*/false)))
|
|
return failure();
|
|
|
|
if (failed(parser.parseToken(Token::Kind::minus, "expected '-'")) ||
|
|
failed(parser.parseToken(Token::Kind::gt, "expected '>'")) ||
|
|
failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
|
|
return failure();
|
|
auto parseOutputDef = [&]() -> LogicalResult {
|
|
return parseTensorDef(/*isOutput=*/true);
|
|
};
|
|
if (failed(parser.parseCommaSeparatedListUntil(
|
|
Token::Kind::r_paren, parseOutputDef, /*allowEmptyList=*/false)))
|
|
return failure();
|
|
|
|
// Parse optional attribute definitions
|
|
if (succeeded(parser.parseOptionalToken(Token::Kind::kw_attr_def))) {
|
|
if (failed(parser.parseToken(Token::Kind::l_paren, "expected '('")))
|
|
return failure();
|
|
if (failed(parser.parseCommaSeparatedListUntil(
|
|
Token::Kind::r_paren, std::bind(&TCParser::parseAttrDef, this),
|
|
/*allowEmptyList=*/false)))
|
|
return failure();
|
|
}
|
|
|
|
// Parse optional doc string
|
|
if (parser.curToken.is(Token::Kind::doc_str)) {
|
|
docString = parser.curToken.getSpelling();
|
|
parser.consumeToken();
|
|
LLVM_DEBUG(llvm::dbgs()
|
|
<< "parsed doc string: '''" << docString << "'''\n");
|
|
}
|
|
|
|
// Since we don't declare symbols separately, we discover them eagerly: each
|
|
// newly encountered id in a tensor shape expression is treated as a new
|
|
// symbolic. At this point, all tensors have been parsed and all the symbols
|
|
// that could be discovered eagerly are now known. Resize all AffineMaps to
|
|
// normalize the number of eagerly discovered symbols.
|
|
for (auto &tensor : registeredTensors) {
|
|
auto &map = tensor.getValue().shape;
|
|
map = AffineMap::get(/*dimCount=*/0, symbols.size(), map.getResults(),
|
|
parser.context);
|
|
}
|
|
|
|
if (failed(parser.parseToken(Token::Kind::l_brace, "expected '{'")))
|
|
return failure();
|
|
|
|
SmallVector<ComprehensionParsingState, 4> perComprehensionStates;
|
|
while (parser.curToken.isNot(Token::Kind::r_brace)) {
|
|
perComprehensionStates.push_back(ComprehensionParsingState());
|
|
perComprehensionStates.back().numArgs = registeredTensors.size();
|
|
if (failed(parseOneComprehension(cppOpName, tcName,
|
|
perComprehensionStates.back())))
|
|
return failure();
|
|
};
|
|
if (failed(parser.parseToken(Token::Kind::r_brace, "expected '}'")))
|
|
return failure();
|
|
|
|
// Print.
|
|
auto nComprehensions = perComprehensionStates.size();
|
|
if (nComprehensions != 1)
|
|
return parser.emitError("only 1 comprehension supported for now, got: " +
|
|
llvm::Twine(nComprehensions));
|
|
if (genODSDecl) {
|
|
auto &state = perComprehensionStates.back();
|
|
printODS(os, cppOpName, tcName, interfaces, state);
|
|
os << "\n";
|
|
}
|
|
if (genODSImpl) {
|
|
auto &state = perComprehensionStates.back();
|
|
std::string extraMethods;
|
|
llvm::raw_string_ostream ss(extraMethods);
|
|
printReferenceIterators(ss, cppOpName, state);
|
|
printIndexingMapRequiredAttrMethods(ss, cppOpName, state);
|
|
printReferenceIndexingMaps(ss, cppOpName, state);
|
|
printRegionBuilder(ss, cppOpName, state);
|
|
printCanonicalizersAndFolders(ss, cppOpName);
|
|
ss.flush();
|
|
os << extraMethods << "\n";
|
|
}
|
|
|
|
return success();
|
|
}
|
|
|
|
//===----------------------------------------------------------------------===//
|
|
// Printing functions
|
|
//===----------------------------------------------------------------------===//
|
|
|
|
/// Print the ODS class that defines a new `cppOpName` for a `linalgOpName`.
|
|
void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
|
|
StringRef linalgOpName, ArrayRef<StringRef> interfaces,
|
|
ComprehensionParsingState &state) {
|
|
SmallVector<std::string, 4> attributes;
|
|
for (const auto &attr : registeredAttrs) {
|
|
llvm::StringRef name = attr.first;
|
|
|
|
llvm::StringRef elementType = attr.second.elementType;
|
|
std::string odsType = llvm::StringSwitch<std::string>(elementType)
|
|
.Case("f32", "F32")
|
|
.Case("i32", "I32")
|
|
.Case("i64", "I64")
|
|
.Default("");
|
|
if (odsType.empty()) {
|
|
(void)parser.emitError(
|
|
"unimplemented support for attribute element type: " + elementType);
|
|
return;
|
|
}
|
|
|
|
const auto &dims = attr.second.vectorDims;
|
|
if (!dims.empty()) {
|
|
// Vector case
|
|
SmallVector<std::string, 4> dimStrs;
|
|
for (uint64_t dim : dims)
|
|
dimStrs.push_back(std::to_string(dim));
|
|
odsType = llvm::formatv("Ranked{0}ElementsAttr<[{1}]>", odsType,
|
|
llvm::join(dimStrs, ", "));
|
|
} else if (attr.second.isArray) {
|
|
// Array case
|
|
odsType = llvm::formatv("{0}ArrayAttr", odsType);
|
|
} else {
|
|
// Scalar case
|
|
odsType = llvm::formatv("{0}Attr", odsType);
|
|
}
|
|
|
|
if (attr.second.isOptional)
|
|
odsType = llvm::formatv("OptionalAttr<{0}>", odsType);
|
|
|
|
attributes.push_back(llvm::formatv("{0}:${1}", odsType, name));
|
|
}
|
|
|
|
std::string attrList = llvm::join(attributes, ",\n");
|
|
if (!attrList.empty())
|
|
attrList = ",\n" + attrList;
|
|
|
|
// Template for Linalg named ops' ODS definitions. Parameters:
|
|
// {0}: ODS/C++ op name
|
|
// {1}: assembly op mnemonic
|
|
// {2}: op interface list
|
|
// {3}: documentation (summary + description)
|
|
// {4}: op attribute list
|
|
// {5}: the number of arguments for the op region
|
|
// {6}: builder methods taking standalone attribute parameters
|
|
// {7}: additional methods for attributes used by indexing maps
|
|
const char *header = R"FMT( def {0} : LinalgStructuredBase_Op<"{1}", [
|
|
AttrSizedOperandSegments,
|
|
DeclareOpInterfaceMethods<MemoryEffectsOpInterface>,
|
|
SingleBlockImplicitTerminator<"YieldOp">
|
|
/*extraInterfaces=*/{2}]> {
|
|
{3}
|
|
let arguments = (ins
|
|
Variadic<AnyShaped>:$inputs,
|
|
Variadic<AnyShaped>:$outputs{4}
|
|
);
|
|
let results = (outs Variadic<AnyRankedTensor>:$result_tensors);
|
|
let regions = (region AnyRegion:$region);
|
|
|
|
let skipDefaultBuilders = 1;
|
|
let builders = [
|
|
OpBuilder<
|
|
(ins "ValueRange":$inputs, "ValueRange":$outputs),
|
|
[{{
|
|
$_state.addOperands(inputs);
|
|
$_state.addOperands(outputs);
|
|
$_state.addAttribute(
|
|
"operand_segment_sizes",
|
|
$_builder.getI32VectorAttr({{
|
|
static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
createAndFillStructuredOpRegion<{0}>(
|
|
$_builder,
|
|
$_state,
|
|
TypeRange(inputs),
|
|
TypeRange(outputs));
|
|
}]>,
|
|
OpBuilder<
|
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
|
"ValueRange":$outputs),
|
|
[{{
|
|
$_state.addOperands(inputs);
|
|
$_state.addOperands(outputs);
|
|
$_state.addTypes(resultTensorTypes);
|
|
$_state.addAttribute(
|
|
"operand_segment_sizes",
|
|
$_builder.getI32VectorAttr({{
|
|
static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
createAndFillStructuredOpRegion<{0}>(
|
|
$_builder,
|
|
$_state,
|
|
TypeRange(inputs),
|
|
TypeRange(outputs));
|
|
}]>,
|
|
OpBuilder<
|
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
|
|
CArg<"ArrayRef<NamedAttribute>", "{{}">:$attributes),
|
|
[{{
|
|
$_state.addOperands(operands);
|
|
$_state.addAttributes(attributes);
|
|
$_state.addTypes(resultTensorTypes);
|
|
(void)$_state.addRegion();
|
|
}]>
|
|
{6}
|
|
];
|
|
let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
|
|
let parser = [{{
|
|
return ::parseNamedStructuredOp<{0}>(parser, result);
|
|
}];
|
|
let hasFolder = 1;
|
|
|
|
let extraClassDeclaration = structuredOpsBaseDecls # [{{
|
|
// Auto-generated.
|
|
ArrayAttr iterator_types();
|
|
ArrayAttr indexing_maps();
|
|
static void regionBuilder(ImplicitLocOpBuilder &b, Block &block);
|
|
static std::function<void(ImplicitLocOpBuilder &b, Block &)>
|
|
getRegionBuilder() {{
|
|
return regionBuilder;
|
|
}
|
|
|
|
// Generic methods.
|
|
static unsigned getNumRegionArgs() {{ return {5}; }
|
|
std::string getLibraryCallName() {{
|
|
return generateLibraryCallName(getOperation());
|
|
}
|
|
|
|
{7}
|
|
}];
|
|
})FMT";
|
|
|
|
// Generate the list of extra implemented interfaces.
|
|
std::string interfaceNameList;
|
|
if (!interfaces.empty()) {
|
|
llvm::raw_string_ostream ss(interfaceNameList);
|
|
ss << ", "; // Leading comma to concat to existing list of interfaces.
|
|
llvm::interleaveComma(interfaces, ss);
|
|
ss.flush();
|
|
}
|
|
|
|
// Generate documentation.
|
|
std::string doc;
|
|
if (!docString.empty()) {
|
|
const char *docFmt = R"FMT(
|
|
let summary = [{ {0} }];
|
|
let description = [{
|
|
{1}
|
|
}];
|
|
)FMT";
|
|
|
|
StringRef summary, description;
|
|
std::tie(summary, description) = docString.trim().split('\n');
|
|
doc = llvm::formatv(docFmt, summary.trim(), description.trim());
|
|
}
|
|
|
|
// Generate an additional builder that has parameters for attributes.
|
|
std::string attrBuilder;
|
|
if (!registeredAttrs.empty()) {
|
|
SmallVector<std::string, 4> attrParams, attrStmts;
|
|
for (const auto &attr : registeredAttrs) {
|
|
llvm::StringRef name = attr.first;
|
|
attrParams.push_back(llvm::formatv("\"Attribute\":${0}", name));
|
|
attrStmts.push_back(
|
|
llvm::formatv("$_state.addAttribute(\"{0}\", {0});", name));
|
|
}
|
|
std::string attrParamsList = llvm::join(attrParams, ", ");
|
|
std::string attrStmtsList = llvm::join(attrStmts, "\n");
|
|
|
|
const char *builderFmt = R"FMT(
|
|
, OpBuilder<
|
|
(ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
|
|
"ValueRange":$outputs, {1}),
|
|
[{{
|
|
$_state.addOperands(inputs);
|
|
$_state.addOperands(outputs);
|
|
$_state.addTypes(resultTensorTypes);
|
|
$_state.addAttribute(
|
|
"operand_segment_sizes",
|
|
$_builder.getI32VectorAttr({{
|
|
static_cast<int32_t>(inputs.size()),
|
|
static_cast<int32_t>(outputs.size())}));
|
|
createAndFillStructuredOpRegion<{0}>(
|
|
$_builder,
|
|
$_state,
|
|
TypeRange(inputs),
|
|
TypeRange(outputs));
|
|
{2}
|
|
}]>
|
|
)FMT";
|
|
attrBuilder =
|
|
llvm::formatv(builderFmt, cppOpName, attrParamsList, attrStmtsList);
|
|
}
|
|
|
|
std::string attrMethods;
|
|
if (!registeredAttrs.empty()) {
|
|
attrMethods = R"(
|
|
bool hasDynamicIndexingMaps();
|
|
LogicalResult verifyIndexingMapRequiredAttributes();
|
|
)";
|
|
}
|
|
|
|
// Finally put everything together.
|
|
os << llvm::formatv(header, cppOpName, linalgOpName, interfaceNameList, doc,
|
|
attrList, state.numArgs, attrBuilder, attrMethods);
|
|
}
|
|
|
|
/// Print the C++ StructuredOpsInterface impl of `iterator_types`.
|
|
void TCParser::printReferenceIterators(llvm::raw_ostream &os,
|
|
StringRef cppOpName,
|
|
ComprehensionParsingState &state) {
|
|
const char *referenceReferenceIteratorsFmt =
|
|
R"FMT(
|
|
ArrayAttr {0}::iterator_types() {
|
|
return Builder(getContext()).getStrArrayAttr(SmallVector<StringRef, 8>{{ {1} });
|
|
})FMT";
|
|
|
|
std::string iteratorsStr;
|
|
llvm::raw_string_ostream ss(iteratorsStr);
|
|
unsigned pos = 0;
|
|
llvm::interleaveComma(
|
|
state.dims, ss, [&](std::pair<StringRef, AffineExpr> p) {
|
|
bool reduction = false;
|
|
for (auto &expr : state.expressions) {
|
|
visitPostorder(*expr, [&](const Expression &e) {
|
|
if (auto *pTensorExpr = dyn_cast<TensorExpr>(&e)) {
|
|
if (pTensorExpr->reductionDimensions.count(pos) > 0)
|
|
reduction = true;
|
|
}
|
|
});
|
|
if (reduction)
|
|
break;
|
|
}
|
|
ss << (reduction ? "getReductionIteratorTypeName()"
|
|
: "getParallelIteratorTypeName()");
|
|
pos++;
|
|
});
|
|
ss.flush();
|
|
|
|
os << llvm::formatv(referenceReferenceIteratorsFmt, cppOpName, iteratorsStr);
|
|
}
|
|
|
|
void TCParser::printCanonicalizersAndFolders(llvm::raw_ostream &os,
|
|
StringRef cppOpName) {
|
|
const char *foldersFmt = R"FMT(
|
|
LogicalResult {0}::fold(ArrayRef<Attribute>,
|
|
SmallVectorImpl<OpFoldResult> &) {{
|
|
return foldMemRefCast(*this);
|
|
}
|
|
void {0}::getEffects(SmallVectorImpl<
|
|
SideEffects::EffectInstance<MemoryEffects::Effect> >&effects) {{
|
|
SmallVector<Value> inputBuffers = getInputBufferOperands();
|
|
SmallVector<Value> outputBuffers = getOutputBufferOperands();
|
|
getGenericEffectsImpl(effects,
|
|
getOperation()->getResults(), inputBuffers, outputBuffers);
|
|
})FMT";
|
|
os << llvm::formatv(foldersFmt, cppOpName);
|
|
}
|
|
|
|
// Prints methods for querying whether the current named op has attributes that
|
|
// are used by its indexing maps and for verifying those attributes have the
|
|
// expected type.
|
|
void TCParser::printIndexingMapRequiredAttrMethods(
|
|
llvm::raw_ostream &os, StringRef cppOpName,
|
|
ComprehensionParsingState &state) {
|
|
// If there are no attribute used by the whole definition, then we are done.
|
|
if (registeredAttrs.empty())
|
|
return;
|
|
|
|
// Otherwise, go through each attribute and generate code to verify it's
|
|
// valid per the spec.
|
|
SmallVector<std::string, 4> attributes;
|
|
for (const auto &attr : registeredAttrs) {
|
|
if (attr.second.isOptional)
|
|
continue;
|
|
|
|
llvm::StringRef name = attr.first;
|
|
llvm::StringRef elementType = attr.second.elementType;
|
|
const auto &dims = attr.second.vectorDims;
|
|
|
|
// Get the method call to check the element type is of the expected kind.
|
|
std::string elemTypeCheck = llvm::StringSwitch<std::string>(elementType)
|
|
.Case("f32", "isF32()")
|
|
.Case("i32", "isInteger(32)")
|
|
.Case("i64", "isInteger(64)")
|
|
.Default("");
|
|
if (elemTypeCheck.empty()) {
|
|
(void)parser.emitError(
|
|
"unimplemented support for attribute element type: " + elementType);
|
|
return;
|
|
}
|
|
|
|
// Scalar case.
|
|
if (dims.empty() && !attr.second.isArray) {
|
|
const char *attrFmt = R"FMT(
|
|
if (auto attr = op->getAttr("{0}")) {{
|
|
if (!attr.getType().{1}) return op->emitError(
|
|
"incorrect type for indexing map required attribute '{0}'");
|
|
} else {{
|
|
return op->emitError(
|
|
"missing indexing map required attribute '{0}'");
|
|
}
|
|
)FMT";
|
|
|
|
attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck));
|
|
continue;
|
|
}
|
|
|
|
// Vector case.
|
|
if (!dims.empty()) {
|
|
SmallVector<std::string, 4> dimStrs;
|
|
for (uint64_t dim : dims)
|
|
dimStrs.push_back(std::to_string(dim));
|
|
|
|
const char *attrFmt = R"FMT(
|
|
if (auto attr = op->getAttrOfType<DenseElementsAttr>("{0}")) {{
|
|
if (!attr.getType().getElementType().{1}) return op->emitError(
|
|
"incorrect element type for indexing map required attribute '{0}'");
|
|
if (attr.getType().getShape() != ArrayRef<int64_t>{{ {2} })
|
|
return op->emitError(
|
|
"incorrect shape for indexing map required attribute '{0}'");
|
|
} else {
|
|
return op->emitError(
|
|
"missing indexing map required attribute '{0}'");
|
|
}
|
|
)FMT";
|
|
|
|
attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck,
|
|
llvm::join(dimStrs, ", ")));
|
|
continue;
|
|
}
|
|
|
|
// Array case.
|
|
{
|
|
const char *attrFmt = R"FMT(
|
|
if (auto attr = op->getAttrOfType<ArrayAttr>("{0}")) {{
|
|
for (Attribute element : attr) {{
|
|
if (!element.getType().{1}) return emitError(
|
|
"incorrect element type for indexing map required attribute '{0}'");
|
|
}
|
|
} else {{
|
|
return op->emitError(
|
|
"missing indexing map required attribute '{0}'");
|
|
}
|
|
)FMT";
|
|
|
|
attributes.push_back(llvm::formatv(attrFmt, name, elemTypeCheck));
|
|
}
|
|
}
|
|
|
|
const char *methodFmt = R"FMT(
|
|
bool {0}::hasDynamicIndexingMaps() {{ return true; }
|
|
|
|
LogicalResult {0}::verifyIndexingMapRequiredAttributes() {{
|
|
Operation *op = getOperation();
|
|
{1}
|
|
return success();
|
|
}
|
|
)FMT";
|
|
|
|
// Print everything out.
|
|
os << llvm::formatv(methodFmt, cppOpName, llvm::join(attributes, "\n"));
|
|
}
|
|
|
|
/// Print the C++ StructuredOpsInterface impl of `referenceIndexingMaps`.
|
|
void TCParser::printReferenceIndexingMaps(llvm::raw_ostream &os,
|
|
StringRef cppOpName,
|
|
ComprehensionParsingState &state) {
|
|
// 1. Generic string template for specifying reference indexing maps.
|
|
const char *referenceIndexingMapsFmt =
|
|
R"FMT(
|
|
// This is temporary until we transition out of manually specified ops that
|
|
// should be auto-generated with linalg-ods-gen.
|
|
ArrayAttr {0}::indexing_maps() {
|
|
MLIRContext *context = getContext();
|
|
AffineExpr {1};
|
|
bindDims(context, {1});
|
|
{2}
|
|
return Builder(context).getAffineMapArrayAttr({ {3} });
|
|
})FMT";
|
|
|
|
// 2. Print a comma-separated list of identifiers for the AffineExpr in
|
|
// `state.dims`. These will replace the `{1}` placeholder in both
|
|
// `AffineExpr {1}` and `bindDims(context, {1})` ensuring the AffineExpr
|
|
// identifiers are bound in the right order to the proper AffineDimExpr.
|
|
std::string dimsStr;
|
|
llvm::raw_string_ostream ss(dimsStr);
|
|
llvm::interleaveComma(
|
|
state.dims, ss,
|
|
[&](std::pair<StringRef, AffineExpr> p) { ss << p.second; });
|
|
ss.flush();
|
|
|
|
// 3. Get the list of affine maps for each input/output. The AffineExpr use
|
|
// the common arithmetic operators on AffineExpr. These affine maps will
|
|
// replace the `{2}` placeholder.
|
|
std::string mapsStr;
|
|
llvm::raw_string_ostream mapsStringStream(mapsStr);
|
|
|
|
// Create a list of all symbols.
|
|
SmallVector<std::string, 4> symbolReplacements;
|
|
symbolReplacements.reserve(symbols.size());
|
|
for (unsigned i = 0; i < symbols.size(); ++i) {
|
|
const char *symFmt =
|
|
"\n\tauto s{0} = getAffineSymbolExpr({0}, context); (void)s{0};";
|
|
mapsStringStream << llvm::formatv(symFmt, i);
|
|
symbolReplacements.push_back(llvm::formatv("s{0}", i));
|
|
}
|
|
|
|
// Create the affine constant expressions to replace symbols for attributes.
|
|
for (auto attrUse : llvm::enumerate(attrUses)) {
|
|
StringRef attrName = attrUse.value().attrName;
|
|
auto it = registeredAttrs.find(attrName.str());
|
|
assert(it != registeredAttrs.end() && "uses should point to valid attr!");
|
|
llvm::Optional<std::string> getValueFn =
|
|
it->second.getValueFn(attrUse.value().indices);
|
|
if (!getValueFn) {
|
|
(void)parser.emitError("unimplemented getValueFn for attribute: " +
|
|
attrName);
|
|
return;
|
|
}
|
|
std::string cstVal = llvm::formatv("{0}(){1}", attrName, *getValueFn);
|
|
const char *cstFmt =
|
|
"\n\tauto cst{0} = getAffineConstantExpr({1}, context);";
|
|
mapsStringStream << llvm::formatv(cstFmt, attrUse.index(), cstVal);
|
|
|
|
unsigned position =
|
|
attrUse.value().symbol.cast<AffineSymbolExpr>().getPosition();
|
|
symbolReplacements[position] = llvm::formatv("cst{0}", attrUse.index());
|
|
}
|
|
|
|
// For each registered tensor, construct the affine map, replace symbols by
|
|
// the corresponding attribute values, and simplify the affine map.
|
|
for (auto &tensorIter : registeredTensors) {
|
|
auto &tensor = tensorIter.getValue();
|
|
auto indexingMap = tensor.indexingMap;
|
|
const char *mapFmt =
|
|
"\n\tauto map{0} = AffineMap::get({1}, {2}, {3}, context);";
|
|
|
|
std::string exprsStr;
|
|
llvm::raw_string_ostream exprsStringStream(exprsStr);
|
|
exprsStringStream << "{";
|
|
llvm::interleaveComma(indexingMap.getResults(), exprsStringStream);
|
|
exprsStringStream << "}";
|
|
exprsStringStream.flush();
|
|
mapsStringStream << llvm::formatv(mapFmt, tensor.index, state.dims.size(),
|
|
indexingMap.getNumSymbols(), exprsStr);
|
|
|
|
std::string replaceSymbolList =
|
|
llvm::formatv("{ {0} }", llvm::join(symbolReplacements, ", "));
|
|
|
|
// Note that we use `0` as the result affine map's number of symbols. All
|
|
// symbols representing attribute usages should be folded away. But there
|
|
// may exist additional symbols for tensor dimension upper bounds. Linalg
|
|
// does not handle such cases right now. This needs to be fixed once we
|
|
// need that.
|
|
const char *replaceFmt =
|
|
"\n\tmap{0} = map{0}.replaceDimsAndSymbols({{}, {1}, {2}, 0);";
|
|
mapsStringStream << llvm::formatv(replaceFmt, tensor.index,
|
|
replaceSymbolList, state.dims.size());
|
|
const char *simplifyFmt = "\n\tmap{0} = simplifyAffineMap(map{0});";
|
|
mapsStringStream << llvm::formatv(simplifyFmt, tensor.index);
|
|
}
|
|
|
|
mapsStringStream.flush();
|
|
|
|
SmallVector<std::string, 4> mapList;
|
|
mapList.reserve(state.numArgs);
|
|
for (auto i : llvm::seq<unsigned>(0, state.numArgs))
|
|
mapList.push_back(llvm::formatv("map{0}", i));
|
|
|
|
// 4. Apply format to 1. using 2. and 3.
|
|
os << llvm::formatv(referenceIndexingMapsFmt, cppOpName, dimsStr, mapsStr,
|
|
llvm::join(mapList, ", "));
|
|
}
|
|
|
|
/// Print the C++ StructuredOpsInterface impl of `regionBuilder`.
|
|
void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
|
|
ComprehensionParsingState &state) {
|
|
unsigned count = state.numArgs;
|
|
llvm::DenseMap<const TensorExpr *, unsigned> subExprsMap;
|
|
std::function<void(llvm::raw_ostream & os, const Expression &)> printExpr;
|
|
printExpr = [&](llvm::raw_ostream &os, const Expression &e) -> void {
|
|
if (auto *pUse = dyn_cast<TensorUse>(&e)) {
|
|
os << "_" << state.orderedTensorArgs.find(*pUse)->second;
|
|
return;
|
|
}
|
|
auto *pTensorExpr = cast<TensorExpr>(&e);
|
|
if (subExprsMap.count(pTensorExpr) > 0) {
|
|
os << "_" << subExprsMap[pTensorExpr];
|
|
} else {
|
|
std::string subExprs;
|
|
llvm::raw_string_ostream subExprsStringStream(subExprs);
|
|
llvm::interleaveComma(pTensorExpr->expressions, subExprsStringStream,
|
|
[&](const std::unique_ptr<Expression> &e) {
|
|
printExpr(subExprsStringStream, *e);
|
|
});
|
|
subExprsStringStream.flush();
|
|
const char *tensorExprFmt = "\n Value _{0} = b.create<{1}>({2});";
|
|
os << llvm::formatv(tensorExprFmt, ++count, pTensorExpr->operationName,
|
|
subExprs);
|
|
subExprsMap[pTensorExpr] = count;
|
|
}
|
|
};
|
|
|
|
const char *regionBuilderFmt = R"FMT(
|
|
void {0}::regionBuilder(ImplicitLocOpBuilder &b, Block &block) {
|
|
auto args = block.getArguments();
|
|
Value {1};
|
|
{2}
|
|
b.create<linalg::YieldOp>(ValueRange{ {3} });
|
|
})FMT";
|
|
|
|
std::string valueHandleStr;
|
|
llvm::raw_string_ostream valueHandleStringStream(valueHandleStr);
|
|
std::set<unsigned> usedTensorId;
|
|
for (const auto &iter : state.orderedTensorArgs)
|
|
usedTensorId.insert(iter.second);
|
|
llvm::interleaveComma(usedTensorId, valueHandleStringStream, [&](auto idx) {
|
|
valueHandleStringStream << "_" << idx << "(args[" << idx << "])";
|
|
});
|
|
|
|
std::string expressionsStr;
|
|
llvm::raw_string_ostream expressionStringStream(expressionsStr);
|
|
for (auto &expr : state.expressions)
|
|
visitPostorder(*expr, [&](const Expression &e) {
|
|
if (e.kind == Expression::Kind::TensorExpr)
|
|
printExpr(expressionStringStream, e);
|
|
});
|
|
expressionStringStream.flush();
|
|
substituteOpAliases(expressionsStr);
|
|
|
|
std::string yieldStr;
|
|
llvm::raw_string_ostream yieldStringStream(yieldStr);
|
|
llvm::interleaveComma(state.expressions, yieldStringStream,
|
|
[&](const std::unique_ptr<Expression> &e) {
|
|
printExpr(yieldStringStream, *e);
|
|
});
|
|
|
|
valueHandleStringStream.flush();
|
|
yieldStringStream.flush();
|
|
|
|
os << llvm::formatv(regionBuilderFmt, cppOpName, valueHandleStr,
|
|
expressionsStr, yieldStr);
|
|
}
|
|
|
|
llvm::Optional<std::string>
|
|
TCParser::RegisteredAttr::getValueFn(ArrayRef<uint64_t> indices) const {
|
|
if (isArray)
|
|
return llvm::None;
|
|
|
|
if (!vectorDims.empty()) {
|
|
SmallVector<std::string, 4> indexStrs;
|
|
for (uint64_t index : indices)
|
|
indexStrs.push_back(std::to_string(index));
|
|
std::string indexList = llvm::join(indexStrs, ", ");
|
|
if (elementType == "f32")
|
|
return llvm::formatv(".getValue<float>({ {0} })", indexList).str();
|
|
if (elementType == "i32")
|
|
return llvm::formatv(".getValue<int>({ {0} })", indexList).str();
|
|
if (elementType == "i64")
|
|
return llvm::formatv(".getValue<int64_t>({ {0} })", indexList).str();
|
|
|
|
return llvm::None;
|
|
}
|
|
|
|
if (elementType == "f32")
|
|
return std::string(".convertToFloat()");
|
|
if (elementType == "i32" || elementType == "i64")
|
|
return std::string("");
|
|
return llvm::None;
|
|
}
|
|
|
|
/// Iterate over each Tensor Comprehension def.
|
|
LogicalResult parseAndEmitAllTensorComprehensions(llvm::raw_ostream &os,
|
|
Parser &parser) {
|
|
while (parser.curToken.getKind() != Token::Kind::eof) {
|
|
TCParser tcParser(parser);
|
|
if (failed(tcParser.parseAndEmitODSDef(os)))
|
|
return failure();
|
|
}
|
|
return success();
|
|
}
|
|
|
|
int main(int argc, char **argv) {
|
|
llvm::cl::ParseCommandLineOptions(argc, argv, "Linalg ODS Gen");
|
|
|
|
// Set up the input file.
|
|
std::string errorMessage;
|
|
std::unique_ptr<llvm::MemoryBuffer> file =
|
|
mlir::openInputFile(inputFilename, &errorMessage);
|
|
if (!file) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
return 1;
|
|
}
|
|
|
|
std::unique_ptr<llvm::ToolOutputFile> output =
|
|
openOutputFile(outputFilename, &errorMessage);
|
|
if (!output) {
|
|
llvm::errs() << errorMessage << "\n";
|
|
exit(1);
|
|
}
|
|
|
|
// Include the proper Linalg header for end-to-end tblgen testing without
|
|
// resorting to non-portable shell manipulations.
|
|
if (testEmitIncludeTdHeader)
|
|
output->os() << "include \"mlir/Dialect/Linalg/IR/LinalgStructuredOps.td\"";
|
|
|
|
MLIRContext context;
|
|
llvm::SourceMgr mgr;
|
|
mgr.AddNewSourceBuffer(std::move(file), llvm::SMLoc());
|
|
Parser parser(mgr, &context);
|
|
(void)parseAndEmitAllTensorComprehensions(output->os(), parser);
|
|
output->keep();
|
|
|
|
return 0;
|
|
}
|