forked from OSchip/llvm-project
[mlir] Improve documentation of shape dialect
Add small example of usage (brief which will be further refined).
This commit is contained in:
parent
7a8c7946fc
commit
0d0c46a35b
|
@ -0,0 +1,201 @@
|
|||
# 'shape' Dialect
|
||||
|
||||
Description of operations & types within the Shape dialect as well as their
|
||||
[usage](#different-stages-of-lowering-shape-dialect).
|
||||
|
||||
[include "Dialects/ShapeDialect.md"]
|
||||
|
||||
## Different stages of lowering Shape dialect
|
||||
|
||||
In this section we shall give a brief overview of the different uses of the
|
||||
shape dialect and the lowering between these uses. Currently we have 3 worlds /
|
||||
stages of lowering of shape functions:
|
||||
|
||||
1. _Error monadic/error carrying/user specification_:
|
||||
This "input" form carries both the shape and whether in error state as
|
||||
value. Hence at this level all operations are pure operations producing and
|
||||
consuming values where the values could represent an error.
|
||||
|
||||
2. _Constrained_:
|
||||
This form uses a variant of explicit evidence passing to allow leveraging
|
||||
existing compiler infrastructure to preserve safety information during
|
||||
optimization.
|
||||
|
||||
3. _Side-effecting/asserting_:
|
||||
This final lowered form is imperative form with side-effecting ops (e.g.,
|
||||
assert) for final codegen.
|
||||
|
||||
We are going to do a quick step through of the lowering using the example of
|
||||
a matmul.
|
||||
|
||||
Starting from the shape function of matmul in the error monadic form
|
||||
below[^wip_form1]:
|
||||
|
||||
```mlir
|
||||
shape.function_library @shplib {
|
||||
|
||||
builtin.func @matmul(%lhs: !shape.value_shape, %rhs: !shape.value_shape) -> !shape.shape {
|
||||
%c1 = shape.const_size 1
|
||||
%c2 = shape.const_size 2
|
||||
// We could also allow rank etc operations directly on value_shape too, that
|
||||
// would make it nicer as "input" language, but keeping it explicit inside the
|
||||
// IR instead and then we could have helper methods in front-end language.
|
||||
%lhs_shape = shape.shape_of %lhs : !shape.value_shape -> !shape.shape
|
||||
%rhs_shape = shape.shape_of %rhs : !shape.value_shape -> !shape.shape
|
||||
%lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size
|
||||
%rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size
|
||||
// This is not minimal as one could ensure the ranks are the same below, also a
|
||||
// variadic meet would make it more concise too.
|
||||
%r = "shape.meet"(%lhs_rank, %rhs_rank) : (!shape.size, !shape.size) -> !shape.size
|
||||
%rank = shape.meet %c2, %r, error="requires rank 2 operands" :
|
||||
!shape.size, !shape.size -> !shape.size
|
||||
%l0, %l1 = "shape.split_at"(%lhs_shape, %c1) :
|
||||
(!shape.shape, !shape.size) -> (!shape.shape, !shape.shape)
|
||||
%r0, %r1 = "shape.split_at"(%rhs_shape, %c1) :
|
||||
(!shape.shape, !shape.size) -> (!shape.shape, !shape.shape)
|
||||
%c = shape.meet %l1, %r0, error="inner dimensions required to match" :
|
||||
!shape.shape, !shape.shape -> !shape.shape
|
||||
%res = shape.concat %l0, %r1
|
||||
// Should have `shape.return %res requires %c, %rank` to enable
|
||||
return %res : !shape.shape
|
||||
}
|
||||
|
||||
} mapping {
|
||||
foo.matmul = @matmul
|
||||
}
|
||||
```
|
||||
|
||||
* We are using the default builtin func and return here. Preferably we'd use
|
||||
‘shape\_func’ as a special function op that allows passing multiple results
|
||||
back that affect correct execution (e.g., serves as an error join)
|
||||
* This would also means one can't reify it inside a regular function
|
||||
without handling the shape.return - that is a feature here as these are
|
||||
more of a template.
|
||||
* Currently we also have not marked `meet` as having no side-effects to
|
||||
avoid DCE until we have `shape.return`, at which point computing the
|
||||
meet could be treated as purely computational returning error.
|
||||
* Meet represents a constraint that should hold, so should not be used to see
|
||||
*if* something is equal. E.g., this means `meet` can't be used to represent
|
||||
|
||||
```
|
||||
either(meet(x, y), meet(y,z))
|
||||
```
|
||||
|
||||
* This could have been written more concisely as something like
|
||||
|
||||
```
|
||||
concat(lhs[0], rhs[1]) if rank(lhs) == 2 &&
|
||||
rank(rhs) == 2 && lhs[1] == rhs[0]
|
||||
```
|
||||
|
||||
but not focusing on front-end proper here.
|
||||
|
||||
We are going to lower to "most" nested form directly (see
|
||||
[test](https://github.com/tensorflow/tensorflow/blob/64062b5c51e04e370df26551d247496787d3f5c2/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir#L3088)
|
||||
for an example reification along with legalization). In the above this was in a
|
||||
separate shape function library, while here we would normally reify it as part
|
||||
of lowering, but for simplicity will show as a standalone shape function.
|
||||
|
||||
```mlir
|
||||
func @matmul_shape1(%lhs: tensor<*xf32>, %rhs: tensor<*xindex>) -> tensor<?xindex> {
|
||||
%c1 = shape.const_size 1
|
||||
%c2 = shape.const_size 2
|
||||
// We allow `shape.shape_of` to return either a `!shape.shape` or
|
||||
// `tensor<?xindex>` type, in the case where the input is a tensor the most
|
||||
// refined type is a tensor of `index` but not required.
|
||||
%lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> !shape.shape
|
||||
%rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> !shape.shape
|
||||
%lhs_rank = shape.rank %lhs_shape : !shape.shape -> !shape.size
|
||||
%rhs_rank = shape.rank %rhs_shape : !shape.shape -> !shape.size
|
||||
%w1 = shape.cstr_eq %lhs_rank, %rhs_rank : !shape.witness
|
||||
%res = shape.assuming %w1 -> tensor<?xindex> {
|
||||
%r1 = shape.any %lhs_rank, %rhs_rank : (!shape.size, !shape.size) -> !shape.size
|
||||
// Error message needs an addition, currently only on cstr_require.
|
||||
%w2 = shape.cstr_eq %c2, %r1, error="requires rank 2 operands"
|
||||
%res_1 = shape.assuming %w2 -> tensor<?xindex> {
|
||||
// Here the lowered
|
||||
// %rank = shape.any %c2, %r1 (!shape.size, !shape.size) -> !shape.size
|
||||
// is dead and so elided further. But if `%rank` was actually consumed,
|
||||
// then it could have been folded in `shape.any`.
|
||||
%l0, %r0 = "shape.split_at"(%lhs_shape, %c1) :
|
||||
(!shape.shape, !shape.size) -> !shape.shape
|
||||
%l1, %r1 = "shape.split_at"(%lhs_shape, %c1) :
|
||||
(!shape.shape, !shape.size) -> !shape.shape
|
||||
%c = shape.meet %l1, %r0, error="inner dimensions required to match" :
|
||||
!shape.size, !shape.size -> !shape.size
|
||||
%res = concat(%l0, %r1)
|
||||
shape.assuming_yield %res
|
||||
}
|
||||
shape.assuming_yield %res_1
|
||||
}
|
||||
return %res : tensor<?xindex>
|
||||
}
|
||||
```
|
||||
|
||||
We can now hoist computations of constraint were possible (which in the case
|
||||
below is not too many as we need to verify the rank before we can split)
|
||||
|
||||
```mlir
|
||||
func @matmul_shape2(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> {
|
||||
%c1 = shape.const_size 1
|
||||
%c2 = shape.const_size 2
|
||||
%lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex>
|
||||
%rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex>
|
||||
%lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index>
|
||||
%rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index>
|
||||
%w1 = shape.cstr_eq %c2, %lhs_rank, error="requires rank 2 operands"
|
||||
%w2 = shape.cstr_eq %c2, %rhs_rank, error="requires rank 2 operands"
|
||||
%w = shape.assuming_all %w1, %w2
|
||||
%res = shape.assuming %w -> tensor<?xindex> {
|
||||
%l0, %r0 = "shape.split_at"(%lhs_shape, %c1) :
|
||||
(tensor<?xindex>, !shape.size) -> tensor<?xindex>
|
||||
%l1, %r1 = "shape.split_at"(%lhs_shape, %c1) :
|
||||
(tensor<?xindex>, !shape.size) -> tensor<?xindex>
|
||||
%w3 = shape.cstr_eq %l1, %r0, error="inner dimensions required to match"
|
||||
%res_2 = shape.assuming %w3 {
|
||||
%res = concat(%l0, %r1)
|
||||
shape.assuming_yield %res
|
||||
}
|
||||
shape.assuming_yield %res_1
|
||||
}
|
||||
return %res
|
||||
}
|
||||
```
|
||||
|
||||
The above form can now be lowered to the fully imperative form (see
|
||||
[test](https://github.com/tensorflow/mlir-hlo/blob/af14e1ded33c3164d4418c5d234b5b346b6d017c/tests/rank-specialization.mlir#L22)
|
||||
for example).
|
||||
|
||||
```mlir
|
||||
func @matmul_shape3(%lhs: tensor<*xf32>, %lhs: tensor<*xf32>) -> tensor<?xindex> {
|
||||
%c1 = arith.constant 1 : index
|
||||
%c2 = arith.constant 2 : index
|
||||
%lhs_shape = shape.shape_of %lhs : tensor<*xf32> -> tensor<?xindex>
|
||||
%rhs_shape = shape.shape_of %rhs : tensor<*xf32> -> tensor<?xindex>
|
||||
%lhs_rank = shape.rank %lhs_shape : tensor<?xindex> -> tensor<index>
|
||||
%rhs_rank = shape.rank %rhs_shape : tensor<?xindex> -> tensor<index>
|
||||
%w1 = shape.shape_eq %lhs_rank, %rhs_rank
|
||||
%w2 = shape.shape_eq %c2, %lhs_rank
|
||||
%w3 = and %w1, %w2
|
||||
assert %w3, "requires rank 2 operands"
|
||||
%l0, %l1 = shape.split_at(%lhs_shape, %c1) : tensor<?xindex>
|
||||
%r0, %r1 = shape.split_at(%rhs_shape, %c1) : tensor<?xindex>
|
||||
%w4 = shape.eq %l1, %r0
|
||||
assert %w4, "inner dimensions required to match"
|
||||
%res = concat(%l0, %r1)
|
||||
return %res
|
||||
}
|
||||
```
|
||||
|
||||
* In this case form 3 is as easy and closer to form 1 (but only as no
|
||||
reordering was required). So it is a good question if the frontend authoring
|
||||
language could be more similar to the imperative form (under discussion).
|
||||
* The above form presented here is an intermittent form during a lowering
|
||||
pass. If used as input we would need to restrict the optimizations on it as
|
||||
the `shape` dialect operations are no longer connected by producer-consumer
|
||||
to enforce guard checking.
|
||||
|
||||
The above could be further lowered by using `tensor.dim`, `tensor.from_elements`
|
||||
etc (or one could even lower these by way of, say, MHLO or TOSA dialect).
|
||||
|
||||
[^wip_form1]: This form is least use inside the current workflows and needs more work. In particular in the example we use `shape_func` where in the code we instead use standard func as first form 1 isn't used explicitly.
|
Loading…
Reference in New Issue