forked from OSchip/llvm-project
[MLIR] Add async.value type to Async dialect
Return values from async regions as !async.value<...>. Reviewed By: mehdi_amini, csigg Differential Revision: https://reviews.llvm.org/D88510
This commit is contained in:
parent
c3193e464c
commit
655af658c9
|
@ -22,12 +22,28 @@
|
|||
namespace mlir {
|
||||
namespace async {
|
||||
|
||||
namespace detail {
|
||||
struct ValueTypeStorage;
|
||||
} // namespace detail
|
||||
|
||||
/// The token type to represent asynchronous operation completion.
|
||||
class TokenType : public Type::TypeBase<TokenType, Type, TypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
};
|
||||
|
||||
/// The value type to represent values returned from asynchronous operations.
|
||||
class ValueType
|
||||
: public Type::TypeBase<ValueType, Type, detail::ValueTypeStorage> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
/// Get or create an async ValueType with the provided value type.
|
||||
static ValueType get(Type valueType);
|
||||
|
||||
Type getValueType();
|
||||
};
|
||||
|
||||
} // namespace async
|
||||
} // namespace mlir
|
||||
|
||||
|
|
|
@ -39,4 +39,24 @@ def Async_TokenType : DialectType<AsyncDialect,
|
|||
}];
|
||||
}
|
||||
|
||||
class Async_ValueType<Type type>
|
||||
: DialectType<AsyncDialect,
|
||||
And<[
|
||||
CPred<"$_self.isa<::mlir::async::ValueType>()">,
|
||||
SubstLeaves<"$_self",
|
||||
"$_self.cast<::mlir::async::ValueType>().getValueType()",
|
||||
type.predicate>
|
||||
]>, "async value type with " # type.description # " underlying type"> {
|
||||
let typeDescription = [{
|
||||
`async.value` represents a value returned by asynchronous operations,
|
||||
which may or may not be available currently, but will be available at some
|
||||
point in the future.
|
||||
}];
|
||||
|
||||
Type valueType = type;
|
||||
}
|
||||
|
||||
def Async_AnyValueType : Type<CPred<"$_self.isa<::mlir::async::ValueType>()">,
|
||||
"async value type">;
|
||||
|
||||
#endif // ASYNC_BASE_TD
|
||||
|
|
|
@ -40,24 +40,24 @@ def Async_ExecuteOp : Async_Op<"execute"> {
|
|||
state). All dependencies must be made explicit with async execute arguments
|
||||
(`async.token` or `async.value`).
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0 = async.execute {
|
||||
"compute0"(...)
|
||||
async.yield
|
||||
} : !async.token
|
||||
%done, %values = async.execute {
|
||||
%0 = "compute0"(...) : !some.type
|
||||
async.yield %1 : f32
|
||||
} : !async.token, !async.value<!some.type>
|
||||
|
||||
%1 = "compute1"(...)
|
||||
%1 = "compute1"(...) : !some.type
|
||||
```
|
||||
}];
|
||||
|
||||
// TODO: Take async.tokens/async.values as arguments.
|
||||
let arguments = (ins );
|
||||
let results = (outs Async_TokenType:$done);
|
||||
let results = (outs Async_TokenType:$done,
|
||||
Variadic<Async_AnyValueType>:$values);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let assemblyFormat = "$body attr-dict `:` type($done)";
|
||||
let printer = [{ return ::mlir::async::print(p, *this); }];
|
||||
let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }];
|
||||
}
|
||||
|
||||
def Async_YieldOp :
|
||||
|
@ -71,6 +71,8 @@ def Async_YieldOp :
|
|||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
|
||||
let verifier = [{ return ::mlir::async::verify(*this); }];
|
||||
}
|
||||
|
||||
#endif // ASYNC_OPS
|
||||
|
|
|
@ -19,8 +19,8 @@
|
|||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
using namespace mlir;
|
||||
using namespace mlir::async;
|
||||
namespace mlir {
|
||||
namespace async {
|
||||
|
||||
void AsyncDialect::initialize() {
|
||||
addOperations<
|
||||
|
@ -28,6 +28,7 @@ void AsyncDialect::initialize() {
|
|||
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|
||||
>();
|
||||
addTypes<TokenType>();
|
||||
addTypes<ValueType>();
|
||||
}
|
||||
|
||||
/// Parse a type registered to this dialect.
|
||||
|
@ -39,6 +40,15 @@ Type AsyncDialect::parseType(DialectAsmParser &parser) const {
|
|||
if (keyword == "token")
|
||||
return TokenType::get(getContext());
|
||||
|
||||
if (keyword == "value") {
|
||||
Type ty;
|
||||
if (parser.parseLess() || parser.parseType(ty) || parser.parseGreater()) {
|
||||
parser.emitError(parser.getNameLoc(), "failed to parse async value type");
|
||||
return Type();
|
||||
}
|
||||
return ValueType::get(ty);
|
||||
}
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown async type: ") << keyword;
|
||||
return Type();
|
||||
}
|
||||
|
@ -46,9 +56,113 @@ Type AsyncDialect::parseType(DialectAsmParser &parser) const {
|
|||
/// Print a type registered to this dialect.
|
||||
void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
|
||||
TypeSwitch<Type>(type)
|
||||
.Case<TokenType>([&](Type) { os << "token"; })
|
||||
.Case<TokenType>([&](TokenType) { os << "token"; })
|
||||
.Case<ValueType>([&](ValueType valueTy) {
|
||||
os << "value<";
|
||||
os.printType(valueTy.getValueType());
|
||||
os << '>';
|
||||
})
|
||||
.Default([](Type) { llvm_unreachable("unexpected 'async' type kind"); });
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// ValueType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace detail {
|
||||
|
||||
// Storage for `async.value<T>` type, the only member is the wrapped type.
|
||||
struct ValueTypeStorage : public TypeStorage {
|
||||
ValueTypeStorage(Type valueType) : valueType(valueType) {}
|
||||
|
||||
/// The hash key used for uniquing.
|
||||
using KeyTy = Type;
|
||||
bool operator==(const KeyTy &key) const { return key == valueType; }
|
||||
|
||||
/// Construction.
|
||||
static ValueTypeStorage *construct(TypeStorageAllocator &allocator,
|
||||
Type valueType) {
|
||||
return new (allocator.allocate<ValueTypeStorage>())
|
||||
ValueTypeStorage(valueType);
|
||||
}
|
||||
|
||||
Type valueType;
|
||||
};
|
||||
|
||||
} // namespace detail
|
||||
|
||||
ValueType ValueType::get(Type valueType) {
|
||||
return Base::get(valueType.getContext(), valueType);
|
||||
}
|
||||
|
||||
Type ValueType::getValueType() { return getImpl()->valueType; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// YieldOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static LogicalResult verify(YieldOp op) {
|
||||
// Get the underlying value types from async values returned from the
|
||||
// parent `async.execute` operation.
|
||||
auto executeOp = op.getParentOfType<ExecuteOp>();
|
||||
auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
|
||||
return result.getType().cast<ValueType>().getValueType();
|
||||
});
|
||||
|
||||
if (!std::equal(types.begin(), types.end(), op.getOperandTypes().begin()))
|
||||
return op.emitOpError("Operand types do not match the types returned from "
|
||||
"the parent ExecuteOp");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
/// ExecuteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
static void print(OpAsmPrinter &p, ExecuteOp op) {
|
||||
p << "async.execute ";
|
||||
p.printRegion(op.body());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : ";
|
||||
p.printType(op.done().getType());
|
||||
if (!op.values().empty())
|
||||
p << ", ";
|
||||
llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
|
||||
p.printType(result.getType());
|
||||
});
|
||||
}
|
||||
|
||||
static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
|
||||
MLIRContext *ctx = result.getContext();
|
||||
|
||||
// Parse asynchronous region.
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
|
||||
/*enableNameShadowing=*/false))
|
||||
return failure();
|
||||
|
||||
// Parse operation attributes.
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs))
|
||||
return failure();
|
||||
result.addAttributes(attrs);
|
||||
|
||||
// Parse result types.
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
if (parser.parseColonTypeList(resultTypes))
|
||||
return failure();
|
||||
|
||||
// First result type must be an async token type.
|
||||
if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
|
||||
return failure();
|
||||
parser.addTypesToList(resultTypes, result.types);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace async
|
||||
} // namespace mlir
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|
||||
|
|
|
@ -1,16 +1,46 @@
|
|||
// RUN: mlir-opt %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @identity
|
||||
func @identity(%arg0 : !async.token) -> !async.token {
|
||||
// CHECK-LABEL: @identity_token
|
||||
func @identity_token(%arg0 : !async.token) -> !async.token {
|
||||
// CHECK: return %arg0 : !async.token
|
||||
return %arg0 : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @identity_value
|
||||
func @identity_value(%arg0 : !async.value<f32>) -> !async.value<f32> {
|
||||
// CHECK: return %arg0 : !async.value<f32>
|
||||
return %arg0 : !async.value<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @empty_async_execute
|
||||
func @empty_async_execute() -> !async.token {
|
||||
%0 = async.execute {
|
||||
%done = async.execute {
|
||||
async.yield
|
||||
} : !async.token
|
||||
|
||||
return %0 : !async.token
|
||||
// CHECK: return %done : !async.token
|
||||
return %done : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @return_async_value
|
||||
func @return_async_value() -> !async.value<f32> {
|
||||
%done, %values = async.execute {
|
||||
%cst = constant 1.000000e+00 : f32
|
||||
async.yield %cst : f32
|
||||
} : !async.token, !async.value<f32>
|
||||
|
||||
// CHECK: return %values : !async.value<f32>
|
||||
return %values : !async.value<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @return_async_values
|
||||
func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
|
||||
%done, %values:2 = async.execute {
|
||||
%cst1 = constant 1.000000e+00 : f32
|
||||
%cst2 = constant 2.000000e+00 : f32
|
||||
async.yield %cst1, %cst2 : f32, f32
|
||||
} : !async.token, !async.value<f32>, !async.value<f32>
|
||||
|
||||
// CHECK: return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
|
||||
return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue