forked from OSchip/llvm-project
[mlir] Split MLProgram global load and store to Graph variants
* Split ops into X_graph variants as discussed; * Remove tokens from non-Graph region variants and rely on side-effect modelling there while removing side-effect modelling from Graph variants and relying on explicit ordering there; * Make tokens required to be produced by Graph variants - but kept explicit token type specification given previous discussion on this potentially being configurable in future; This results in duplicating some code. I considered adding helper functions but decided against adding an abstraction there early given size of duplication and creating accidental coupling. Differential Revision: https://reviews.llvm.org/D127813
This commit is contained in:
parent
f2bcf33058
commit
d30c0221cf
|
@ -171,7 +171,8 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
|
|||
advanced cases.
|
||||
|
||||
This op is side effecting and may not be valid to use in graph regions
|
||||
without additional consideration to evaluation order constraints.
|
||||
without additional consideration to evaluation order constraints. See
|
||||
`global_load_graph` for op which allows for explicit ordering constraints.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -181,16 +182,14 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
|
|||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<SymbolRefAttr, "", [MemRead]>:$global,
|
||||
Variadic<MLProgram_TokenType>:$consumeTokens
|
||||
Arg<SymbolRefAttr, "", [MemRead]>:$global
|
||||
);
|
||||
let results = (outs
|
||||
AnyType:$result,
|
||||
Optional<MLProgram_TokenType>:$produceToken
|
||||
AnyType:$result
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
|
||||
$global `:` type($result) attr-dict
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
|
@ -238,6 +237,52 @@ def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [
|
|||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalLoadGraphOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MLProgram_GlobalLoadGraphOp : MLProgram_Op<"global_load_graph", [
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||
]> {
|
||||
let summary = "Direct load of a mutable value from a global in Graph region";
|
||||
let description = [{
|
||||
Performs a non-atomic, non-volatile, non-synchronized load from a global
|
||||
that may be mutable.
|
||||
|
||||
It is fully expected that these constraints are not suitable for all
|
||||
situations, and alternative ops should be defined and used for more advanced
|
||||
cases.
|
||||
|
||||
This op is side effecting and may not be valid to use in graph regions
|
||||
without additional consideration to evaluation order constraints.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
%0, %cstr = ml_program.global_load_graph @foobar
|
||||
ordering (%token -> !ml_program.token) : tensor<?xi32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<SymbolRefAttr, "", [MemRead]>:$global,
|
||||
Variadic<MLProgram_TokenType>:$consumeTokens
|
||||
);
|
||||
let results = (outs
|
||||
AnyType:$result,
|
||||
MLProgram_TokenType:$produceToken
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Gets the corresponding GlobalOp (or nullptr).
|
||||
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalStoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -254,23 +299,66 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
|
|||
all situations, and alternative ops should be defined and used for more
|
||||
advanced cases.
|
||||
|
||||
This op is side effecting and may not be valid to use in graph regions
|
||||
without additional consideration to evaluation order constraints. See
|
||||
`global_store_graph` for op which allows for explicit ordering constraints.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
ml_program.global_store @foobar = %0 : tensor<?xi32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<SymbolRefAttr, "", [MemWrite]>:$global,
|
||||
AnyType:$value
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$global `=` $value `:` type($value) attr-dict
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
/// Gets the corresponding GlobalOp (or nullptr).
|
||||
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalStoreGraphOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
def MLProgram_GlobalStoreGraphOp : MLProgram_Op<"global_store_graph", [
|
||||
DeclareOpInterfaceMethods<SymbolUserOpInterface>
|
||||
]> {
|
||||
let summary = "Direct store of a value into a mutable global";
|
||||
let description = [{
|
||||
Performs a non-atomic, non-volatile, non-synchronized store to a mutable
|
||||
global.
|
||||
|
||||
It is fully expected that these constraints are not suitable for
|
||||
all situations, and alternative ops should be defined and used for more
|
||||
advanced cases.
|
||||
|
||||
This op is side effecting and may not be valid to use in graph regions
|
||||
without additional consideration to evaluation order constraints.
|
||||
|
||||
Example:
|
||||
|
||||
```mlir
|
||||
ml_program.global_store @foobar = %0 : tensor<?xi32>
|
||||
%token = ml_program.global_store @foobar = %0 : tensor<?xi32>
|
||||
ordering (%in_token -> !ml_program.token) : tensor<?xi32>
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins
|
||||
Arg<SymbolRefAttr, "", [MemWrite]>:$global,
|
||||
Arg<SymbolRefAttr, "", [MemRead]>:$global,
|
||||
AnyType:$value,
|
||||
Variadic<MLProgram_TokenType>:$consumeTokens
|
||||
);
|
||||
let results = (outs
|
||||
Optional<MLProgram_TokenType>:$produceToken
|
||||
MLProgram_TokenType:$produceToken
|
||||
);
|
||||
|
||||
let assemblyFormat = [{
|
||||
|
|
|
@ -18,12 +18,11 @@ using namespace mlir::ml_program;
|
|||
//===----------------------------------------------------------------------===//
|
||||
|
||||
/// Parse and print an ordering clause for a variadic of consuming tokens
|
||||
/// and an optional producing token.
|
||||
/// and an producing token.
|
||||
///
|
||||
/// Syntax:
|
||||
/// ordering(%0, %1 -> !ml_program.token)
|
||||
/// ordering(() -> !ml_program.token)
|
||||
/// ordering(%0, %1)
|
||||
///
|
||||
/// If both the consuming and producing token are not present on the op, then
|
||||
/// the clause prints nothing.
|
||||
|
@ -46,10 +45,11 @@ static ParseResult parseTokenOrdering(
|
|||
return failure();
|
||||
}
|
||||
|
||||
// Parse optional producer token.
|
||||
if (succeeded(parser.parseOptionalArrow()))
|
||||
if (failed(parser.parseType(produceTokenType)))
|
||||
return failure();
|
||||
// Parse producer token.
|
||||
if (failed(parser.parseArrow()))
|
||||
return failure();
|
||||
if (failed(parser.parseType(produceTokenType)))
|
||||
return failure();
|
||||
|
||||
if (failed(parser.parseRParen()))
|
||||
return failure();
|
||||
|
@ -220,6 +220,30 @@ GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalLoadGraphOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
||||
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
||||
getOperation()->getParentOp(), getGlobalAttr());
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
GlobalOp referrent = getGlobalOp(symbolTable);
|
||||
if (!referrent)
|
||||
return emitOpError() << "undefined global: " << getGlobal();
|
||||
|
||||
if (referrent.getType() != getResult().getType()) {
|
||||
return emitOpError() << "cannot load from global typed "
|
||||
<< referrent.getType() << " as "
|
||||
<< getResult().getType();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalStoreOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -249,6 +273,35 @@ GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
|||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// GlobalStoreGraphOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
|
||||
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
|
||||
getOperation()->getParentOp(), getGlobalAttr());
|
||||
}
|
||||
|
||||
LogicalResult
|
||||
GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
|
||||
GlobalOp referrent = getGlobalOp(symbolTable);
|
||||
if (!referrent)
|
||||
return emitOpError() << "undefined global: " << getGlobal();
|
||||
|
||||
if (!referrent.getIsMutable()) {
|
||||
return emitOpError() << "cannot store to an immutable global "
|
||||
<< getGlobal();
|
||||
}
|
||||
|
||||
if (referrent.getType() != getValue().getType()) {
|
||||
return emitOpError() << "cannot store to a global typed "
|
||||
<< referrent.getType() << " from "
|
||||
<< getValue().getType();
|
||||
}
|
||||
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SubgraphOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -96,3 +96,17 @@ ml_program.func @store_immutable(%arg0: i64) {
|
|||
ml_program.global_store @var = %arg0 : i64
|
||||
ml_program.return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
|
||||
ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
|
||||
%token1 = ml_program.token
|
||||
%0, %token2 = ml_program.global_load_graph @global_mutable_undef
|
||||
ordering(() -> !ml_program.token) : tensor<?xi32>
|
||||
%token3 = ml_program.global_store_graph @global_mutable_undef = %0
|
||||
// expected-error @+1 {{expected '->'}}
|
||||
ordering(%token1, %token2) : tensor<?xi32>
|
||||
|
||||
ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
|
||||
}
|
||||
|
|
|
@ -45,12 +45,12 @@ ml_program.func @global_load_store() {
|
|||
// CHECK-LABEL: @global_load_store_tokens
|
||||
ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
|
||||
%token1 = ml_program.token
|
||||
%0, %token2 = ml_program.global_load @global_mutable_undef
|
||||
%0, %token2 = ml_program.global_load_graph @global_mutable_undef
|
||||
ordering(() -> !ml_program.token) : tensor<?xi32>
|
||||
%token3 = ml_program.global_store @global_mutable_undef = %0
|
||||
%token3 = ml_program.global_store_graph @global_mutable_undef = %0
|
||||
ordering(%token1, %token2 -> !ml_program.token) : tensor<?xi32>
|
||||
ml_program.global_store @global_mutable_undef = %0
|
||||
ordering(%token3) : tensor<?xi32>
|
||||
%token4 = ml_program.global_store_graph @global_mutable_undef = %0
|
||||
ordering(%token3 -> !ml_program.token) : tensor<?xi32>
|
||||
|
||||
ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue