forked from OSchip/llvm-project
[PDLL] Add support for `op` Operation expressions
An operation expression in PDLL represents an MLIR operation. In the match section of a pattern, this expression models one of the input operations to the pattern. In the rewrite section of a pattern, this expression models one of the operations to create. The general structure of the operation expression is very similar to that of the "generic form" of textual MLIR assembly: ``` let root = op<my_dialect.foo>(operands: ValueRange) {attr = attr: Attr} -> (resultTypes: TypeRange); ``` For now we only model the components that are within PDL, as PDL gains support for blocks and regions so will this expression. Differential Revision: https://reviews.llvm.org/D115296
This commit is contained in:
parent
d7e7fdf3aa
commit
02670c3f38
|
@ -23,6 +23,7 @@ namespace ast {
|
|||
class Context;
|
||||
class Decl;
|
||||
class Expr;
|
||||
class NamedAttributeDecl;
|
||||
class OpNameDecl;
|
||||
class VariableDecl;
|
||||
|
||||
|
@ -342,6 +343,105 @@ private:
|
|||
StringRef memberName;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AllResultsMemberAccessExpr
|
||||
|
||||
/// This class represents an instance of MemberAccessExpr that references all
|
||||
/// results of an operation.
|
||||
class AllResultsMemberAccessExpr : public MemberAccessExpr {
|
||||
public:
|
||||
/// Return the member name used for the "all-results" access.
|
||||
static StringRef getMemberName() { return "$results"; }
|
||||
|
||||
static AllResultsMemberAccessExpr *create(Context &ctx, llvm::SMRange loc,
|
||||
const Expr *parentExpr, Type type) {
|
||||
return cast<AllResultsMemberAccessExpr>(
|
||||
MemberAccessExpr::create(ctx, loc, parentExpr, getMemberName(), type));
|
||||
}
|
||||
|
||||
/// Provide type casting support.
|
||||
static bool classof(const Node *node) {
|
||||
const MemberAccessExpr *memAccess = dyn_cast<MemberAccessExpr>(node);
|
||||
return memAccess && memAccess->getMemberName() == getMemberName();
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperationExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This expression represents the structural form of an MLIR Operation. It
|
||||
/// represents either an input operation to match, or an operation to create
|
||||
/// within a rewrite.
|
||||
class OperationExpr final
|
||||
: public Node::NodeBase<OperationExpr, Expr>,
|
||||
private llvm::TrailingObjects<OperationExpr, Expr *,
|
||||
NamedAttributeDecl *> {
|
||||
public:
|
||||
static OperationExpr *create(Context &ctx, llvm::SMRange loc,
|
||||
const OpNameDecl *nameDecl,
|
||||
ArrayRef<Expr *> operands,
|
||||
ArrayRef<Expr *> resultTypes,
|
||||
ArrayRef<NamedAttributeDecl *> attributes);
|
||||
|
||||
/// Return the name of the operation, or None if there isn't one.
|
||||
Optional<StringRef> getName() const;
|
||||
|
||||
/// Return the declaration of the operation name.
|
||||
const OpNameDecl *getNameDecl() const { return nameDecl; }
|
||||
|
||||
/// Return the location of the name of the operation expression, or an invalid
|
||||
/// location if there isn't a name.
|
||||
llvm::SMRange getNameLoc() const { return nameLoc; }
|
||||
|
||||
/// Return the operands of this operation.
|
||||
MutableArrayRef<Expr *> getOperands() {
|
||||
return {getTrailingObjects<Expr *>(), numOperands};
|
||||
}
|
||||
ArrayRef<Expr *> getOperands() const {
|
||||
return const_cast<OperationExpr *>(this)->getOperands();
|
||||
}
|
||||
|
||||
/// Return the result types of this operation.
|
||||
MutableArrayRef<Expr *> getResultTypes() {
|
||||
return {getTrailingObjects<Expr *>() + numOperands, numResultTypes};
|
||||
}
|
||||
MutableArrayRef<Expr *> getResultTypes() const {
|
||||
return const_cast<OperationExpr *>(this)->getResultTypes();
|
||||
}
|
||||
|
||||
/// Return the attributes of this operation.
|
||||
MutableArrayRef<NamedAttributeDecl *> getAttributes() {
|
||||
return {getTrailingObjects<NamedAttributeDecl *>(), numAttributes};
|
||||
}
|
||||
MutableArrayRef<NamedAttributeDecl *> getAttributes() const {
|
||||
return const_cast<OperationExpr *>(this)->getAttributes();
|
||||
}
|
||||
|
||||
private:
|
||||
OperationExpr(llvm::SMRange loc, Type type, const OpNameDecl *nameDecl,
|
||||
unsigned numOperands, unsigned numResultTypes,
|
||||
unsigned numAttributes, llvm::SMRange nameLoc)
|
||||
: Base(loc, type), nameDecl(nameDecl), numOperands(numOperands),
|
||||
numResultTypes(numResultTypes), numAttributes(numAttributes),
|
||||
nameLoc(nameLoc) {}
|
||||
|
||||
/// The name decl of this expression.
|
||||
const OpNameDecl *nameDecl;
|
||||
|
||||
/// The number of operands, result types, and attributes of the operation.
|
||||
unsigned numOperands, numResultTypes, numAttributes;
|
||||
|
||||
/// The location of the operation name in the expression if it has a name.
|
||||
llvm::SMRange nameLoc;
|
||||
|
||||
/// TrailingObject utilities.
|
||||
friend llvm::TrailingObjects<OperationExpr, Expr *, NamedAttributeDecl *>;
|
||||
size_t numTrailingObjects(OverloadToken<Expr *>) const {
|
||||
return numOperands + numResultTypes;
|
||||
}
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -555,6 +655,31 @@ protected:
|
|||
Expr *typeExpr;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NamedAttributeDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// This Decl represents a NamedAttribute, and contains a string name and
|
||||
/// attribute value.
|
||||
class NamedAttributeDecl : public Node::NodeBase<NamedAttributeDecl, Decl> {
|
||||
public:
|
||||
static NamedAttributeDecl *create(Context &ctx, const Name &name,
|
||||
Expr *value);
|
||||
|
||||
/// Return the name of the attribute.
|
||||
const Name &getName() const { return *Decl::getName(); }
|
||||
|
||||
/// Return value of the attribute.
|
||||
Expr *getValue() const { return value; }
|
||||
|
||||
private:
|
||||
NamedAttributeDecl(const Name &name, Expr *value)
|
||||
: Base(name.getLoc(), &name), value(value) {}
|
||||
|
||||
/// The value of the attribute.
|
||||
Expr *value;
|
||||
};
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpNameDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -703,7 +828,8 @@ private:
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
inline bool Decl::classof(const Node *node) {
|
||||
return isa<ConstraintDecl, OpNameDecl, PatternDecl, VariableDecl>(node);
|
||||
return isa<ConstraintDecl, NamedAttributeDecl, OpNameDecl, PatternDecl,
|
||||
VariableDecl>(node);
|
||||
}
|
||||
|
||||
inline bool ConstraintDecl::classof(const Node *node) {
|
||||
|
@ -717,7 +843,8 @@ inline bool CoreConstraintDecl::classof(const Node *node) {
|
|||
}
|
||||
|
||||
inline bool Expr::classof(const Node *node) {
|
||||
return isa<AttributeExpr, DeclRefExpr, MemberAccessExpr, TypeExpr>(node);
|
||||
return isa<AttributeExpr, DeclRefExpr, MemberAccessExpr, OperationExpr,
|
||||
TypeExpr>(node);
|
||||
}
|
||||
|
||||
inline bool OpRewriteStmt::classof(const Node *node) {
|
||||
|
|
|
@ -79,6 +79,7 @@ private:
|
|||
void printImpl(const AttributeExpr *expr);
|
||||
void printImpl(const DeclRefExpr *expr);
|
||||
void printImpl(const MemberAccessExpr *expr);
|
||||
void printImpl(const OperationExpr *expr);
|
||||
void printImpl(const TypeExpr *expr);
|
||||
|
||||
void printImpl(const AttrConstraintDecl *decl);
|
||||
|
@ -87,6 +88,7 @@ private:
|
|||
void printImpl(const TypeRangeConstraintDecl *decl);
|
||||
void printImpl(const ValueConstraintDecl *decl);
|
||||
void printImpl(const ValueRangeConstraintDecl *decl);
|
||||
void printImpl(const NamedAttributeDecl *decl);
|
||||
void printImpl(const OpNameDecl *decl);
|
||||
void printImpl(const PatternDecl *decl);
|
||||
void printImpl(const VariableDecl *decl);
|
||||
|
@ -147,13 +149,14 @@ void NodePrinter::print(const Node *node) {
|
|||
|
||||
// Expressions.
|
||||
const AttributeExpr, const DeclRefExpr, const MemberAccessExpr,
|
||||
const TypeExpr,
|
||||
const OperationExpr, const TypeExpr,
|
||||
|
||||
// Decls.
|
||||
const AttrConstraintDecl, const OpConstraintDecl,
|
||||
const TypeConstraintDecl, const TypeRangeConstraintDecl,
|
||||
const ValueConstraintDecl, const ValueRangeConstraintDecl,
|
||||
const OpNameDecl, const PatternDecl, const VariableDecl,
|
||||
const NamedAttributeDecl, const OpNameDecl, const PatternDecl,
|
||||
const VariableDecl,
|
||||
|
||||
const Module>([&](auto derivedNode) { this->printImpl(derivedNode); })
|
||||
.Default([](const Node *) { llvm_unreachable("unknown AST node"); });
|
||||
|
@ -194,6 +197,17 @@ void NodePrinter::printImpl(const MemberAccessExpr *expr) {
|
|||
printChildren(expr->getParentExpr());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const OperationExpr *expr) {
|
||||
os << "OperationExpr " << expr << " Type<";
|
||||
print(expr->getType());
|
||||
os << ">\n";
|
||||
|
||||
printChildren(expr->getNameDecl());
|
||||
printChildren("Operands", expr->getOperands());
|
||||
printChildren("Result Types", expr->getResultTypes());
|
||||
printChildren("Attributes", expr->getAttributes());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const TypeExpr *expr) {
|
||||
os << "TypeExpr " << expr << " Value<\"" << expr->getValue() << "\">\n";
|
||||
}
|
||||
|
@ -229,6 +243,12 @@ void NodePrinter::printImpl(const ValueRangeConstraintDecl *decl) {
|
|||
printChildren(typeExpr);
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const NamedAttributeDecl *decl) {
|
||||
os << "NamedAttributeDecl " << decl << " Name<" << decl->getName().getName()
|
||||
<< ">\n";
|
||||
printChildren(decl->getValue());
|
||||
}
|
||||
|
||||
void NodePrinter::printImpl(const OpNameDecl *decl) {
|
||||
os << "OpNameDecl " << decl;
|
||||
if (Optional<StringRef> name = decl->getName())
|
||||
|
|
|
@ -116,6 +116,37 @@ MemberAccessExpr *MemberAccessExpr::create(Context &ctx, llvm::SMRange loc,
|
|||
loc, parentExpr, memberName.copy(ctx.getAllocator()), type);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperationExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OperationExpr *OperationExpr::create(
|
||||
Context &ctx, llvm::SMRange loc, const OpNameDecl *name,
|
||||
ArrayRef<Expr *> operands, ArrayRef<Expr *> resultTypes,
|
||||
ArrayRef<NamedAttributeDecl *> attributes) {
|
||||
unsigned allocSize =
|
||||
OperationExpr::totalSizeToAlloc<Expr *, NamedAttributeDecl *>(
|
||||
operands.size() + resultTypes.size(), attributes.size());
|
||||
void *rawData =
|
||||
ctx.getAllocator().Allocate(allocSize, alignof(OperationExpr));
|
||||
|
||||
Type resultType = OperationType::get(ctx, name->getName());
|
||||
OperationExpr *opExpr = new (rawData)
|
||||
OperationExpr(loc, resultType, name, operands.size(), resultTypes.size(),
|
||||
attributes.size(), name->getLoc());
|
||||
std::uninitialized_copy(operands.begin(), operands.end(),
|
||||
opExpr->getOperands().begin());
|
||||
std::uninitialized_copy(resultTypes.begin(), resultTypes.end(),
|
||||
opExpr->getResultTypes().begin());
|
||||
std::uninitialized_copy(attributes.begin(), attributes.end(),
|
||||
opExpr->getAttributes().begin());
|
||||
return opExpr;
|
||||
}
|
||||
|
||||
Optional<StringRef> OperationExpr::getName() const {
|
||||
return getNameDecl()->getName();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -193,6 +224,16 @@ ValueRangeConstraintDecl *ValueRangeConstraintDecl::create(Context &ctx,
|
|||
ValueRangeConstraintDecl(loc, typeExpr);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NamedAttributeDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
NamedAttributeDecl *NamedAttributeDecl::create(Context &ctx, const Name &name,
|
||||
Expr *value) {
|
||||
return new (ctx.getAllocator().Allocate<NamedAttributeDecl>())
|
||||
NamedAttributeDecl(name, value);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OpNameDecl
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -80,6 +80,10 @@ private:
|
|||
ast::Expr *&expr, ast::Type type,
|
||||
function_ref<void(ast::Diagnostic &diag)> noteAttachFn = {});
|
||||
|
||||
/// Given an operation expression, convert it to a Value or ValueRange
|
||||
/// typed expression.
|
||||
ast::Expr *convertOpToValue(const ast::Expr *opExpr);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Directives
|
||||
|
||||
|
@ -96,6 +100,7 @@ private:
|
|||
};
|
||||
|
||||
FailureOr<ast::Decl *> parseTopLevelDecl();
|
||||
FailureOr<ast::NamedAttributeDecl *> parseNamedAttributeDecl();
|
||||
FailureOr<ast::Decl *> parsePatternDecl();
|
||||
LogicalResult parsePatternDeclMetadata(ParsedPatternMetadata &metadata);
|
||||
|
||||
|
@ -141,6 +146,7 @@ private:
|
|||
FailureOr<ast::Expr *> parseMemberAccessExpr(ast::Expr *parentExpr);
|
||||
FailureOr<ast::OpNameDecl *> parseOperationName();
|
||||
FailureOr<ast::OpNameDecl *> parseWrappedOperationName();
|
||||
FailureOr<ast::Expr *> parseOperationExpr();
|
||||
FailureOr<ast::Expr *> parseTypeExpr();
|
||||
FailureOr<ast::Expr *> parseUnderscoreExpr();
|
||||
|
||||
|
@ -205,6 +211,22 @@ private:
|
|||
/// success, this also returns the type of the member accessed.
|
||||
FailureOr<ast::Type> validateMemberAccess(ast::Expr *parentExpr,
|
||||
StringRef name, llvm::SMRange loc);
|
||||
FailureOr<ast::OperationExpr *>
|
||||
createOperationExpr(llvm::SMRange loc, const ast::OpNameDecl *name,
|
||||
MutableArrayRef<ast::Expr *> operands,
|
||||
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
|
||||
MutableArrayRef<ast::Expr *> results);
|
||||
LogicalResult
|
||||
validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
|
||||
MutableArrayRef<ast::Expr *> operands);
|
||||
LogicalResult validateOperationResults(llvm::SMRange loc,
|
||||
Optional<StringRef> name,
|
||||
MutableArrayRef<ast::Expr *> results);
|
||||
LogicalResult
|
||||
validateOperationOperandsOrResults(llvm::SMRange loc,
|
||||
Optional<StringRef> name,
|
||||
MutableArrayRef<ast::Expr *> values,
|
||||
ast::Type singleTy, ast::Type rangeTy);
|
||||
|
||||
//===--------------------------------------------------------------------===//
|
||||
// Stmts
|
||||
|
@ -322,6 +344,11 @@ LogicalResult Parser::parseModuleBody(SmallVector<ast::Decl *> &decls) {
|
|||
return success();
|
||||
}
|
||||
|
||||
ast::Expr *Parser::convertOpToValue(const ast::Expr *opExpr) {
|
||||
return ast::AllResultsMemberAccessExpr::create(ctx, opExpr->getLoc(), opExpr,
|
||||
valueRangeTy);
|
||||
}
|
||||
|
||||
LogicalResult Parser::convertExpressionTo(
|
||||
ast::Expr *&expr, ast::Type type,
|
||||
function_ref<void(ast::Diagnostic &diag)> noteAttachFn) {
|
||||
|
@ -351,15 +378,15 @@ LogicalResult Parser::convertExpressionTo(
|
|||
|
||||
// An operation can always convert to a ValueRange.
|
||||
if (type == valueRangeTy) {
|
||||
expr = ast::MemberAccessExpr::create(ctx, expr->getLoc(), expr,
|
||||
"$results", valueRangeTy);
|
||||
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
|
||||
valueRangeTy);
|
||||
return success();
|
||||
}
|
||||
|
||||
// Allow conversion to a single value by constraining the result range.
|
||||
if (type == valueTy) {
|
||||
expr = ast::MemberAccessExpr::create(ctx, expr->getLoc(), expr,
|
||||
"$results", valueTy);
|
||||
expr = ast::AllResultsMemberAccessExpr::create(ctx, expr->getLoc(), expr,
|
||||
valueTy);
|
||||
return success();
|
||||
}
|
||||
return emitConvertError();
|
||||
|
@ -447,6 +474,33 @@ FailureOr<ast::Decl *> Parser::parseTopLevelDecl() {
|
|||
return decl;
|
||||
}
|
||||
|
||||
FailureOr<ast::NamedAttributeDecl *> Parser::parseNamedAttributeDecl() {
|
||||
std::string attrNameStr;
|
||||
if (curToken.isString())
|
||||
attrNameStr = curToken.getStringValue();
|
||||
else if (curToken.is(Token::identifier) || curToken.isKeyword())
|
||||
attrNameStr = curToken.getSpelling().str();
|
||||
else
|
||||
return emitError("expected identifier or string attribute name");
|
||||
const auto &name = ast::Name::create(ctx, attrNameStr, curToken.getLoc());
|
||||
consumeToken();
|
||||
|
||||
// Check for a value of the attribute.
|
||||
ast::Expr *attrValue = nullptr;
|
||||
if (consumeIf(Token::equal)) {
|
||||
FailureOr<ast::Expr *> attrExpr = parseExpr();
|
||||
if (failed(attrExpr))
|
||||
return failure();
|
||||
attrValue = *attrExpr;
|
||||
} else {
|
||||
// If there isn't a concrete value, create an expression representing a
|
||||
// UnitAttr.
|
||||
attrValue = ast::AttributeExpr::create(ctx, name.getLoc(), "unit");
|
||||
}
|
||||
|
||||
return ast::NamedAttributeDecl::create(ctx, name, attrValue);
|
||||
}
|
||||
|
||||
FailureOr<ast::Decl *> Parser::parsePatternDecl() {
|
||||
llvm::SMRange loc = curToken.getLoc();
|
||||
consumeToken(Token::kw_Pattern);
|
||||
|
@ -739,6 +793,9 @@ FailureOr<ast::Expr *> Parser::parseExpr() {
|
|||
case Token::identifier:
|
||||
lhsExpr = parseIdentifierExpr();
|
||||
break;
|
||||
case Token::kw_op:
|
||||
lhsExpr = parseOperationExpr();
|
||||
break;
|
||||
case Token::kw_type:
|
||||
lhsExpr = parseTypeExpr();
|
||||
break;
|
||||
|
@ -868,6 +925,77 @@ FailureOr<ast::OpNameDecl *> Parser::parseWrappedOperationName() {
|
|||
return opNameDecl;
|
||||
}
|
||||
|
||||
FailureOr<ast::Expr *> Parser::parseOperationExpr() {
|
||||
llvm::SMRange loc = curToken.getLoc();
|
||||
consumeToken(Token::kw_op);
|
||||
|
||||
// If it isn't followed by a `<`, the `op` keyword is treated as a normal
|
||||
// identifier.
|
||||
if (curToken.isNot(Token::less)) {
|
||||
resetToken(loc);
|
||||
return parseIdentifierExpr();
|
||||
}
|
||||
|
||||
// Parse the operation name. The name may be elided, in which case the
|
||||
// operation refers to "any" operation(i.e. a difference between `MyOp` and
|
||||
// `Operation*`).
|
||||
FailureOr<ast::OpNameDecl *> opNameDecl = parseWrappedOperationName();
|
||||
if (failed(opNameDecl))
|
||||
return failure();
|
||||
|
||||
// Check for the optional list of operands.
|
||||
SmallVector<ast::Expr *> operands;
|
||||
if (consumeIf(Token::l_paren)) {
|
||||
do {
|
||||
FailureOr<ast::Expr *> operand = parseExpr();
|
||||
if (failed(operand))
|
||||
return failure();
|
||||
operands.push_back(*operand);
|
||||
} while (consumeIf(Token::comma));
|
||||
|
||||
if (failed(parseToken(Token::r_paren,
|
||||
"expected `)` after operation operand list")))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check for the optional list of attributes.
|
||||
SmallVector<ast::NamedAttributeDecl *> attributes;
|
||||
if (consumeIf(Token::l_brace)) {
|
||||
do {
|
||||
FailureOr<ast::NamedAttributeDecl *> decl = parseNamedAttributeDecl();
|
||||
if (failed(decl))
|
||||
return failure();
|
||||
attributes.emplace_back(*decl);
|
||||
} while (consumeIf(Token::comma));
|
||||
|
||||
if (failed(parseToken(Token::r_brace,
|
||||
"expected `}` after operation attribute list")))
|
||||
return failure();
|
||||
}
|
||||
|
||||
// Check for the optional list of result types.
|
||||
SmallVector<ast::Expr *> resultTypes;
|
||||
if (consumeIf(Token::arrow)) {
|
||||
if (failed(parseToken(Token::l_paren,
|
||||
"expected `(` before operation result type list")))
|
||||
return failure();
|
||||
|
||||
do {
|
||||
FailureOr<ast::Expr *> resultTypeExpr = parseExpr();
|
||||
if (failed(resultTypeExpr))
|
||||
return failure();
|
||||
resultTypes.push_back(*resultTypeExpr);
|
||||
} while (consumeIf(Token::comma));
|
||||
|
||||
if (failed(parseToken(Token::r_paren,
|
||||
"expected `)` after operation result type list")))
|
||||
return failure();
|
||||
}
|
||||
|
||||
return createOperationExpr(loc, *opNameDecl, operands, attributes,
|
||||
resultTypes);
|
||||
}
|
||||
|
||||
FailureOr<ast::Expr *> Parser::parseTypeExpr() {
|
||||
llvm::SMRange loc = curToken.getLoc();
|
||||
consumeToken(Token::kw_type);
|
||||
|
@ -1198,11 +1326,8 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
|
|||
StringRef name,
|
||||
llvm::SMRange loc) {
|
||||
ast::Type parentType = parentExpr->getType();
|
||||
if (ast::OperationType opType = parentType.dyn_cast<ast::OperationType>()) {
|
||||
// $results is a special member access representing all of the results.
|
||||
// TODO: Should we have special AST expressions for these? How does the
|
||||
// user reference these in the language itself?
|
||||
if (name == "$results")
|
||||
if (parentType.isa<ast::OperationType>()) {
|
||||
if (name == ast::AllResultsMemberAccessExpr::getMemberName())
|
||||
return valueRangeTy;
|
||||
}
|
||||
return emitError(
|
||||
|
@ -1211,6 +1336,89 @@ FailureOr<ast::Type> Parser::validateMemberAccess(ast::Expr *parentExpr,
|
|||
name, parentType));
|
||||
}
|
||||
|
||||
FailureOr<ast::OperationExpr *> Parser::createOperationExpr(
|
||||
llvm::SMRange loc, const ast::OpNameDecl *name,
|
||||
MutableArrayRef<ast::Expr *> operands,
|
||||
MutableArrayRef<ast::NamedAttributeDecl *> attributes,
|
||||
MutableArrayRef<ast::Expr *> results) {
|
||||
Optional<StringRef> opNameRef = name->getName();
|
||||
|
||||
// Verify the inputs operands.
|
||||
if (failed(validateOperationOperands(loc, opNameRef, operands)))
|
||||
return failure();
|
||||
|
||||
// Verify the attribute list.
|
||||
for (ast::NamedAttributeDecl *attr : attributes) {
|
||||
// Check for an attribute type, or a type awaiting resolution.
|
||||
ast::Type attrType = attr->getValue()->getType();
|
||||
if (!attrType.isa<ast::AttributeType>()) {
|
||||
return emitError(
|
||||
attr->getValue()->getLoc(),
|
||||
llvm::formatv("expected `Attr` expression, but got `{0}`", attrType));
|
||||
}
|
||||
}
|
||||
|
||||
// Verify the result types.
|
||||
if (failed(validateOperationResults(loc, opNameRef, results)))
|
||||
return failure();
|
||||
|
||||
return ast::OperationExpr::create(ctx, loc, name, operands, results,
|
||||
attributes);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Parser::validateOperationOperands(llvm::SMRange loc, Optional<StringRef> name,
|
||||
MutableArrayRef<ast::Expr *> operands) {
|
||||
return validateOperationOperandsOrResults(loc, name, operands, valueTy,
|
||||
valueRangeTy);
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
Parser::validateOperationResults(llvm::SMRange loc, Optional<StringRef> name,
|
||||
MutableArrayRef<ast::Expr *> results) {
|
||||
return validateOperationOperandsOrResults(loc, name, results, typeTy,
|
||||
typeRangeTy);
|
||||
}
|
||||
|
||||
LogicalResult Parser::validateOperationOperandsOrResults(
|
||||
llvm::SMRange loc, Optional<StringRef> name,
|
||||
MutableArrayRef<ast::Expr *> values, ast::Type singleTy,
|
||||
ast::Type rangeTy) {
|
||||
// All operation types accept a single range parameter.
|
||||
if (values.size() == 1) {
|
||||
if (failed(convertExpressionTo(values[0], rangeTy)))
|
||||
return failure();
|
||||
return success();
|
||||
}
|
||||
|
||||
// Otherwise, accept the value groups as they have been defined and just
|
||||
// ensure they are one of the expected types.
|
||||
for (ast::Expr *&valueExpr : values) {
|
||||
ast::Type valueExprType = valueExpr->getType();
|
||||
|
||||
// Check if this is one of the expected types.
|
||||
if (valueExprType == rangeTy || valueExprType == singleTy)
|
||||
continue;
|
||||
|
||||
// If the operand is an Operation, allow converting to a Value or
|
||||
// ValueRange. This situations arises quite often with nested operation
|
||||
// expressions: `op<my_dialect.foo>(op<my_dialect.bar>)`
|
||||
if (singleTy == valueTy) {
|
||||
if (valueExprType.isa<ast::OperationType>()) {
|
||||
valueExpr = convertOpToValue(valueExpr);
|
||||
continue;
|
||||
}
|
||||
}
|
||||
|
||||
return emitError(
|
||||
valueExpr->getLoc(),
|
||||
llvm::formatv(
|
||||
"expected `{0}` or `{1}` convertible expression, but got `{2}`",
|
||||
singleTy, rangeTy, valueExprType));
|
||||
}
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Stmts
|
||||
|
||||
|
|
|
@ -81,6 +81,83 @@ Pattern {
|
|||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// `op` Expr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `)` after operation operand list
|
||||
let value: Value;
|
||||
let foo = op<builtin.func>(value<;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: unable to convert expression of type `Attr` to the expected type of `ValueRange`
|
||||
let attr: Attr;
|
||||
let foo = op<builtin.func>(attr);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `Value` or `ValueRange` convertible expression, but got `Type`
|
||||
let foo = op<>(_: Type, _: TypeRange);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected identifier or string attribute name
|
||||
let foo = op<> { 10;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `Attr` expression, but got `Value`
|
||||
let foo = op<> { foo = _: Value };
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `}` after operation attribute list
|
||||
let foo = op<> { "foo" {;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `(` before operation result type list
|
||||
let foo = op<> -> );
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: unable to convert expression of type `ValueRange` to the expected type of `TypeRange`
|
||||
let foo = op<> -> (_: ValueRange);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `Type` or `TypeRange` convertible expression, but got `Value`
|
||||
let foo = op<> -> (_: Value, _: ValueRange);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
Pattern {
|
||||
// CHECK: expected `)` after operation result type list
|
||||
let value: TypeRange;
|
||||
let foo = op<> -> (value<;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// `type` Expr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -14,6 +14,82 @@ Pattern {
|
|||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// OperationExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-OperationExpr {{.*}} Type<Op>
|
||||
// CHECK: `-OpNameDecl
|
||||
Pattern {
|
||||
erase op<>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-OperationExpr {{.*}} Type<Op<my_dialect.foo>>
|
||||
// CHECK: `-OpNameDecl {{.*}} Name<my_dialect.foo>
|
||||
Pattern {
|
||||
erase op<my_dialect.foo>;
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-OperationExpr {{.*}} Type<Op>
|
||||
// CHECK: `-OpNameDecl
|
||||
// CHECK: `Operands`
|
||||
// CHECK: |-DeclRefExpr {{.*}} Type<Value>
|
||||
// CHECK: |-DeclRefExpr {{.*}} Type<ValueRange>
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Value>
|
||||
Pattern {
|
||||
erase op<>(_: Value, _: ValueRange, _: Value);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-OperationExpr {{.*}} Type<Op>
|
||||
// CHECK: `-OpNameDecl
|
||||
// CHECK: `Operands`
|
||||
// CHECK: `-MemberAccessExpr {{.*}} Member<$results> Type<ValueRange>
|
||||
// CHECK: `-OperationExpr {{.*}} Type<Op<my_dialect.bar>>
|
||||
// CHECK: `-OpNameDecl {{.*}} Name<my_dialect.bar>
|
||||
Pattern {
|
||||
erase op<>(op<my_dialect.bar>);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-OperationExpr {{.*}} Type<Op>
|
||||
// CHECK: `-OpNameDecl
|
||||
// CHECK: `Attributes`
|
||||
// CHECK: |-NamedAttributeDecl {{.*}} Name<unitAttr>
|
||||
// CHECK: `-AttributeExpr {{.*}} Value<"unit">
|
||||
// CHECK: `-NamedAttributeDecl {{.*}} Name<normal$Attr>
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Attr>
|
||||
|
||||
Pattern {
|
||||
erase op<> {unitAttr, "normal$Attr" = _: Attr};
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK: Module
|
||||
// CHECK: `-OperationExpr {{.*}} Type<Op>
|
||||
// CHECK: `-OpNameDecl
|
||||
// CHECK: `Result Types`
|
||||
// CHECK: |-DeclRefExpr {{.*}} Type<Type>
|
||||
// CHECK: |-DeclRefExpr {{.*}} Type<TypeRange>
|
||||
// CHECK: `-DeclRefExpr {{.*}} Type<Type>
|
||||
Pattern {
|
||||
erase op<> -> (_: Type, _: TypeRange, _: Type);
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TypeExpr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
Loading…
Reference in New Issue