[mlir] Split MLProgram global load and store to Graph variants

* Split ops into X_graph variants as discussed;
* Remove tokens from non-Graph region variants and rely on side-effect
  modelling there while removing side-effect modelling from Graph
  variants and relying on explicit ordering there;
* Make tokens required to be produced by Graph variants - but kept
  explicit token type specification given previous discussion on this
  potentially being configurable in future;

This results in duplicating some code. I considered adding helper
functions but decided against adding an abstraction there early given
size of duplication and creating accidental coupling.

Differential Revision: https://reviews.llvm.org/D127813
This commit is contained in:
Jacques Pienaar 2022-06-16 20:01:54 -07:00
parent f2bcf33058
commit d30c0221cf
4 changed files with 174 additions and 19 deletions

View File

@ -171,7 +171,8 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
advanced cases.
This op is side effecting and may not be valid to use in graph regions
without additional consideration to evaluation order constraints.
without additional consideration to evaluation order constraints. See
`global_load_graph` for op which allows for explicit ordering constraints.
Example:
@ -181,16 +182,14 @@ def MLProgram_GlobalLoadOp : MLProgram_Op<"global_load", [
}];
let arguments = (ins
Arg<SymbolRefAttr, "", [MemRead]>:$global,
Variadic<MLProgram_TokenType>:$consumeTokens
Arg<SymbolRefAttr, "", [MemRead]>:$global
);
let results = (outs
AnyType:$result,
Optional<MLProgram_TokenType>:$produceToken
AnyType:$result
);
let assemblyFormat = [{
$global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
$global `:` type($result) attr-dict
}];
let extraClassDeclaration = [{
@ -238,6 +237,52 @@ def MLProgram_GlobalLoadConstOp : MLProgram_Op<"global_load_const", [
}];
}
//===----------------------------------------------------------------------===//
// GlobalLoadGraphOp
//===----------------------------------------------------------------------===//
def MLProgram_GlobalLoadGraphOp : MLProgram_Op<"global_load_graph", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Direct load of a mutable value from a global in Graph region";
let description = [{
Performs a non-atomic, non-volatile, non-synchronized load from a global
that may be mutable.
It is fully expected that these constraints are not suitable for all
situations, and alternative ops should be defined and used for more advanced
cases.
This op is side effecting and may not be valid to use in graph regions
without additional consideration to evaluation order constraints.
Example:
```mlir
%0, %cstr = ml_program.global_load_graph @foobar
ordering (%token -> !ml_program.token) : tensor<?xi32>
```
}];
let arguments = (ins
Arg<SymbolRefAttr, "", [MemRead]>:$global,
Variadic<MLProgram_TokenType>:$consumeTokens
);
let results = (outs
AnyType:$result,
MLProgram_TokenType:$produceToken
);
let assemblyFormat = [{
$global `` custom<TokenOrdering>($consumeTokens, type($produceToken)) `:` type($result) attr-dict
}];
let extraClassDeclaration = [{
/// Gets the corresponding GlobalOp (or nullptr).
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
}];
}
//===----------------------------------------------------------------------===//
// GlobalStoreOp
//===----------------------------------------------------------------------===//
@ -254,23 +299,66 @@ def MLProgram_GlobalStoreOp : MLProgram_Op<"global_store", [
all situations, and alternative ops should be defined and used for more
advanced cases.
This op is side effecting and may not be valid to use in graph regions
without additional consideration to evaluation order constraints. See
`global_store_graph` for op which allows for explicit ordering constraints.
Example:
```mlir
ml_program.global_store @foobar = %0 : tensor<?xi32>
```
}];
let arguments = (ins
Arg<SymbolRefAttr, "", [MemWrite]>:$global,
AnyType:$value
);
let assemblyFormat = [{
$global `=` $value `:` type($value) attr-dict
}];
let extraClassDeclaration = [{
/// Gets the corresponding GlobalOp (or nullptr).
GlobalOp getGlobalOp(SymbolTableCollection &symbolTable);
}];
}
//===----------------------------------------------------------------------===//
// GlobalStoreGraphOp
//===----------------------------------------------------------------------===//
def MLProgram_GlobalStoreGraphOp : MLProgram_Op<"global_store_graph", [
DeclareOpInterfaceMethods<SymbolUserOpInterface>
]> {
let summary = "Direct store of a value into a mutable global";
let description = [{
Performs a non-atomic, non-volatile, non-synchronized store to a mutable
global.
It is fully expected that these constraints are not suitable for
all situations, and alternative ops should be defined and used for more
advanced cases.
This op is side effecting and may not be valid to use in graph regions
without additional consideration to evaluation order constraints.
Example:
```mlir
ml_program.global_store @foobar = %0 : tensor<?xi32>
%token = ml_program.global_store @foobar = %0 : tensor<?xi32>
ordering (%in_token -> !ml_program.token) : tensor<?xi32>
```
}];
let arguments = (ins
Arg<SymbolRefAttr, "", [MemWrite]>:$global,
Arg<SymbolRefAttr, "", [MemRead]>:$global,
AnyType:$value,
Variadic<MLProgram_TokenType>:$consumeTokens
);
let results = (outs
Optional<MLProgram_TokenType>:$produceToken
MLProgram_TokenType:$produceToken
);
let assemblyFormat = [{

View File

@ -18,12 +18,11 @@ using namespace mlir::ml_program;
//===----------------------------------------------------------------------===//
/// Parse and print an ordering clause for a variadic of consuming tokens
/// and an optional producing token.
/// and an producing token.
///
/// Syntax:
/// ordering(%0, %1 -> !ml_program.token)
/// ordering(() -> !ml_program.token)
/// ordering(%0, %1)
///
/// If both the consuming and producing token are not present on the op, then
/// the clause prints nothing.
@ -46,10 +45,11 @@ static ParseResult parseTokenOrdering(
return failure();
}
// Parse optional producer token.
if (succeeded(parser.parseOptionalArrow()))
if (failed(parser.parseType(produceTokenType)))
return failure();
// Parse producer token.
if (failed(parser.parseArrow()))
return failure();
if (failed(parser.parseType(produceTokenType)))
return failure();
if (failed(parser.parseRParen()))
return failure();
@ -220,6 +220,30 @@ GlobalLoadConstOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
//===----------------------------------------------------------------------===//
// GlobalLoadGraphOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalLoadGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalLoadGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (referrent.getType() != getResult().getType()) {
return emitOpError() << "cannot load from global typed "
<< referrent.getType() << " as "
<< getResult().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// GlobalStoreOp
//===----------------------------------------------------------------------===//
@ -249,6 +273,35 @@ GlobalStoreOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
return success();
}
//===----------------------------------------------------------------------===//
// GlobalStoreGraphOp
//===----------------------------------------------------------------------===//
GlobalOp GlobalStoreGraphOp::getGlobalOp(SymbolTableCollection &symbolTable) {
return symbolTable.lookupNearestSymbolFrom<GlobalOp>(
getOperation()->getParentOp(), getGlobalAttr());
}
LogicalResult
GlobalStoreGraphOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
GlobalOp referrent = getGlobalOp(symbolTable);
if (!referrent)
return emitOpError() << "undefined global: " << getGlobal();
if (!referrent.getIsMutable()) {
return emitOpError() << "cannot store to an immutable global "
<< getGlobal();
}
if (referrent.getType() != getValue().getType()) {
return emitOpError() << "cannot store to a global typed "
<< referrent.getType() << " from "
<< getValue().getType();
}
return success();
}
//===----------------------------------------------------------------------===//
// SubgraphOp
//===----------------------------------------------------------------------===//

View File

@ -96,3 +96,17 @@ ml_program.func @store_immutable(%arg0: i64) {
ml_program.global_store @var = %arg0 : i64
ml_program.return
}
// -----
ml_program.global private mutable @global_mutable_undef : tensor<?xi32>
ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
%token1 = ml_program.token
%0, %token2 = ml_program.global_load_graph @global_mutable_undef
ordering(() -> !ml_program.token) : tensor<?xi32>
%token3 = ml_program.global_store_graph @global_mutable_undef = %0
// expected-error @+1 {{expected '->'}}
ordering(%token1, %token2) : tensor<?xi32>
ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
}

View File

@ -45,12 +45,12 @@ ml_program.func @global_load_store() {
// CHECK-LABEL: @global_load_store_tokens
ml_program.subgraph @global_load_store_tokens() -> (tensor<?xi32>, !ml_program.token) {
%token1 = ml_program.token
%0, %token2 = ml_program.global_load @global_mutable_undef
%0, %token2 = ml_program.global_load_graph @global_mutable_undef
ordering(() -> !ml_program.token) : tensor<?xi32>
%token3 = ml_program.global_store @global_mutable_undef = %0
%token3 = ml_program.global_store_graph @global_mutable_undef = %0
ordering(%token1, %token2 -> !ml_program.token) : tensor<?xi32>
ml_program.global_store @global_mutable_undef = %0
ordering(%token3) : tensor<?xi32>
%token4 = ml_program.global_store_graph @global_mutable_undef = %0
ordering(%token3 -> !ml_program.token) : tensor<?xi32>
ml_program.output %0, %token3 : tensor<?xi32>, !ml_program.token
}