forked from OSchip/llvm-project
[MLIR] Rename Shape dialect's `join` to `meet`.
For the type lattice, we (now) use the "less specialized or equal" partial order, leading to the bottom representing the empty set, and the top representing any type. This naming is more in line with the generally used conventions, where the top of the lattice is the full set, and the bottom of the lattice is the empty set. A typical example is the powerset of a finite set: generally, meet would be the intersection, and join would be the union. ``` top: {a,b,c} / | \ {a,b} {a,c} {b,c} | X X | {a} { b } {c} \ | / bottom: { } ``` This is in line with the examined lattice representations in LLVM: * lattice for `BitTracker::BitValue` in `Hexagon/BitTracker.h` * lattice for constant propagation in `HexagonConstPropagation.cpp` * lattice in `VarLocBasedImpl.cpp` * lattice for address space inference code in `InferAddressSpaces.cpp` Reviewed By: silvas, jpienaar Differential Revision: https://reviews.llvm.org/D110766
This commit is contained in:
parent
1301a8b473
commit
fd9613324d
|
@ -397,51 +397,6 @@ def Shape_IndexToSizeOp : Shape_Op<"index_to_size", [NoSideEffect]> {
|
|||
let hasCanonicalizer = 1;
|
||||
}
|
||||
|
||||
def Shape_JoinOp : Shape_Op<"join",
|
||||
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Returns the least general shape.shape of its operands";
|
||||
let description = [{
|
||||
An operation that computes the least general shape of input operands.
|
||||
This effectively asserts that corresponding static dimensions are equal.
|
||||
The behavior is to match each element of the `shape.shape` and propagate the
|
||||
most restrictive information, returning an invalid shape if there are
|
||||
contradictory requirements. E.g., using pseudo code
|
||||
|
||||
```
|
||||
shape.join([*], [*]) -> [*]
|
||||
shape.join([*], [1, ?]) -> [1, ?]
|
||||
shape.join([1, 2], [1, ?]) -> [1, 2]
|
||||
shape.join([*], [1, 2]) -> [1, 2]
|
||||
shape.join([], []) -> []
|
||||
shape.join([], [*]) -> []
|
||||
shape.join([], [?, ?]) -> [invalid]
|
||||
shape.join([1, ?], [2, ?, ?]) -> [invalid]
|
||||
```
|
||||
|
||||
`shape.join` also allows specifying an optional error string, that may be
|
||||
used to return an error to the user upon mismatch of dimensions.
|
||||
|
||||
```mlir
|
||||
%c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
|
||||
OptionalAttr<StrAttr>:$error);
|
||||
let results = (outs Shape_ShapeOrSizeType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
|
||||
type($arg0) `,` type($arg1) `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MaxOp : Shape_Op<"max",
|
||||
[Commutative, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
|
@ -469,6 +424,51 @@ def Shape_MaxOp : Shape_Op<"max",
|
|||
}];
|
||||
}
|
||||
|
||||
def Shape_MeetOp : Shape_Op<"meet",
|
||||
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Returns the least general shape.shape of its operands";
|
||||
let description = [{
|
||||
An operation that computes the least general shape of input operands.
|
||||
This effectively asserts that corresponding static dimensions are equal.
|
||||
The behavior is to match each element of the `shape.shape` and propagate the
|
||||
most restrictive information, returning an invalid shape if there are
|
||||
contradictory requirements. E.g., using pseudo code
|
||||
|
||||
```
|
||||
shape.meet([*], [*]) -> [*]
|
||||
shape.meet([*], [1, ?]) -> [1, ?]
|
||||
shape.meet([1, 2], [1, ?]) -> [1, 2]
|
||||
shape.meet([*], [1, 2]) -> [1, 2]
|
||||
shape.meet([], []) -> []
|
||||
shape.meet([], [*]) -> []
|
||||
shape.meet([], [?, ?]) -> [invalid]
|
||||
shape.meet([1, ?], [2, ?, ?]) -> [invalid]
|
||||
```
|
||||
|
||||
`shape.meet` also allows specifying an optional error string, that may be
|
||||
used to return an error to the user upon mismatch of dimensions.
|
||||
|
||||
```mlir
|
||||
%c = shape.meet %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
|
||||
OptionalAttr<StrAttr>:$error);
|
||||
let results = (outs Shape_ShapeOrSizeType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
|
||||
type($arg0) `,` type($arg1) `->` type($result)
|
||||
}];
|
||||
|
||||
let extraClassDeclaration = [{
|
||||
// Returns when two result types are compatible for this op; method used by
|
||||
// InferTypeOpInterface
|
||||
static bool isCompatibleReturnTypes(TypeRange l, TypeRange r);
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MinOp : Shape_Op<"min",
|
||||
[Commutative, NoSideEffect,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
|
|
|
@ -1177,10 +1177,10 @@ OpFoldResult IsBroadcastableOp::fold(ArrayRef<Attribute> operands) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// JoinOp
|
||||
// MeetOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult mlir::shape::JoinOp::inferReturnTypes(
|
||||
LogicalResult mlir::shape::MeetOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
|
@ -1188,7 +1188,7 @@ LogicalResult mlir::shape::JoinOp::inferReturnTypes(
|
|||
return success();
|
||||
}
|
||||
|
||||
bool mlir::shape::JoinOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||
if (l.size() != 1 || r.size() != 1)
|
||||
return false;
|
||||
if (l == r)
|
||||
|
|
|
@ -65,7 +65,7 @@ func @test_broadcast_extents() -> tensor<?xindex> {
|
|||
func @test_shape_any_fixed() {
|
||||
%0 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%1 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
@ -73,7 +73,7 @@ func @test_shape_any_fixed() {
|
|||
func @test_shape_any_unknown() {
|
||||
%0 = shape.const_shape [4, -1, 92] : !shape.shape
|
||||
%1 = shape.const_shape [-1, 57, 92] : !shape.shape
|
||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
@ -81,7 +81,7 @@ func @test_shape_any_unknown() {
|
|||
func @test_shape_any_fixed_mismatch() {
|
||||
%0 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%1 = shape.const_shape [2, 57, 92] : !shape.shape
|
||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
@ -243,7 +243,7 @@ func @num_elements_shape(%arg : !shape.shape) -> !shape.size {
|
|||
func @shape_equal_shapes(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
|
||||
%0 = shape.shape_of %a : !shape.value_shape -> !shape.shape
|
||||
%1 = shape.shape_of %b : !shape.value_shape -> !shape.shape
|
||||
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
%2 = "shape.meet"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) -> !shape.shape {
|
||||
|
@ -293,7 +293,7 @@ func @is_broadcastable_on_shapes(%a : !shape.shape,
|
|||
func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
|
||||
%0 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
|
||||
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
|
||||
%2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
|
||||
!shape.shape, !shape.shape -> !shape.shape
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
|
@ -301,7 +301,7 @@ func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
|
|||
func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
|
||||
%0 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
|
||||
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
|
||||
%2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
|
||||
!shape.shape, !shape.shape -> !shape.shape
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
|
@ -309,7 +309,7 @@ func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
|
|||
func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
|
||||
%0 = shape.const_size 5
|
||||
%1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
|
||||
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
|
||||
%2 = shape.meet %0, %1, error="exceeded element-wise upper bound" :
|
||||
!shape.size, !shape.size -> !shape.size
|
||||
return %2 : !shape.size
|
||||
}
|
||||
|
@ -317,7 +317,7 @@ func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
|
|||
func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
|
||||
%0 = shape.const_size 9
|
||||
%1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
|
||||
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
|
||||
%2 = shape.meet %0, %1, error="lower bound element-wise exceeded" :
|
||||
!shape.size, !shape.size -> !shape.size
|
||||
return %2 : !shape.size
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue