[MLIR] Rename Shape dialect's `join` to `meet`.

For the type lattice, we (now) use the "less specialized or equal" partial
order, leading to the bottom representing the empty set, and the top
representing any type.

This naming is more in line with the generally used conventions, where the top
of the lattice is the full set, and the bottom of the lattice is the empty set.
A typical example is the powerset of a finite set: generally, meet would be the
intersection, and join would be the union.

```
top:  {a,b,c}
     /   |   \
 {a,b} {a,c} {b,c}
   |  X     X  |
   {a} { b } {c}
      \  |  /
bottom: { }
```

This is in line with the examined lattice representations in LLVM:
* lattice for `BitTracker::BitValue` in `Hexagon/BitTracker.h`
* lattice for constant propagation in `HexagonConstPropagation.cpp`
* lattice in `VarLocBasedImpl.cpp`
* lattice for address space inference code in `InferAddressSpaces.cpp`

Reviewed By: silvas, jpienaar

Differential Revision: https://reviews.llvm.org/D110766
This commit is contained in:
Alexandre Rames 2021-10-05 10:53:02 -07:00
parent 1301a8b473
commit fd9613324d
3 changed files with 56 additions and 56 deletions

View File

@ -397,51 +397,6 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
let hasCanonicalizer = 1;
}
def Shape_JoinOp : Shape_Op<"join",
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the least general shape.shape of its operands";
let description = [{
An operation that computes the least general shape of input operands.
This effectively asserts that corresponding static dimensions are equal.
The behavior is to match each element of the `shape.shape` and propagate the
most restrictive information, returning an invalid shape if there are
contradictory requirements. E.g., using pseudo code
```
shape.join([*], [*]) -> [*]
shape.join([*], [1, ?]) -> [1, ?]
shape.join([1, 2], [1, ?]) -> [1, 2]
shape.join([*], [1, 2]) -> [1, 2]
shape.join([], []) -> []
shape.join([], [*]) -> []
shape.join([], [?, ?]) -> [invalid]
shape.join([1, ?], [2, ?, ?]) -> [invalid]
```
`shape.join` also allows specifying an optional error string, that may be
used to return an error to the user upon mismatch of dimensions.
```mlir
%c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
```
}];
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrSizeType:$result);
let assemblyFormat = [{
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
type($arg0) `,` type($arg1) `->` type($result)
}];
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_MaxOp : Shape_Op<"max",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
@ -469,6 +424,51 @@ def Shape_MaxOp : Shape_Op<"max",
}];
}
def Shape_MeetOp : Shape_Op<"meet",
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the least general shape.shape of its operands";
let description = [{
An operation that computes the least general shape of input operands.
This effectively asserts that corresponding static dimensions are equal.
The behavior is to match each element of the `shape.shape` and propagate the
most restrictive information, returning an invalid shape if there are
contradictory requirements. E.g., using pseudo code
```
shape.meet([*], [*]) -> [*]
shape.meet([*], [1, ?]) -> [1, ?]
shape.meet([1, 2], [1, ?]) -> [1, 2]
shape.meet([*], [1, 2]) -> [1, 2]
shape.meet([], []) -> []
shape.meet([], [*]) -> []
shape.meet([], [?, ?]) -> [invalid]
shape.meet([1, ?], [2, ?, ?]) -> [invalid]
```
`shape.meet` also allows specifying an optional error string, that may be
used to return an error to the user upon mismatch of dimensions.
```mlir
%c = shape.meet %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
```
}];
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrSizeType:$result);
let assemblyFormat = [{
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
type($arg0) `,` type($arg1) `->` type($result)
}];
let extraClassDeclaration = [{
// Returns when two result types are compatible for this op; method used by
// InferTypeOpInterface
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
}];
}
def Shape_MinOp : Shape_Op<"min",
[Commutative, NoSideEffect,
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {

View File

@ -1177,10 +1177,10 @@ OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
// JoinOp
// MeetOp
//===----------------------------------------------------------------------===//
LogicalResult mlir::shape::JoinOp::inferReturnTypes(
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
@ -1188,7 +1188,7 @@ LogicalResult mlir::shape::JoinOp::inferReturnTypes(
return success();
}
bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
if (l.size() != 1 || r.size() != 1)
return false;
if (l == r)

View File

@ -65,7 +65,7 @@ func @test_broadcast_extents() -> tensor<?xindex> {
func @test_shape_any_fixed() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [4, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
@ -73,7 +73,7 @@ func @test_shape_any_fixed() {
func @test_shape_any_unknown() {
%0 = shape.const_shape [4, -1, 92] : !shape.shape
%1 = shape.const_shape [-1, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
@ -81,7 +81,7 @@ func @test_shape_any_unknown() {
func @test_shape_any_fixed_mismatch() {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.const_shape [2, 57, 92] : !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
@ -243,7 +243,7 @@ func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
%0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
%1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
return %2 : !shape.shape
}
func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
@ -293,7 +293,7 @@ func @is_broadcastable_on_shapes(%a : !shape.shape,
func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
%2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
@ -301,7 +301,7 @@ func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
%2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
@ -309,7 +309,7 @@ func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 5
%1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
%2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}
@ -317,7 +317,7 @@ func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 9
%1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
%2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}