[shape] Add min and max ops

These are element-wise operations that operates on shapes with equal ranks.
Also add missing printer/parser for join operator.

Differential Revision: https://reviews.llvm.org/D99986
This commit is contained in:
Jacques Pienaar 2021-04-06 17:58:12 -07:00
parent 86175d5fed
commit e74e6afcf1
2 changed files with 73 additions and 3 deletions

View File

@ -387,13 +387,52 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
used to return an error to the user upon mismatch of dimensions.
```mlir
%c = shape.join %a, %b, error="<reason>" : !shape.shape
%c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
```
}];
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrSizeType:$result);
let assemblyFormat = [{
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
type($arg0) `,` type($arg1) `->` type($result)
}];
}
def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
let summary = "Elementwise maximum";
let description = [{
Computes the elementwise maximum of two shapes with equal ranks. If either
operand is an error, then an error will be propagated to the result. If the
input types mismatch or the ranks do not match, then the result is an
error.
}];
let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
let results = (outs Shape_ShapeOrSizeType:$result);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
}
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
let summary = "Elementwise minimum";
let description = [{
Computes the elementwise maximum of two shapes with equal ranks. If either
operand is an error, then an error will be propagated to the result. If the
input types mismatch or the ranks do not match, then the result is an
error.
}];
let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
let results = (outs Shape_ShapeOrSizeType:$result);
let assemblyFormat = [{
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
}];
}
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {

View File

@ -115,7 +115,7 @@ func @test_constraints() {
}
func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
%rhs : tensor<?xindex>) {
%rhs : tensor<?xindex>) {
%w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
return
}
@ -183,7 +183,6 @@ func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
return %rank : index
}
func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
return %result : i1
@ -289,3 +288,35 @@ func @is_broadcastable_on_shapes(%a : !shape.shape,
: !shape.shape, !shape.shape
return %result : i1
}
func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
%0 = shape.const_shape [4, 57, 92] : !shape.shape
%1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
!shape.shape, !shape.shape -> !shape.shape
return %2 : !shape.shape
}
func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 5
%1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}
func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
%0 = shape.const_size 9
%1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}