forked from OSchip/llvm-project
[mlir:async] Use ODS to define async types
Depends On D94923 Migrate Async dialect to ODS `TypeDef` Reviewed By: ftynse, rriddle Differential Revision: https://reviews.llvm.org/D95000
This commit is contained in:
parent
d5e48f1347
commit
2f7baffdc1
|
@ -14,6 +14,7 @@
|
|||
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H
|
||||
#define MLIR_DIALECT_ASYNC_IR_ASYNC_H
|
||||
|
||||
#include "mlir/Dialect/Async/IR/AsyncTypes.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/BuiltinTypes.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
|
@ -22,70 +23,27 @@
|
|||
#include "mlir/Interfaces/ControlFlowInterfaces.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace async {
|
||||
|
||||
namespace detail {
|
||||
struct ValueTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Async dialect types.
|
||||
// Async Dialect
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The token type to represent asynchronous operation completion.
|
||||
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
/// The value type to represent values returned from asynchronous operations.
|
||||
class ValueType
|
||||
: public Type::TypeBase<ValueType, Type, detail::ValueTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Get or create an async ValueType with the provided value type.
|
||||
static ValueType get(Type valueType);
|
||||
|
||||
Type getValueType();
|
||||
};
|
||||
|
||||
/// The group type to represent async tokens or values grouped together.
|
||||
class GroupType : public Type::TypeBase<GroupType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
};
|
||||
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// LLVM coroutines types.
|
||||
// Async Dialect Operations
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// The type identifying a switched-resume coroutine.
|
||||
class CoroIdType : public Type::TypeBase<CoroIdType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
/// The coroutine handle type which is a pointer to the coroutine frame.
|
||||
class CoroHandleType
|
||||
: public Type::TypeBase<CoroHandleType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
/// The coroutine saved state type.
|
||||
class CoroStateType : public Type::TypeBase<CoroStateType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
};
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Async/IR/AsyncOps.h.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Helper functions of Async dialect transformations.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Returns true if the type is reference counted. All async dialect types are
|
||||
/// reference counted at runtime.
|
||||
namespace mlir {
|
||||
namespace async {
|
||||
|
||||
/// Returns true if the type is reference counted at runtime.
|
||||
inline bool isRefCounted(Type type) {
|
||||
return type.isa<TokenType, ValueType, GroupType>();
|
||||
}
|
||||
|
@ -93,9 +51,4 @@ inline bool isRefCounted(Type type) {
|
|||
} // namespace async
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Async/IR/AsyncOps.h.inc"
|
||||
|
||||
#include "mlir/Dialect/Async/IR/AsyncOpsDialect.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_ASYNC_IR_ASYNC_H
|
||||
|
|
|
@ -0,0 +1,33 @@
|
|||
//===- AsyncDialect.td -------------------------------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Async dialect definition.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ASYNC_DIALECT_TD
|
||||
#define ASYNC_DIALECT_TD
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Async dialect definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AsyncDialect : Dialect {
|
||||
let name = "async";
|
||||
|
||||
let summary = "Types and operations for async dialect";
|
||||
let description = [{
|
||||
This dialect contains operations for modeling asynchronous execution.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::async";
|
||||
}
|
||||
|
||||
#endif // ASYNC_DIALECT_TD
|
|
@ -13,7 +13,8 @@
|
|||
#ifndef ASYNC_OPS
|
||||
#define ASYNC_OPS
|
||||
|
||||
include "mlir/Dialect/Async/IR/AsyncBase.td"
|
||||
include "mlir/Dialect/Async/IR/AsyncDialect.td"
|
||||
include "mlir/Dialect/Async/IR/AsyncTypes.td"
|
||||
include "mlir/Interfaces/ControlFlowInterfaces.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
|
@ -75,7 +76,7 @@ def Async_ExecuteOp :
|
|||
Variadic<Async_AnyValueOrTokenType>:$operands);
|
||||
|
||||
let results = (outs Async_TokenType:$token,
|
||||
Variadic<Async_AnyValueType>:$results);
|
||||
Variadic<Async_ValueType>:$results);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
|
@ -398,7 +399,7 @@ def Async_RuntimeStoreOp : Async_Op<"runtime.store",
|
|||
}];
|
||||
|
||||
let arguments = (ins AnyType:$value,
|
||||
Async_AnyValueType:$storage);
|
||||
Async_ValueType:$storage);
|
||||
let assemblyFormat = "$value `,` $storage attr-dict `:` type($storage)";
|
||||
}
|
||||
|
||||
|
@ -412,7 +413,7 @@ def Async_RuntimeLoadOp : Async_Op<"runtime.load",
|
|||
async.value storage.
|
||||
}];
|
||||
|
||||
let arguments = (ins Async_AnyValueType:$storage);
|
||||
let arguments = (ins Async_ValueType:$storage);
|
||||
let results = (outs AnyType:$result);
|
||||
let assemblyFormat = "$storage attr-dict `:` type($storage)";
|
||||
}
|
||||
|
|
|
@ -0,0 +1,25 @@
|
|||
//===- AsyncTypes.h - Async Dialect Types -----------------------*- C++ -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// This file defines the types for the Async dialect.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_
|
||||
#define MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_
|
||||
|
||||
#include "mlir/IR/Types.h"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Async Dialect Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.h.inc"
|
||||
|
||||
#endif // MLIR_DIALECT_ASYNC_IR_ASYNCTYPES_H_
|
|
@ -1,4 +1,4 @@
|
|||
//===- AsyncBase.td ----------------------------------------*- tablegen -*-===//
|
||||
//===- AsyncTypes.td - Async dialect types -----------------*- tablegen -*-===//
|
||||
//
|
||||
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
|
||||
// See https://llvm.org/LICENSE.txt for license information.
|
||||
|
@ -6,59 +6,53 @@
|
|||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
//
|
||||
// Base definitions for the `async` dialect.
|
||||
// This file declares the Async dialect types.
|
||||
//
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#ifndef ASYNC_BASE_TD
|
||||
#define ASYNC_BASE_TD
|
||||
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNCTYPES
|
||||
#define MLIR_DIALECT_ASYNC_IR_ASYNCTYPES
|
||||
|
||||
include "mlir/IR/OpBase.td"
|
||||
include "mlir/Dialect/Async/IR/AsyncDialect.td"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Async dialect definitions
|
||||
// Async Types
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def AsyncDialect : Dialect {
|
||||
let name = "async";
|
||||
|
||||
let summary = "Types and operations for async dialect";
|
||||
let description = [{
|
||||
This dialect contains operations for modeling asynchronous execution.
|
||||
}];
|
||||
|
||||
let cppNamespace = "::mlir::async";
|
||||
class Async_Type<string name, string typeMnemonic> : TypeDef<AsyncDialect,
|
||||
name> {
|
||||
let mnemonic = typeMnemonic;
|
||||
}
|
||||
|
||||
def Async_TokenType : DialectType<AsyncDialect,
|
||||
CPred<"$_self.isa<::mlir::async::TokenType>()">, "token type">,
|
||||
BuildableType<"$_builder.getType<::mlir::async::TokenType>()"> {
|
||||
def Async_TokenType : Async_Type<"Token", "token"> {
|
||||
let summary = "async token type";
|
||||
let description = [{
|
||||
`async.token` is a type returned by asynchronous operations, and it becomes
|
||||
`ready` when the asynchronous operations that created it is completed.
|
||||
`available` when the asynchronous operations that created it is completed.
|
||||
}];
|
||||
}
|
||||
|
||||
class Async_ValueType<Type type>
|
||||
: DialectType<AsyncDialect,
|
||||
And<[
|
||||
CPred<"$_self.isa<::mlir::async::ValueType>()">,
|
||||
SubstLeaves<"$_self",
|
||||
"$_self.cast<::mlir::async::ValueType>().getValueType()",
|
||||
type.predicate>
|
||||
]>, "async value type with " # type.summary # " underlying type"> {
|
||||
def Async_ValueType : Async_Type<"Value", "value"> {
|
||||
let summary = "async value type";
|
||||
let description = [{
|
||||
`async.value` represents a value returned by asynchronous operations,
|
||||
which may or may not be available currently, but will be available at some
|
||||
point in the future.
|
||||
}];
|
||||
|
||||
Type valueType = type;
|
||||
let parameters = (ins "Type":$valueType);
|
||||
let builders = [
|
||||
TypeBuilderWithInferredContext<(ins "Type":$valueType), [{
|
||||
return Base::get(valueType.getContext(), valueType);
|
||||
}], [{
|
||||
return Base::getChecked($_loc, valueType);
|
||||
}]>
|
||||
];
|
||||
let skipDefaultBuilders = 1;
|
||||
}
|
||||
|
||||
def Async_GroupType : DialectType<AsyncDialect,
|
||||
CPred<"$_self.isa<::mlir::async::GroupType>()">, "group type">,
|
||||
BuildableType<"$_builder.getType<::mlir::async::GroupType>()"> {
|
||||
def Async_GroupType : Async_Type<"Group", "group"> {
|
||||
let summary = "async group type";
|
||||
let description = [{
|
||||
`async.group` represent a set of async tokens or values and allows to
|
||||
execute async operations on all of them together (e.g. wait for the
|
||||
|
@ -66,14 +60,10 @@ def Async_GroupType : DialectType<AsyncDialect,
|
|||
}];
|
||||
}
|
||||
|
||||
def Async_AnyValueType : DialectType<AsyncDialect,
|
||||
CPred<"$_self.isa<::mlir::async::ValueType>()">,
|
||||
"async value type">;
|
||||
|
||||
def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType,
|
||||
def Async_AnyValueOrTokenType : AnyTypeOf<[Async_ValueType,
|
||||
Async_TokenType]>;
|
||||
|
||||
def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType,
|
||||
def Async_AnyAsyncType : AnyTypeOf<[Async_ValueType,
|
||||
Async_TokenType,
|
||||
Async_GroupType]>;
|
||||
|
||||
|
@ -86,30 +76,27 @@ def Async_AnyAsyncType : AnyTypeOf<[Async_AnyValueType,
|
|||
// build a properly typed intermediate IR during the Async to LLVM lowering we
|
||||
// define a separate types for values that can be produced by LLVM intrinsics.
|
||||
|
||||
def Async_CoroIdType : DialectType<AsyncDialect,
|
||||
CPred<"$_self.isa<::mlir::async::CoroIdType>()">, "coro.id type">,
|
||||
BuildableType<"$_builder.getType<::mlir::async::CoroIdType>()"> {
|
||||
def Async_CoroIdType : Async_Type<"CoroId", "coro.id"> {
|
||||
let summary = "switched-resume coroutine identifier";
|
||||
let description = [{
|
||||
`async.coro.id` is a type identifying a switched-resume coroutine.
|
||||
}];
|
||||
}
|
||||
|
||||
def Async_CoroHandleType : DialectType<AsyncDialect,
|
||||
CPred<"$_self.isa<::mlir::async::CoroHandleType>()">, "coro.handle type">,
|
||||
BuildableType<"$_builder.getType<::mlir::async::CoroHandleType>()"> {
|
||||
def Async_CoroHandleType : Async_Type<"CoroHandle", "coro.handle"> {
|
||||
let summary = "coroutine handle";
|
||||
let description = [{
|
||||
`async.coro.handle` is a handle to the coroutine (pointer to the coroutine
|
||||
frame) that can be passed around to resume or destroy the coroutine.
|
||||
}];
|
||||
}
|
||||
|
||||
def Async_CoroStateType : DialectType<AsyncDialect,
|
||||
CPred<"$_self.isa<::mlir::async::CoroStateType>()">, "coro.state type">,
|
||||
BuildableType<"$_builder.getType<::mlir::async::CoroStateType>()"> {
|
||||
def Async_CoroStateType : Async_Type<"CoroState", "coro.state"> {
|
||||
let summary = "saved coroutine state";
|
||||
let description = [{
|
||||
`async.coro.state` is a saved coroutine state that should be passed to the
|
||||
coroutine suspension operation.
|
||||
}];
|
||||
}
|
||||
|
||||
#endif // ASYNC_BASE_TD
|
||||
#endif // MLIR_DIALECT_ASYNC_IR_ASYNCTYPES
|
|
@ -19,96 +19,12 @@ void AsyncDialect::initialize() {
|
|||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|
||||
>();
|
||||
addTypes<TokenType, ValueType, GroupType>(); // async types
|
||||
addTypes<CoroIdType, CoroHandleType, CoroStateType>(); // coro types
|
||||
addTypes<
|
||||
#define GET_TYPEDEF_LIST
|
||||
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
|
||||
>();
|
||||
}
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type AsyncDialect::parseType(DialectAsmParser &parser) const {
|
||||
StringRef keyword;
|
||||
if (parser.parseKeyword(&keyword))
|
||||
return Type();
|
||||
|
||||
if (keyword == "token")
|
||||
return TokenType::get(getContext());
|
||||
|
||||
if (keyword == "group")
|
||||
return GroupType::get(getContext());
|
||||
|
||||
if (keyword == "value") {
|
||||
Type ty;
|
||||
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
|
||||
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
|
||||
return Type();
|
||||
}
|
||||
return ValueType::get(ty);
|
||||
}
|
||||
|
||||
if (keyword == "coro.id")
|
||||
return CoroIdType::get(getContext());
|
||||
|
||||
if (keyword == "coro.handle")
|
||||
return CoroHandleType::get(getContext());
|
||||
|
||||
if (keyword == "coro.state")
|
||||
return CoroStateType::get(getContext());
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
|
||||
return Type();
|
||||
}
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
TypeSwitch<Type>(type)
|
||||
.Case<TokenType>([&](TokenType) { os << "token"; })
|
||||
.Case<ValueType>([&](ValueType valueTy) {
|
||||
os << "value<";
|
||||
os.printType(valueTy.getValueType());
|
||||
os << '>';
|
||||
})
|
||||
.Case<GroupType>([&](GroupType) { os << "group"; })
|
||||
.Case<CoroIdType>([&](CoroIdType) { os << "coro.id"; })
|
||||
.Case<CoroHandleType>([&](CoroHandleType) { os << "coro.handle"; })
|
||||
.Case<CoroStateType>([&](CoroStateType) { os << "coro.state"; })
|
||||
.Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// ValueType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace async {
|
||||
namespace detail {
|
||||
|
||||
// Storage for `async.value<T>` type, the only member is the wrapped type.
|
||||
struct ValueTypeStorage : public TypeStorage {
|
||||
ValueTypeStorage(Type valueType) : valueType(valueType) {}
|
||||
|
||||
/// The hash key used for uniquing.
|
||||
using KeyTy = Type;
|
||||
bool operator==(const KeyTy &key) const { return key == valueType; }
|
||||
|
||||
/// Construction.
|
||||
static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
Type valueType) {
|
||||
return new (allocator.allocate<ValueTypeStorage>())
|
||||
ValueTypeStorage(valueType);
|
||||
}
|
||||
|
||||
Type valueType;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace async
|
||||
} // namespace mlir
|
||||
|
||||
ValueType ValueType::get(Type valueType) {
|
||||
return Base::get(valueType.getContext(), valueType);
|
||||
}
|
||||
|
||||
Type ValueType::getValueType() { return getImpl()->valueType; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -376,5 +292,47 @@ static LogicalResult verify(AwaitOp op) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd op method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TableGen'd type method definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
#define GET_TYPEDEF_CLASSES
|
||||
#include "mlir/Dialect/Async/IR/AsyncOpsTypes.cpp.inc"
|
||||
|
||||
void ValueType::print(DialectAsmPrinter &printer) const {
|
||||
printer << getMnemonic();
|
||||
printer << "<";
|
||||
printer.printType(getValueType());
|
||||
printer << '>';
|
||||
}
|
||||
|
||||
Type ValueType::parse(mlir::MLIRContext *, mlir::DialectAsmParser &parser) {
|
||||
Type ty;
|
||||
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
|
||||
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
|
||||
return Type();
|
||||
}
|
||||
return ValueType::get(ty);
|
||||
}
|
||||
|
||||
/// Print a type registered to this dialect.
|
||||
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
if (failed(generatedTypePrinter(type, os)))
|
||||
llvm_unreachable("unexpected 'async' type kind");
|
||||
}
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
Type AsyncDialect::parseType(DialectAsmParser &parser) const {
|
||||
StringRef mnemonic;
|
||||
if (parser.parseKeyword(&mnemonic))
|
||||
return Type();
|
||||
|
||||
return generatedTypeParser(getContext(), parser, mnemonic);
|
||||
}
|
||||
|
|
|
@ -9,7 +9,7 @@ func @no_op(%arg0: !async.token) {
|
|||
// -----
|
||||
|
||||
func @wrong_async_await_arg_type(%arg0: f32) {
|
||||
// expected-error @+1 {{'async.await' op operand #0 must be async value type or token type, but got 'f32'}}
|
||||
// expected-error @+1 {{'async.await' op operand #0 must be async value type or async token type, but got 'f32'}}
|
||||
async.await %arg0 : f32
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue