forked from OSchip/llvm-project
[mlir][shape] Add `shape.cstr_require %bool`
This op is a catch-all for creating witnesses from various random kinds of constraints. In particular, I when dealing with extents directly, which are of `index` type, one can directly use std ops for calculating the predicates, and then use cstr_require for the final conversion to a witness. Differential Revision: https://reviews.llvm.org/D87871
This commit is contained in:
parent
4926a5ee63
commit
bae6374205
|
@ -738,5 +738,27 @@ def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
|
||||
let summary = "Represents a runtime assertion that an i1 is `true`";
|
||||
let description = [{
|
||||
Represents a runtime assretion that an i1 is true. It returns a
|
||||
!shape.witness to order this assertion.
|
||||
|
||||
For simplicity, prefer using other cstr_* ops if they are available for a
|
||||
given constraint.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%bool = ...
|
||||
%w0 = shape.cstr_require %bool // Passing if `%bool` is true.
|
||||
```
|
||||
}];
|
||||
let arguments = (ins I1:$pred);
|
||||
let results = (outs Shape_WitnessType:$result);
|
||||
|
||||
let assemblyFormat = "$pred attr-dict";
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
#endif // SHAPE_OPS
|
||||
|
|
|
@ -490,6 +490,14 @@ void ConstSizeOp::getAsmResultNames(
|
|||
|
||||
OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// CstrRequireOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
|
||||
return operands[0];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ShapeEqOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -386,7 +386,31 @@ func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
|
|||
}
|
||||
|
||||
// -----
|
||||
// cstr_require with constant can be folded
|
||||
// CHECK-LABEL: func @cstr_require_fold
|
||||
func @cstr_require_fold() {
|
||||
// CHECK-NEXT: shape.const_witness true
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%true = constant true
|
||||
%0 = shape.cstr_require %true
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// cstr_require without constant cannot be folded
|
||||
// CHECK-LABEL: func @cstr_require_no_fold
|
||||
func @cstr_require_no_fold(%arg0: i1) {
|
||||
// CHECK-NEXT: shape.cstr_require
|
||||
// CHECK-NEXT: consume.witness
|
||||
// CHECK-NEXT: return
|
||||
%0 = shape.cstr_require %arg0
|
||||
"consume.witness"(%0) : (!shape.witness) -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
// assuming_all with known passing witnesses can be folded
|
||||
// CHECK-LABEL: func @f
|
||||
func @f() {
|
||||
|
|
|
@ -100,12 +100,14 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
|
|||
func @test_constraints() {
|
||||
%0 = shape.const_shape [] : !shape.shape
|
||||
%1 = shape.const_shape [1, 2, 3] : !shape.shape
|
||||
%true = constant true
|
||||
%w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
|
||||
%w1 = shape.cstr_eq %0, %1
|
||||
%w2 = shape.const_witness true
|
||||
%w3 = shape.const_witness false
|
||||
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
|
||||
shape.assuming %w4 -> !shape.shape {
|
||||
%w4 = shape.cstr_require %true
|
||||
%w_all = shape.assuming_all %w0, %w1, %w2, %w3, %w4
|
||||
shape.assuming %w_all -> !shape.shape {
|
||||
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
shape.assuming_yield %2 : !shape.shape
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue