[mlir:async] Use ODS to define async types

Depends On D94923

Migrate Async dialect to ODS `TypeDef`

Reviewed By: ftynse, rriddle

Differential Revision: https://reviews.llvm.org/D95000
This commit is contained in:
Eugene Zhulenev 2021-01-25 14:14:12 -08:00
parent d5e48f1347
commit 2f7baffdc1
7 changed files with 154 additions and 197 deletions

View File

@ -14,6 +14,7 @@
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H
#define MLIR_DIALECT_ASYNC_IR_ASYNC_H
#include "mlir/Dialect/Async/IR/AsyncTypes.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/BuiltinTypes.h"
#include "mlir/IR/Dialect.h"
@ -22,70 +23,27 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
namespace mlir {
namespace async {
namespace detail {
struct ValueTypeStorage;
} // namespace detail
//===----------------------------------------------------------------------===//
// Async dialect types.
// Async Dialect
//===----------------------------------------------------------------------===//
/// The token type to represent asynchronous operation completion.
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
public:
using Base::Base;
};
/// The value type to represent values returned from asynchronous operations.
class ValueType
: public Type::TypeBase<ValueType, Type, detail::ValueTypeStorage> {
public:
using Base::Base;
/// Get or create an async ValueType with the provided value type.
static ValueType get(Type valueType);
Type getValueType();
};
/// The group type to represent async tokens or values grouped together.
class GroupType : public Type::TypeBase<GroupType, Type, TypeStorage> {
public:
using Base::Base;
};
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.h.inc"
//===----------------------------------------------------------------------===//
// LLVM coroutines types.
// Async Dialect Operations
//===----------------------------------------------------------------------===//
/// The type identifying a switched-resume coroutine.
class CoroIdType : public Type::TypeBase<CoroIdType, Type, TypeStorage> {
public:
using Base::Base;
};
/// The coroutine handle type which is a pointer to the coroutine frame.
class CoroHandleType
: public Type::TypeBase<CoroHandleType, Type, TypeStorage> {
public:
using Base::Base;
};
/// The coroutine saved state type.
class CoroStateType : public Type::TypeBase<CoroStateType, Type, TypeStorage> {
public:
using Base::Base;
};
#define GET_OP_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOps.h.inc"
//===----------------------------------------------------------------------===//
// Helper functions of Async dialect transformations.
//===----------------------------------------------------------------------===//
/// Returns true if the type is reference counted. All async dialect types are
/// reference counted at runtime.
namespace mlir {
namespace async {
/// Returns true if the type is reference counted at runtime.
inline bool isRefCounted(Type type) {
return type.isa<TokenType, ValueType, GroupType>();
}
@ -93,9 +51,4 @@ inline bool isRefCounted(Type type) {
} // namespace async
} // namespace mlir
#define GET_OP_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOps.h.inc"
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.h.inc"
#endif // MLIR_DIALECT_ASYNC_IR_ASYNC_H

View File

@ -0,0 +1,33 @@
//===- AsyncDialect.td -------------------------------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// Async dialect definition.
//
//===----------------------------------------------------------------------===//
#ifndef ASYNC_DIALECT_TD
#define ASYNC_DIALECT_TD
include "mlir/IR/OpBase.td"
//===----------------------------------------------------------------------===//
// Async dialect definitions
//===----------------------------------------------------------------------===//
def AsyncDialect : Dialect {
let name = "async";
let summary = "Types and operations for async dialect";
let description = [{
This dialect contains operations for modeling asynchronous execution.
}];
let cppNamespace = "::mlir::async";
}
#endif // ASYNC_DIALECT_TD

View File

@ -13,7 +13,8 @@
#ifndef ASYNC_OPS
#define ASYNC_OPS
include "mlir/Dialect/Async/IR/AsyncBase.td"
include "mlir/Dialect/Async/IR/AsyncDialect.td"
include "mlir/Dialect/Async/IR/AsyncTypes.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
@ -75,7 +76,7 @@ def Async_ExecuteOp :
Variadic<Async_AnyValueOrTokenType>:$operands);
let results = (outs Async_TokenType:$token,
Variadic<Async_AnyValueType>:$results);
Variadic<Async_ValueType>:$results);
let regions = (region SizedRegion<1>:$body);
let printer = [{ return ::print(p, *this); }];
@ -398,7 +399,7 @@ def Async_RuntimeStoreOp : Async_Op<"runtime.store",
}];
let arguments = (ins AnyType:$value,
Async_AnyValueType:$storage);
Async_ValueType:$storage);
let assemblyFormat = "$value `,` $storage attr-dict `:` type($storage)";
}
@ -412,7 +413,7 @@ def Async_RuntimeLoadOp : Async_Op<"runtime.load",
async.value storage.
}];
let arguments = (ins Async_AnyValueType:$storage);
let arguments = (ins Async_ValueType:$storage);
let results = (outs AnyType:$result);
let assemblyFormat = "$storage attr-dict `:` type($storage)";
}

View File

@ -0,0 +1,25 @@
//===- AsyncTypes.h - Async Dialect Types -----------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
//
// This file defines the types for the Async dialect.
//
//===----------------------------------------------------------------------===//
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_
#define MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_
#include "mlir/IR/Types.h"
//===----------------------------------------------------------------------===//
// Async Dialect Types
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.h.inc"
#endif // MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_

View File

@ -1,4 +1,4 @@
//===- AsyncBase.td ----------------------------------------*- tablegen -*-===//
//===- AsyncTypes.td - Async dialect types -----------------*- tablegen -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
@ -6,59 +6,53 @@
//
//===----------------------------------------------------------------------===//
//
// Base definitions for the `async` dialect.
// This file declares the Async dialect types.
//
//===----------------------------------------------------------------------===//
#ifndef ASYNC_BASE_TD
#define ASYNC_BASE_TD
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNCTYPES
#define MLIR_DIALECT_ASYNC_IR_ASYNCTYPES
include "mlir/IR/OpBase.td"
include "mlir/Dialect/Async/IR/AsyncDialect.td"
//===----------------------------------------------------------------------===//
// Async dialect definitions
// Async Types
//===----------------------------------------------------------------------===//
def AsyncDialect : Dialect {
let name = "async";
let summary = "Types and operations for async dialect";
let description = [{
This dialect contains operations for modeling asynchronous execution.
}];
let cppNamespace = "::mlir::async";
class Async_Type<string name, string typeMnemonic> : TypeDef<AsyncDialect,
name> {
let mnemonic = typeMnemonic;
}
def Async_TokenType : DialectType<AsyncDialect,
CPred<"$_self.isa<::mlir::async::TokenType>()">, "token type">,
BuildableType<"$_builder.getType<::mlir::async::TokenType>()"> {
def Async_TokenType : Async_Type<"Token", "token"> {
let summary = "async token type";
let description = [{
`async.token` is a type returned by asynchronous operations, and it becomes
`ready` when the asynchronous operations that created it is completed.
`available` when the asynchronous operations that created it is completed.
}];
}
class Async_ValueType<Type type>
: DialectType<AsyncDialect,
And<[
CPred<"$_self.isa<::mlir::async::ValueType>()">,
SubstLeaves<"$_self",
"$_self.cast<::mlir::async::ValueType>().getValueType()",
type.predicate>
]>, "async value type with " # type.summary # " underlying type"> {
def Async_ValueType : Async_Type<"Value", "value"> {
let summary = "async value type";
let description = [{
`async.value` represents a value returned by asynchronous operations,
which may or may not be available currently, but will be available at some
point in the future.
}];
Type valueType = type;
let parameters = (ins "Type":$valueType);
let builders = [
TypeBuilderWithInferredContext<(ins "Type":$valueType), [{
return Base::get(valueType.getContext(), valueType);
}], [{
return Base::getChecked($_loc, valueType);
}]>
];
let skipDefaultBuilders = 1;
}
def Async_GroupType : DialectType<AsyncDialect,
CPred<"$_self.isa<::mlir::async::GroupType>()">, "group type">,
BuildableType<"$_builder.getType<::mlir::async::GroupType>()"> {
def Async_GroupType : Async_Type<"Group", "group"> {
let summary = "async group type";
let description = [{
`async.group` represent a set of async tokens or values and allows to
execute async operations on all of them together (e.g. wait for the
@ -66,14 +60,10 @@ def Async_GroupType : DialectType<AsyncDialect,
}];
}
def Async_AnyValueType : DialectType<AsyncDialect,
CPred<"$_self.isa<::mlir::async::ValueType>()">,
"async value type">;
def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType,
def Async_AnyValueOrTokenType : AnyTypeOf<[Async_ValueType,
Async_TokenType]>;
def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType,
def Async_AnyAsyncType : AnyTypeOf<[Async_ValueType,
Async_TokenType,
Async_GroupType]>;
@ -86,30 +76,27 @@ def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType,
// build a properly typed intermediate IR during the Async to LLVM lowering we
// define a separate types for values that can be produced by LLVM intrinsics.
def Async_CoroIdType : DialectType<AsyncDialect,
CPred<"$_self.isa<::mlir::async::CoroIdType>()">, "coro.id type">,
BuildableType<"$_builder.getType<::mlir::async::CoroIdType>()"> {
def Async_CoroIdType : Async_Type<"CoroId", "coro.id"> {
let summary = "switched-resume coroutine identifier";
let description = [{
`async.coro.id` is a type identifying a switched-resume coroutine.
}];
}
def Async_CoroHandleType : DialectType<AsyncDialect,
CPred<"$_self.isa<::mlir::async::CoroHandleType>()">, "coro.handle type">,
BuildableType<"$_builder.getType<::mlir::async::CoroHandleType>()"> {
def Async_CoroHandleType : Async_Type<"CoroHandle", "coro.handle"> {
let summary = "coroutine handle";
let description = [{
`async.coro.handle` is a handle to the coroutine (pointer to the coroutine
frame) that can be passed around to resume or destroy the coroutine.
}];
}
def Async_CoroStateType : DialectType<AsyncDialect,
CPred<"$_self.isa<::mlir::async::CoroStateType>()">, "coro.state type">,
BuildableType<"$_builder.getType<::mlir::async::CoroStateType>()"> {
def Async_CoroStateType : Async_Type<"CoroState", "coro.state"> {
let summary = "saved coroutine state";
let description = [{
`async.coro.state` is a saved coroutine state that should be passed to the
coroutine suspension operation.
}];
}
#endif // ASYNC_BASE_TD
#endif // MLIR_DIALECT_ASYNC_IR_ASYNCTYPES

View File

@ -19,96 +19,12 @@ void AsyncDialect::initialize() {
#define GET_OP_LIST
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
>();
addTypes<TokenType, ValueType, GroupType>(); // async types
addTypes<CoroIdType, CoroHandleType, CoroStateType>(); // coro types
addTypes<
#define GET_TYPEDEF_LIST
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
>();
}
/// Parse a type registered to this dialect.
Type AsyncDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
if (parser.parseKeyword(&keyword))
return Type();
if (keyword == "token")
return TokenType::get(getContext());
if (keyword == "group")
return GroupType::get(getContext());
if (keyword == "value") {
Type ty;
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
return Type();
}
return ValueType::get(ty);
}
if (keyword == "coro.id")
return CoroIdType::get(getContext());
if (keyword == "coro.handle")
return CoroHandleType::get(getContext());
if (keyword == "coro.state")
return CoroStateType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
return Type();
}
/// Print a type registered to this dialect.
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
TypeSwitch<Type>(type)
.Case<TokenType>([&](TokenType) { os << "token"; })
.Case<ValueType>([&](ValueType valueTy) {
os << "value<";
os.printType(valueTy.getValueType());
os << '>';
})
.Case<GroupType>([&](GroupType) { os << "group"; })
.Case<CoroIdType>([&](CoroIdType) { os << "coro.id"; })
.Case<CoroHandleType>([&](CoroHandleType) { os << "coro.handle"; })
.Case<CoroStateType>([&](CoroStateType) { os << "coro.state"; })
.Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
}
//===----------------------------------------------------------------------===//
/// ValueType
//===----------------------------------------------------------------------===//
namespace mlir {
namespace async {
namespace detail {
// Storage for `async.value<T>` type, the only member is the wrapped type.
struct ValueTypeStorage : public TypeStorage {
ValueTypeStorage(Type valueType) : valueType(valueType) {}
/// The hash key used for uniquing.
using KeyTy = Type;
bool operator==(const KeyTy &key) const { return key == valueType; }
/// Construction.
static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
Type valueType) {
return new (allocator.allocate<ValueTypeStorage>())
ValueTypeStorage(valueType);
}
Type valueType;
};
} // namespace detail
} // namespace async
} // namespace mlir
ValueType ValueType::get(Type valueType) {
return Base::get(valueType.getContext(), valueType);
}
Type ValueType::getValueType() { return getImpl()->valueType; }
//===----------------------------------------------------------------------===//
// YieldOp
//===----------------------------------------------------------------------===//
@ -376,5 +292,47 @@ static LogicalResult verify(AwaitOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// TableGen'd op method definitions
//===----------------------------------------------------------------------===//
#define GET_OP_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
//===----------------------------------------------------------------------===//
// TableGen'd type method definitions
//===----------------------------------------------------------------------===//
#define GET_TYPEDEF_CLASSES
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
void ValueType::print(DialectAsmPrinter &printer) const {
printer << getMnemonic();
printer << "<";
printer.printType(getValueType());
printer << '>';
}
Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) {
Type ty;
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
return Type();
}
return ValueType::get(ty);
}
/// Print a type registered to this dialect.
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
if (failed(generatedTypePrinter(type, os)))
llvm_unreachable("unexpected 'async' type kind");
}
/// Parse a type registered to this dialect.
Type AsyncDialect::parseType(DialectAsmParser &parser) const {
StringRef mnemonic;
if (parser.parseKeyword(&mnemonic))
return Type();
return generatedTypeParser(getContext(), parser, mnemonic);
}

View File

@ -9,7 +9,7 @@ func @no_op(%arg0: !async.token) {
// -----
func @wrong_async_await_arg_type(%arg0: f32) {
// expected-error @+1 {{'async.await' op operand #0 must be async value type or token type, but got 'f32'}}
// expected-error @+1 {{'async.await' op operand #0 must be async value type or async token type, but got 'f32'}}
async.await %arg0 : f32
}