forked from OSchip/llvm-project
[MLIR] Add async token/value arguments to async.execute op
Async execute operation can take async arguments as dependencies. Change `async.execute` custom parser/printer format to use `%value as %unwrapped: !async.value<!type>` sytax. Reviewed By: mehdi_amini, herhut Differential Revision: https://reviews.llvm.org/D88601
This commit is contained in:
parent
dcd9be43e5
commit
4e69a52952
|
@ -14,9 +14,11 @@
|
|||
#ifndef MLIR_DIALECT_ASYNC_IR_ASYNC_H
|
||||
#define MLIR_DIALECT_ASYNC_IR_ASYNC_H
|
||||
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/IR/OpDefinition.h"
|
||||
#include "mlir/IR/OpImplementation.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Interfaces/SideEffectInterfaces.h"
|
||||
|
||||
namespace mlir {
|
||||
|
|
|
@ -56,7 +56,11 @@ class Async_ValueType<Type type>
|
|||
Type valueType = type;
|
||||
}
|
||||
|
||||
def Async_AnyValueType : Type<CPred<"$_self.isa<::mlir::async::ValueType>()">,
|
||||
"async value type">;
|
||||
def Async_AnyValueType : DialectType<AsyncDialect,
|
||||
CPred<"$_self.isa<::mlir::async::ValueType>()">,
|
||||
"async value type">;
|
||||
|
||||
def Async_AnyValueOrTokenType : AnyTypeOf<[Async_AnyValueType,
|
||||
Async_TokenType]>;
|
||||
|
||||
#endif // ASYNC_BASE_TD
|
||||
|
|
|
@ -24,7 +24,7 @@ include "mlir/Interfaces/SideEffectInterfaces.td"
|
|||
class Async_Op<string mnemonic, list<OpTrait> traits = []> :
|
||||
Op<AsyncDialect, mnemonic, traits>;
|
||||
|
||||
def Async_ExecuteOp : Async_Op<"execute"> {
|
||||
def Async_ExecuteOp : Async_Op<"execute", [AttrSizedOperandSegments]> {
|
||||
let summary = "Asynchronous execute operation";
|
||||
let description = [{
|
||||
The `body` region attached to the `async.execute` operation semantically
|
||||
|
@ -40,24 +40,43 @@ def Async_ExecuteOp : Async_Op<"execute"> {
|
|||
state). All dependencies must be made explicit with async execute arguments
|
||||
(`async.token` or `async.value`).
|
||||
|
||||
`async.execute` operation takes `async.token` dependencies and `async.value`
|
||||
operands separatly, and starts execution of the attached body region only
|
||||
when all tokens and values become ready.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%done, %values = async.execute {
|
||||
%0 = "compute0"(...) : !some.type
|
||||
async.yield %1 : f32
|
||||
} : !async.token, !async.value<!some.type>
|
||||
%dependency = ... : !async.token
|
||||
%value = ... : !async.value<f32>
|
||||
|
||||
%token, %results =
|
||||
async.execute [%dependency](%value as %unwrapped: !async.value<f32>)
|
||||
-> !async.value<!some.type>
|
||||
{
|
||||
%0 = "compute0"(%unwrapped): (f32) -> !some.type
|
||||
async.yield %0 : !some.type
|
||||
}
|
||||
|
||||
%1 = "compute1"(...) : !some.type
|
||||
```
|
||||
|
||||
In the example above asynchronous execution starts only after dependency
|
||||
token and value argument become ready. Unwrapped value passed to the
|
||||
attached body region as an %unwrapped value of f32 type.
|
||||
}];
|
||||
|
||||
// TODO: Take async.tokens/async.values as arguments.
|
||||
let arguments = (ins );
|
||||
let results = (outs Async_TokenType:$done,
|
||||
Variadic<Async_AnyValueType>:$values);
|
||||
let arguments = (ins Variadic<Async_TokenType>:$dependencies,
|
||||
Variadic<Async_AnyValueOrTokenType>:$operands);
|
||||
|
||||
let results = (outs Async_TokenType:$token,
|
||||
Variadic<Async_AnyValueType>:$results);
|
||||
let regions = (region SizedRegion<1>:$body);
|
||||
|
||||
let printer = [{ return ::mlir::async::print(p, *this); }];
|
||||
let parser = [{ return ::mlir::async::parse$cppClass(parser, result); }];
|
||||
let printer = [{ return ::print(p, *this); }];
|
||||
let parser = [{ return ::parse$cppClass(parser, result); }];
|
||||
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
def Async_YieldOp :
|
||||
|
@ -72,7 +91,7 @@ def Async_YieldOp :
|
|||
|
||||
let assemblyFormat = "attr-dict ($operands^ `:` type($operands))?";
|
||||
|
||||
let verifier = [{ return ::mlir::async::verify(*this); }];
|
||||
let verifier = [{ return ::verify(*this); }];
|
||||
}
|
||||
|
||||
#endif // ASYNC_OPS
|
||||
|
|
|
@ -8,19 +8,11 @@
|
|||
|
||||
#include "mlir/Dialect/Async/IR/Async.h"
|
||||
|
||||
#include "mlir/Dialect/StandardOps/IR/Ops.h"
|
||||
#include "mlir/Dialect/Traits.h"
|
||||
#include "mlir/IR/Builders.h"
|
||||
#include "mlir/IR/DialectImplementation.h"
|
||||
#include "mlir/IR/PatternMatch.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Transforms/InliningUtils.h"
|
||||
#include "llvm/ADT/SmallString.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace async {
|
||||
using namespace mlir;
|
||||
using namespace mlir::async;
|
||||
|
||||
void AsyncDialect::initialize() {
|
||||
addOperations<
|
||||
|
@ -69,6 +61,8 @@ void AsyncDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
/// ValueType
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
namespace mlir {
|
||||
namespace async {
|
||||
namespace detail {
|
||||
|
||||
// Storage for `async.value<T>` type, the only member is the wrapped type.
|
||||
|
@ -90,6 +84,8 @@ struct ValueTypeStorage : public TypeStorage {
|
|||
};
|
||||
|
||||
} // namespace detail
|
||||
} // namespace async
|
||||
} // namespace mlir
|
||||
|
||||
ValueType ValueType::get(Type valueType) {
|
||||
return Base::get(valueType.getContext(), valueType);
|
||||
|
@ -105,7 +101,7 @@ static LogicalResult verify(YieldOp op) {
|
|||
// Get the underlying value types from async values returned from the
|
||||
// parent `async.execute` operation.
|
||||
auto executeOp = op.getParentOfType<ExecuteOp>();
|
||||
auto types = llvm::map_range(executeOp.values(), [](const OpResult &result) {
|
||||
auto types = llvm::map_range(executeOp.results(), [](const OpResult &result) {
|
||||
return result.getType().cast<ValueType>().getValueType();
|
||||
});
|
||||
|
||||
|
@ -120,49 +116,139 @@ static LogicalResult verify(YieldOp op) {
|
|||
/// ExecuteOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
constexpr char kOperandSegmentSizesAttr[] = "operand_segment_sizes";
|
||||
|
||||
static void print(OpAsmPrinter &p, ExecuteOp op) {
|
||||
p << "async.execute ";
|
||||
p.printRegion(op.body());
|
||||
p.printOptionalAttrDict(op.getAttrs());
|
||||
p << " : ";
|
||||
p.printType(op.done().getType());
|
||||
if (!op.values().empty())
|
||||
p << ", ";
|
||||
llvm::interleaveComma(op.values(), p, [&](const OpResult &result) {
|
||||
p.printType(result.getType());
|
||||
});
|
||||
p << op.getOperationName();
|
||||
|
||||
// [%tokens,...]
|
||||
if (!op.dependencies().empty())
|
||||
p << " [" << op.dependencies() << "]";
|
||||
|
||||
// (%value as %unwrapped: !async.value<!arg.type>, ...)
|
||||
if (!op.operands().empty()) {
|
||||
p << " (";
|
||||
llvm::interleaveComma(op.operands(), p, [&, n = 0](Value operand) mutable {
|
||||
p << operand << " as " << op.body().front().getArgument(n++) << ": "
|
||||
<< operand.getType();
|
||||
});
|
||||
p << ")";
|
||||
}
|
||||
|
||||
// -> (!async.value<!return.type>, ...)
|
||||
p.printOptionalArrowTypeList(op.getResultTypes().drop_front(1));
|
||||
p.printOptionalAttrDictWithKeyword(op.getAttrs(), {kOperandSegmentSizesAttr});
|
||||
p.printRegion(op.body(), /*printEntryBlockArgs=*/false);
|
||||
}
|
||||
|
||||
static ParseResult parseExecuteOp(OpAsmParser &parser, OperationState &result) {
|
||||
MLIRContext *ctx = result.getContext();
|
||||
|
||||
// Parse asynchronous region.
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{},
|
||||
/*enableNameShadowing=*/false))
|
||||
// Sizes of parsed variadic operands, will be updated below after parsing.
|
||||
int32_t numDependencies = 0;
|
||||
int32_t numOperands = 0;
|
||||
|
||||
auto tokenTy = TokenType::get(ctx);
|
||||
|
||||
// Parse dependency tokens.
|
||||
if (succeeded(parser.parseOptionalLSquare())) {
|
||||
SmallVector<OpAsmParser::OperandType, 4> tokenArgs;
|
||||
if (parser.parseOperandList(tokenArgs) ||
|
||||
parser.resolveOperands(tokenArgs, tokenTy, result.operands) ||
|
||||
parser.parseRSquare())
|
||||
return failure();
|
||||
|
||||
numDependencies = tokenArgs.size();
|
||||
}
|
||||
|
||||
// Parse async value operands (%value as %unwrapped : !async.value<!type>).
|
||||
SmallVector<OpAsmParser::OperandType, 4> valueArgs;
|
||||
SmallVector<OpAsmParser::OperandType, 4> unwrappedArgs;
|
||||
SmallVector<Type, 4> valueTypes;
|
||||
SmallVector<Type, 4> unwrappedTypes;
|
||||
|
||||
if (succeeded(parser.parseOptionalLParen())) {
|
||||
auto argsLoc = parser.getCurrentLocation();
|
||||
|
||||
// Parse a single instance of `%value as %unwrapped : !async.value<!type>`.
|
||||
auto parseAsyncValueArg = [&]() -> ParseResult {
|
||||
if (parser.parseOperand(valueArgs.emplace_back()) ||
|
||||
parser.parseKeyword("as") ||
|
||||
parser.parseOperand(unwrappedArgs.emplace_back()) ||
|
||||
parser.parseColonType(valueTypes.emplace_back()))
|
||||
return failure();
|
||||
|
||||
auto valueTy = valueTypes.back().dyn_cast<ValueType>();
|
||||
unwrappedTypes.emplace_back(valueTy ? valueTy.getValueType() : Type());
|
||||
|
||||
return success();
|
||||
};
|
||||
|
||||
// If the next token is `)` skip async value arguments parsing.
|
||||
if (failed(parser.parseOptionalRParen())) {
|
||||
do {
|
||||
if (parseAsyncValueArg())
|
||||
return failure();
|
||||
} while (succeeded(parser.parseOptionalComma()));
|
||||
|
||||
if (parser.parseRParen() ||
|
||||
parser.resolveOperands(valueArgs, valueTypes, argsLoc,
|
||||
result.operands))
|
||||
return failure();
|
||||
}
|
||||
|
||||
numOperands = valueArgs.size();
|
||||
}
|
||||
|
||||
// Add derived `operand_segment_sizes` attribute based on parsed operands.
|
||||
auto operandSegmentSizes = DenseIntElementsAttr::get(
|
||||
VectorType::get({2}, parser.getBuilder().getI32Type()),
|
||||
{numDependencies, numOperands});
|
||||
result.addAttribute(kOperandSegmentSizesAttr, operandSegmentSizes);
|
||||
|
||||
// Parse the types of results returned from the async execute op.
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
if (parser.parseOptionalArrowTypeList(resultTypes))
|
||||
return failure();
|
||||
|
||||
// Async execute first result is always a completion token.
|
||||
parser.addTypeToList(tokenTy, result.types);
|
||||
parser.addTypesToList(resultTypes, result.types);
|
||||
|
||||
// Parse operation attributes.
|
||||
NamedAttrList attrs;
|
||||
if (parser.parseOptionalAttrDict(attrs))
|
||||
if (parser.parseOptionalAttrDictWithKeyword(attrs))
|
||||
return failure();
|
||||
result.addAttributes(attrs);
|
||||
|
||||
// Parse result types.
|
||||
SmallVector<Type, 4> resultTypes;
|
||||
if (parser.parseColonTypeList(resultTypes))
|
||||
// Parse asynchronous region.
|
||||
Region *body = result.addRegion();
|
||||
if (parser.parseRegion(*body, /*arguments=*/{unwrappedArgs},
|
||||
/*argTypes=*/{unwrappedTypes},
|
||||
/*enableNameShadowing=*/false))
|
||||
return failure();
|
||||
|
||||
// First result type must be an async token type.
|
||||
if (resultTypes.empty() || resultTypes.front() != TokenType::get(ctx))
|
||||
return failure();
|
||||
parser.addTypesToList(resultTypes, result.types);
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
} // namespace async
|
||||
} // namespace mlir
|
||||
static LogicalResult verify(ExecuteOp op) {
|
||||
// Unwrap async.execute value operands types.
|
||||
auto unwrappedTypes = llvm::map_range(op.operands(), [](Value operand) {
|
||||
return operand.getType().cast<ValueType>().getValueType();
|
||||
});
|
||||
|
||||
// Verify that unwrapped argument types matches the body region arguments.
|
||||
if (llvm::size(unwrappedTypes) != llvm::size(op.body().getArgumentTypes()))
|
||||
return op.emitOpError("the number of async body region arguments does not "
|
||||
"match the number of execute operation arguments");
|
||||
|
||||
if (!std::equal(unwrappedTypes.begin(), unwrappedTypes.end(),
|
||||
op.body().getArgumentTypes().begin()))
|
||||
return op.emitOpError("async body region argument types do not match the "
|
||||
"execute operation arguments types");
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Async/IR/AsyncOps.cpp.inc"
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
// RUN: mlir-opt %s | FileCheck %s
|
||||
// RUN: mlir-opt %s | FileCheck %s
|
||||
|
||||
// CHECK-LABEL: @identity_token
|
||||
func @identity_token(%arg0 : !async.token) -> !async.token {
|
||||
func @identity_token(%arg0: !async.token) -> !async.token {
|
||||
// CHECK: return %arg0 : !async.token
|
||||
return %arg0 : !async.token
|
||||
}
|
||||
|
@ -14,33 +14,95 @@ func @identity_value(%arg0 : !async.value<f32>) -> !async.value<f32> {
|
|||
|
||||
// CHECK-LABEL: @empty_async_execute
|
||||
func @empty_async_execute() -> !async.token {
|
||||
%done = async.execute {
|
||||
// CHECK: async.execute
|
||||
%token = async.execute {
|
||||
async.yield
|
||||
} : !async.token
|
||||
}
|
||||
|
||||
// CHECK: return %done : !async.token
|
||||
return %done : !async.token
|
||||
// CHECK: return %token : !async.token
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @return_async_value
|
||||
func @return_async_value() -> !async.value<f32> {
|
||||
%done, %values = async.execute {
|
||||
// CHECK: async.execute -> !async.value<f32>
|
||||
%token, %results = async.execute -> !async.value<f32> {
|
||||
%cst = constant 1.000000e+00 : f32
|
||||
async.yield %cst : f32
|
||||
} : !async.token, !async.value<f32>
|
||||
}
|
||||
|
||||
// CHECK: return %values : !async.value<f32>
|
||||
return %values : !async.value<f32>
|
||||
// CHECK: return %results : !async.value<f32>
|
||||
return %results : !async.value<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @return_captured_value
|
||||
func @return_captured_value() -> !async.token {
|
||||
%cst = constant 1.000000e+00 : f32
|
||||
// CHECK: async.execute -> !async.value<f32>
|
||||
%token, %results = async.execute -> !async.value<f32> {
|
||||
async.yield %cst : f32
|
||||
}
|
||||
|
||||
// CHECK: return %token : !async.token
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @return_async_values
|
||||
func @return_async_values() -> (!async.value<f32>, !async.value<f32>) {
|
||||
%done, %values:2 = async.execute {
|
||||
%token, %results:2 = async.execute -> (!async.value<f32>, !async.value<f32>) {
|
||||
%cst1 = constant 1.000000e+00 : f32
|
||||
%cst2 = constant 2.000000e+00 : f32
|
||||
async.yield %cst1, %cst2 : f32, f32
|
||||
} : !async.token, !async.value<f32>, !async.value<f32>
|
||||
}
|
||||
|
||||
// CHECK: return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
|
||||
return %values#0, %values#1 : !async.value<f32>, !async.value<f32>
|
||||
// CHECK: return %results#0, %results#1 : !async.value<f32>, !async.value<f32>
|
||||
return %results#0, %results#1 : !async.value<f32>, !async.value<f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @async_token_dependencies
|
||||
func @async_token_dependencies(%arg0: !async.token) -> !async.token {
|
||||
// CHECK: async.execute [%arg0]
|
||||
%token = async.execute [%arg0] {
|
||||
async.yield
|
||||
}
|
||||
|
||||
// CHECK: return %token : !async.token
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @async_value_operands
|
||||
func @async_value_operands(%arg0: !async.value<f32>) -> !async.token {
|
||||
// CHECK: async.execute (%arg0 as %arg1: !async.value<f32>) -> !async.value<f32>
|
||||
%token, %results = async.execute (%arg0 as %arg1: !async.value<f32>) -> !async.value<f32> {
|
||||
async.yield %arg1 : f32
|
||||
}
|
||||
|
||||
// CHECK: return %token : !async.token
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @async_token_and_value_operands
|
||||
func @async_token_and_value_operands(%arg0: !async.token, %arg1: !async.value<f32>) -> !async.token {
|
||||
// CHECK: async.execute [%arg0] (%arg1 as %arg2: !async.value<f32>) -> !async.value<f32>
|
||||
%token, %results = async.execute [%arg0] (%arg1 as %arg2: !async.value<f32>) -> !async.value<f32> {
|
||||
async.yield %arg2 : f32
|
||||
}
|
||||
|
||||
// CHECK: return %token : !async.token
|
||||
return %token : !async.token
|
||||
}
|
||||
|
||||
// CHECK-LABEL: @empty_tokens_or_values_operands
|
||||
func @empty_tokens_or_values_operands() {
|
||||
// CHECK: async.execute {
|
||||
%token0 = async.execute [] () -> () { async.yield }
|
||||
// CHECK: async.execute {
|
||||
%token1 = async.execute () -> () { async.yield }
|
||||
// CHECK: async.execute {
|
||||
%token2 = async.execute -> () { async.yield }
|
||||
// CHECK: async.execute {
|
||||
%token3 = async.execute () { async.yield }
|
||||
// CHECK: async.execute {
|
||||
%token4 = async.execute [] { async.yield }
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue