[MLIR] Add shape.witness type and ops

Summary: These represent shape based preconditions on execution of code.

Differential Revision: https://reviews.llvm.org/D79717
This commit is contained in:
Tres Popp 2020-05-07 14:24:30 +02:00
parent 9d4b4f344d
commit a26883e5aa
4 changed files with 196 additions and 2 deletions

View File

@ -30,7 +30,8 @@ enum Kind {
Shape,
Size,
ValueShape,
LAST_SHAPE_TYPE = ValueShape
Witness,
LAST_SHAPE_TYPE = Witness
};
} // namespace ShapeTypes
@ -105,6 +106,22 @@ public:
}
};
/// The Witness represents a runtime constraint, to be used as shape related
/// preconditions on code execution.
class WitnessType : public Type::TypeBase<WitnessType, Type> {
public:
using Base::Base;
static WitnessType get(MLIRContext *context) {
return Base::get(context, ShapeTypes::Kind::Witness);
}
/// Support method to enable LLVM-style type casting.
static bool kindof(unsigned kind) {
return kind == ShapeTypes::Kind::Witness;
}
};
#define GET_OP_CLASSES
#include "mlir/Dialect/Shape/IR/ShapeOps.h.inc"

View File

@ -17,6 +17,32 @@ include "mlir/Dialect/Shape/IR/ShapeBase.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
def Shape_WitnessType : DialectType<ShapeDialect,
CPred<"$_self.isa<::mlir::shape::WitnessType>()">, "witness">,
BuildableType<"$_builder.getType<::mlir::shape::WitnessType>()"> {
let typeDescription = [{
A witness is a structural device in the compiler to maintain ordering of
code relying on information obtained from passing assertions. Witnesses do
not represent any physical data.
"cstr_" operations will return witnesses and be lowered into assertion logic
when not resolvable at compile time.
"assuming_" operations will take witnesses as input and represent only
information to the compiler, so they do not exist in executing code. Code
that is dependent on "assuming_" operations can assume all cstr operations
transitively before are honored as true.
These abstractions are intended to allow the compiler more freedom with
assertions by merely showing the assertion through dataflow at this time
rather than a side effecting operation that acts as a barrier. This can be
viewed similarly to a compiler representation of promises from asynchronous,
possibly crashing assertions. Reliant code will not be reordered to before
the code and non-reliant code can be reordered freely, and there are no
guarantees on the final ordering of the assertions or their related code.
}];
}
//===----------------------------------------------------------------------===//
// Shape op definitions
//===----------------------------------------------------------------------===//
@ -313,4 +339,123 @@ def Shape_ConcatOp : Shape_Op<"concat",
let hasFolder = 1;
}
//===----------------------------------------------------------------------===//
// Shape constraint related ops.
//===----------------------------------------------------------------------===//
//TODO(tpopp): Move the code below and witnesses to a different file.
def Shape_AnyOp : Shape_Op<"any",
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Return any combination of the input shapes.";
let description = [{
This operation takes multiple input shapes and returns some combination of
their dimensions. This can be best seen with examples below.
The result is undefined, but still side-effect free, in cases where the
inputs have differing ranks or differ in extents of shared dimensions.
Example:
```mlir
%s0 = shape.any([2,?], [?,3]) // [2,3]
%s1 = shape.any([?,?], [1,2]) // [1,2]
```
}];
let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
let results = (outs Shape_ShapeType:$result);
}
def Shape_AssumingAllOp : Shape_Op<"assuming_all", []> {
let summary = "Return a logical AND of all witnesses.";
let description = [{
Used to simplify constraints as any single failing precondition is enough
to prevent execution.
"assuming" operations represent an execution order restriction to the
compiler, information for dependent code to rely on (by assuming), and
nothing else. They should not exist after a program is fully lowered and
ready to execute.
Example:
```mlir
%w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success
%w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure
%w2 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success
%wf = shape.assume_all(%w0, %w1) // Failure
%wt = shape.assume_all(%w0, %w2) // Success
```
}];
let arguments = (ins Variadic<Shape_WitnessType>:$inputs);
let results = (outs Shape_WitnessType:$result);
}
def Shape_AssumingOp : Shape_Op<"assuming",
[SingleBlockImplicitTerminator<"AssumingYieldOp">,
RecursiveSideEffects]> {
let summary = "Execute the region.";
let description = [{
Executes the region assuming all witnesses are true.
"assuming" operations represent an execution order restriction to the
compiler, information for dependent code to rely on (by assuming), and
nothing else. They should not exist after a program is fully lowered and
ready to execute.
}];
let arguments = (ins Shape_WitnessType);
let regions = (region SizedRegion<1>:$thenRegion);
let results = (outs Variadic<AnyType>:$results);
}
def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> {
let summary = "Yield operation";
let description = [{
This yield operation represents a return operation within the assert_and_exec
region. The operation takes variable number of operands and produces no
results. The operand number and types must match the return signature of
the region that contains the operation.
}];
let arguments = (ins Variadic<AnyType>:$operands);
}
def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
let summary = "Determines if 2 shapes can be successfully broadcasted.";
let description = [{
Given 2 input shapes, return a witness specifying if they are broadcastable.
This broadcastable follows the same logic as what shape.broadcast documents.
"cstr" operations represent runtime assertions.
Example:
```mlir
%w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success
%w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure
```
}];
let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
let results = (outs Shape_WitnessType:$result);
}
def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
let summary = "Determines if all input shapes are equal.";
let description = [{
Given 1 or more input shapes, determine if all shapes are the exact same.
"cstr" operations represent runtime assertions.
Example:
```mlir
%w0 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success
%w1 = shape.cstr_eq([2,2], [1,2]) // Failure
```
}];
let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
let results = (outs Shape_WitnessType:$result);
}
// Canonicalization patterns.
#endif // SHAPE_OPS

View File

@ -24,7 +24,8 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
#define GET_OP_LIST
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
>();
addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
WitnessType>();
// Allow unknown operations during prototyping and testing. As the dialect is
// still evolving it makes it simple to start with an unregistered ops and
// try different variants before actually defining the op.
@ -60,6 +61,8 @@ Type ShapeDialect::parseType(DialectAsmParser &parser) const {
return SizeType::get(getContext());
if (keyword == "value_shape")
return ValueShapeType::get(getContext());
if (keyword == "witness")
return WitnessType::get(getContext());
parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
return Type();
@ -83,11 +86,27 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
case ShapeTypes::ValueShape:
os << "value_shape";
return;
case ShapeTypes::Witness:
os << "witness";
return;
default:
llvm_unreachable("unexpected 'shape' type kind");
}
}
//===----------------------------------------------------------------------===//
// AnyOp
//===----------------------------------------------------------------------===//
LogicalResult
AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(ShapeType::get(context));
return success();
}
//===----------------------------------------------------------------------===//
// BroadcastOp
//===----------------------------------------------------------------------===//

View File

@ -67,3 +67,16 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
%0 = shape.shape_of %arg0 : tensor<?xf32>
return %0 : !shape.shape
}
func @test_constraints() {
%0 = shape.const_shape []
%1 = shape.const_shape [1, 2, 3]
%w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
%w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
%w3 = "shape.assuming_all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness
"shape.assuming"(%w3) ( {
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
"shape.assuming_yield"(%2) : (!shape.shape) -> ()
}) : (!shape.witness) -> !shape.shape
return
}