[mlir][shape] Add `shape.cstr_require %bool`

This op is a catch-all for creating witnesses from various random kinds
of constraints. In particular, I when dealing with extents directly,
which are of `index` type, one can directly use std ops for calculating
the predicates, and then use cstr_require for the final conversion to a
witness.

Differential Revision: https://reviews.llvm.org/D87871
This commit is contained in:
Sean Silva 2020-09-17 16:20:47 -07:00
parent 4926a5ee63
commit bae6374205
4 changed files with 58 additions and 2 deletions

View File

@ -738,5 +738,27 @@ def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect
let hasFolder = 1;
}
def Shape_CstrRequireOp : Shape_Op<"cstr_require", []> {
let summary = "Represents a runtime assertion that an i1 is `true`";
let description = [{
Represents a runtime assretion that an i1 is true. It returns a
!shape.witness to order this assertion.
For simplicity, prefer using other cstr_* ops if they are available for a
given constraint.
Example:
```mlir
%bool = ...
%w0 = shape.cstr_require %bool // Passing if `%bool` is true.
```
}];
let arguments = (ins I1:$pred);
let results = (outs Shape_WitnessType:$result);
let assemblyFormat = "$pred attr-dict";
let hasFolder = 1;
}
#endif // SHAPE_OPS

View File

@ -490,6 +490,14 @@ void ConstSizeOp::getAsmResultNames(
OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
//===----------------------------------------------------------------------===//
// CstrRequireOp
//===----------------------------------------------------------------------===//
OpFoldResult CstrRequireOp::fold(ArrayRef<Attribute> operands) {
return operands[0];
}
//===----------------------------------------------------------------------===//
// ShapeEqOp
//===----------------------------------------------------------------------===//

View File

@ -386,7 +386,31 @@ func @f(%arg0: !shape.shape, %arg1: !shape.shape) {
}
// -----
// cstr_require with constant can be folded
// CHECK-LABEL: func @cstr_require_fold
func @cstr_require_fold() {
// CHECK-NEXT: shape.const_witness true
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%true = constant true
%0 = shape.cstr_require %true
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
// -----
// cstr_require without constant cannot be folded
// CHECK-LABEL: func @cstr_require_no_fold
func @cstr_require_no_fold(%arg0: i1) {
// CHECK-NEXT: shape.cstr_require
// CHECK-NEXT: consume.witness
// CHECK-NEXT: return
%0 = shape.cstr_require %arg0
"consume.witness"(%0) : (!shape.witness) -> ()
return
}
// -----
// assuming_all with known passing witnesses can be folded
// CHECK-LABEL: func @f
func @f() {

View File

@ -100,12 +100,14 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> tensor<?xindex> {
func @test_constraints() {
%0 = shape.const_shape [] : !shape.shape
%1 = shape.const_shape [1, 2, 3] : !shape.shape
%true = constant true
%w0 = shape.cstr_broadcastable %0, %1 : !shape.shape, !shape.shape
%w1 = shape.cstr_eq %0, %1
%w2 = shape.const_witness true
%w3 = shape.const_witness false
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
shape.assuming %w4 -> !shape.shape {
%w4 = shape.cstr_require %true
%w_all = shape.assuming_all %w0, %w1, %w2, %w3, %w4
shape.assuming %w_all -> !shape.shape {
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
shape.assuming_yield %2 : !shape.shape
}