[mlir] Add a shape op that returns a constant witness

This will later be used during canonicalization and folding steps to replace
statically known passing constraints.

Differential Revision: https://reviews.llvm.org/D80307
This commit is contained in:
Tres Popp 2020-05-20 15:56:12 +02:00
parent 5a675f0552
commit 1c3e38d98c
3 changed files with 41 additions and 8 deletions

View File

@ -473,11 +473,11 @@ def Shape_AssumingAllOp : Shape_Op<"assuming_all", [NoSideEffect]> {
Example:
```mlir
%w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Success
%w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
%w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
%w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Success
%w2 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
%wf = shape.assuming_all %w0, %w1 // Failure
%wt = shape.assuming_all %w0, %w2 // Success
%wt = shape.assuming_all %w0, %w2 // Passing
```
}];
@ -537,7 +537,7 @@ def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
Example:
```mlir
%w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Success
%w0 = shape.cstr_broadcastable [2,2], [3,1,2] // Passing
%w1 = shape.cstr_broadcastable [2,2], [3,2] // Failure
```
}];
@ -557,7 +557,7 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
Example:
```mlir
%w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Success
%w0 = shape.cstr_eq [1,2], [1,2], [1,2] // Passing
%w1 = shape.cstr_eq [2,2], [1,2] // Failure
```
}];
@ -567,6 +567,28 @@ def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
let assemblyFormat = "$inputs attr-dict";
}
// Canonicalization patterns.
def Shape_ConstWitnessOp : Shape_Op<"const_witness", [ConstantLike, NoSideEffect]> {
let summary = "An operation that returns a statically known witness value";
let description = [{
This operation represents a statically known witness result. This can be
often used to canonicalize/fold constraint and assuming code that will always
pass.
```mlir
%0 = shape.const_shape [1,2,3]
%1 = shape.const_shape [1, 2, 3]
%w0 = shape.cstr_eq(%0, %1) // Can be folded to "const_witness true"
%w1 = shape.const_witness true
%w2 = shape.assuming_all(%w0, %w2) // Can be folded to "const_witness true"
```
}];
let arguments = (ins BoolAttr:$passing);
let results = (outs Shape_WitnessType:$result);
let assemblyFormat = "$passing attr-dict";
let hasFolder = 1;
}
#endif // SHAPE_OPS

View File

@ -42,6 +42,9 @@ Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
if (auto sizeType = type.dyn_cast<SizeType>()) {
return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
}
if (auto witnessType = type.dyn_cast<WitnessType>()) {
return builder.create<ConstWitnessOp>(loc, type, value.cast<BoolAttr>());
}
return nullptr;
}
@ -229,6 +232,12 @@ OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shapeAttr(); }
OpFoldResult ConstSizeOp::fold(ArrayRef<Attribute>) { return valueAttr(); }
//===----------------------------------------------------------------------===//
// ConstWitnessOp
//===----------------------------------------------------------------------===//
OpFoldResult ConstWitnessOp::fold(ArrayRef<Attribute>) { return passingAttr(); }
//===----------------------------------------------------------------------===//
// IndexToSizeOp
//===----------------------------------------------------------------------===//

View File

@ -77,8 +77,10 @@ func @test_constraints() {
%1 = shape.const_shape [1, 2, 3]
%w0 = shape.cstr_broadcastable %0, %1
%w1 = shape.cstr_eq %0, %1
%w3 = shape.assuming_all %w0, %w1
shape.assuming %w3 -> !shape.shape {
%w2 = shape.const_witness true
%w3 = shape.const_witness false
%w4 = shape.assuming_all %w0, %w1, %w2, %w3
shape.assuming %w4 -> !shape.shape {
%2 = shape.any %0, %1
shape.assuming_yield %2 : !shape.shape
}