forked from OSchip/llvm-project
[shape] Add min and max ops
These are element-wise operations that operates on shapes with equal ranks. Also add missing printer/parser for join operator. Differential Revision: https://reviews.llvm.org/D99986
This commit is contained in:
parent
86175d5fed
commit
e74e6afcf1
|
@ -387,13 +387,52 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> {
|
|||
used to return an error to the user upon mismatch of dimensions.
|
||||
|
||||
```mlir
|
||||
%c = shape.join %a, %b, error="<reason>" : !shape.shape
|
||||
%c = shape.join %a, %b, error="<reason>" : !shape.shape, !shape.shape -> !shape.shape
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
|
||||
OptionalAttr<StrAttr>:$error);
|
||||
let results = (outs Shape_ShapeOrSizeType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
|
||||
type($arg0) `,` type($arg1) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MaxOp : Shape_Op<"max", [Commutative, NoSideEffect]> {
|
||||
let summary = "Elementwise maximum";
|
||||
let description = [{
|
||||
Computes the elementwise maximum of two shapes with equal ranks. If either
|
||||
operand is an error, then an error will be propagated to the result. If the
|
||||
input types mismatch or the ranks do not match, then the result is an
|
||||
error.
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
|
||||
let results = (outs Shape_ShapeOrSizeType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MinOp : Shape_Op<"min", [Commutative, NoSideEffect]> {
|
||||
let summary = "Elementwise minimum";
|
||||
let description = [{
|
||||
Computes the elementwise maximum of two shapes with equal ranks. If either
|
||||
operand is an error, then an error will be propagated to the result. If the
|
||||
input types mismatch or the ranks do not match, then the result is an
|
||||
error.
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeOrSizeType:$lhs, Shape_ShapeOrSizeType:$rhs);
|
||||
let results = (outs Shape_ShapeOrSizeType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$lhs `,` $rhs attr-dict `:` type($lhs) `,` type($rhs) `->` type($result)
|
||||
}];
|
||||
}
|
||||
|
||||
def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> {
|
||||
|
|
|
@ -115,7 +115,7 @@ func @test_constraints() {
|
|||
}
|
||||
|
||||
func @eq_on_extent_tensors(%lhs : tensor<?xindex>,
|
||||
%rhs : tensor<?xindex>) {
|
||||
%rhs : tensor<?xindex>) {
|
||||
%w0 = shape.cstr_eq %lhs, %rhs : tensor<?xindex>, tensor<?xindex>
|
||||
return
|
||||
}
|
||||
|
@ -183,7 +183,6 @@ func @rank_on_extent_tensor(%shape : tensor<?xindex>) -> index {
|
|||
return %rank : index
|
||||
}
|
||||
|
||||
|
||||
func @shape_eq_on_shapes(%a : !shape.shape, %b : !shape.shape) -> i1 {
|
||||
%result = shape.shape_eq %a, %b : !shape.shape, !shape.shape
|
||||
return %result : i1
|
||||
|
@ -289,3 +288,35 @@ func @is_broadcastable_on_shapes(%a : !shape.shape,
|
|||
: !shape.shape, !shape.shape
|
||||
return %result : i1
|
||||
}
|
||||
|
||||
func @shape_upper_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
|
||||
%0 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%1 = shape.max %a, %0 : !shape.shape, !shape.shape -> !shape.shape
|
||||
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
|
||||
!shape.shape, !shape.shape -> !shape.shape
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
|
||||
func @shape_lower_bounded_by_constant(%a: !shape.shape) -> !shape.shape {
|
||||
%0 = shape.const_shape [4, 57, 92] : !shape.shape
|
||||
%1 = shape.min %a, %0 : !shape.shape, !shape.shape -> !shape.shape
|
||||
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
|
||||
!shape.shape, !shape.shape -> !shape.shape
|
||||
return %2 : !shape.shape
|
||||
}
|
||||
|
||||
func @size_upper_bounded_by_constant(%a: !shape.size) -> !shape.size {
|
||||
%0 = shape.const_size 5
|
||||
%1 = shape.max %a, %0 : !shape.size, !shape.size -> !shape.size
|
||||
%2 = shape.join %0, %1, error="exceeded element-wise upper bound" :
|
||||
!shape.size, !shape.size -> !shape.size
|
||||
return %2 : !shape.size
|
||||
}
|
||||
|
||||
func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
|
||||
%0 = shape.const_size 9
|
||||
%1 = shape.min %a, %0 : !shape.size, !shape.size -> !shape.size
|
||||
%2 = shape.join %0, %1, error="lower bound element-wise exceeded" :
|
||||
!shape.size, !shape.size -> !shape.size
|
||||
return %2 : !shape.size
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue