forked from OSchip/llvm-project
[MLIR] Add shape.witness type and ops
Summary: These represent shape based preconditions on execution of code. Differential Revision: https://reviews.llvm.org/D79717
This commit is contained in:
parent
9d4b4f344d
commit
a26883e5aa
|
@ -30,7 +30,8 @@ enum Kind {
|
|||
Shape,
|
||||
Size,
|
||||
ValueShape,
|
||||
LAST_SHAPE_TYPE = ValueShape
|
||||
Witness,
|
||||
LAST_SHAPE_TYPE = Witness
|
||||
};
|
||||
} // namespace ShapeTypes
|
||||
|
||||
|
@ -105,6 +106,22 @@ public:
|
|||
}
|
||||
};
|
||||
|
||||
/// The Witness represents a runtime constraint, to be used as shape related
|
||||
/// preconditions on code execution.
|
||||
class WitnessType : public Type::TypeBase<WitnessType, Type> {
|
||||
public:
|
||||
using Base::Base;
|
||||
|
||||
static WitnessType get(MLIRContext *context) {
|
||||
return Base::get(context, ShapeTypes::Kind::Witness);
|
||||
}
|
||||
|
||||
/// Support method to enable LLVM-style type casting.
|
||||
static bool kindof(unsigned kind) {
|
||||
return kind == ShapeTypes::Kind::Witness;
|
||||
}
|
||||
};
|
||||
|
||||
#define GET_OP_CLASSES
|
||||
#include "mlir/Dialect/Shape/IR/ShapeOps.h.inc"
|
||||
|
||||
|
|
|
@ -17,6 +17,32 @@ include "mlir/Dialect/Shape/IR/ShapeBase.td"
|
|||
include "mlir/Interfaces/InferTypeOpInterface.td"
|
||||
include "mlir/Interfaces/SideEffectInterfaces.td"
|
||||
|
||||
def Shape_WitnessType : DialectType<ShapeDialect,
|
||||
CPred<"$_self.isa<::mlir::shape::WitnessType>()">, "witness">,
|
||||
BuildableType<"$_builder.getType<::mlir::shape::WitnessType>()"> {
|
||||
let typeDescription = [{
|
||||
A witness is a structural device in the compiler to maintain ordering of
|
||||
code relying on information obtained from passing assertions. Witnesses do
|
||||
not represent any physical data.
|
||||
|
||||
"cstr_" operations will return witnesses and be lowered into assertion logic
|
||||
when not resolvable at compile time.
|
||||
|
||||
"assuming_" operations will take witnesses as input and represent only
|
||||
information to the compiler, so they do not exist in executing code. Code
|
||||
that is dependent on "assuming_" operations can assume all cstr operations
|
||||
transitively before are honored as true.
|
||||
|
||||
These abstractions are intended to allow the compiler more freedom with
|
||||
assertions by merely showing the assertion through dataflow at this time
|
||||
rather than a side effecting operation that acts as a barrier. This can be
|
||||
viewed similarly to a compiler representation of promises from asynchronous,
|
||||
possibly crashing assertions. Reliant code will not be reordered to before
|
||||
the code and non-reliant code can be reordered freely, and there are no
|
||||
guarantees on the final ordering of the assertions or their related code.
|
||||
}];
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shape op definitions
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -313,4 +339,123 @@ def Shape_ConcatOp : Shape_Op<"concat",
|
|||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// Shape constraint related ops.
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
//TODO(tpopp): Move the code below and witnesses to a different file.
|
||||
def Shape_AnyOp : Shape_Op<"any",
|
||||
[DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Return any combination of the input shapes.";
|
||||
let description = [{
|
||||
This operation takes multiple input shapes and returns some combination of
|
||||
their dimensions. This can be best seen with examples below.
|
||||
|
||||
The result is undefined, but still side-effect free, in cases where the
|
||||
inputs have differing ranks or differ in extents of shared dimensions.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%s0 = shape.any([2,?], [?,3]) // [2,3]
|
||||
%s1 = shape.any([?,?], [1,2]) // [1,2]
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
|
||||
let results = (outs Shape_ShapeType:$result);
|
||||
}
|
||||
|
||||
def Shape_AssumingAllOp : Shape_Op<"assuming_all", []> {
|
||||
let summary = "Return a logical AND of all witnesses.";
|
||||
let description = [{
|
||||
Used to simplify constraints as any single failing precondition is enough
|
||||
to prevent execution.
|
||||
|
||||
"assuming" operations represent an execution order restriction to the
|
||||
compiler, information for dependent code to rely on (by assuming), and
|
||||
nothing else. They should not exist after a program is fully lowered and
|
||||
ready to execute.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success
|
||||
%w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure
|
||||
%w2 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success
|
||||
%wf = shape.assume_all(%w0, %w1) // Failure
|
||||
%wt = shape.assume_all(%w0, %w2) // Success
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<Shape_WitnessType>:$inputs);
|
||||
let results = (outs Shape_WitnessType:$result);
|
||||
}
|
||||
|
||||
def Shape_AssumingOp : Shape_Op<"assuming",
|
||||
[SingleBlockImplicitTerminator<"AssumingYieldOp">,
|
||||
RecursiveSideEffects]> {
|
||||
let summary = "Execute the region.";
|
||||
let description = [{
|
||||
Executes the region assuming all witnesses are true.
|
||||
|
||||
"assuming" operations represent an execution order restriction to the
|
||||
compiler, information for dependent code to rely on (by assuming), and
|
||||
nothing else. They should not exist after a program is fully lowered and
|
||||
ready to execute.
|
||||
}];
|
||||
let arguments = (ins Shape_WitnessType);
|
||||
let regions = (region SizedRegion<1>:$thenRegion);
|
||||
let results = (outs Variadic<AnyType>:$results);
|
||||
}
|
||||
|
||||
def Shape_AssumingYieldOp : Shape_Op<"assuming_yield", [Terminator]> {
|
||||
let summary = "Yield operation";
|
||||
let description = [{
|
||||
This yield operation represents a return operation within the assert_and_exec
|
||||
region. The operation takes variable number of operands and produces no
|
||||
results. The operand number and types must match the return signature of
|
||||
the region that contains the operation.
|
||||
}];
|
||||
|
||||
let arguments = (ins Variadic<AnyType>:$operands);
|
||||
}
|
||||
|
||||
def Shape_CstrBroadcastableOp : Shape_Op<"cstr_broadcastable", []> {
|
||||
let summary = "Determines if 2 shapes can be successfully broadcasted.";
|
||||
let description = [{
|
||||
Given 2 input shapes, return a witness specifying if they are broadcastable.
|
||||
This broadcastable follows the same logic as what shape.broadcast documents.
|
||||
|
||||
"cstr" operations represent runtime assertions.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%w0 = shape.cstr_broadcastable([2,2], [3,1,2]) // Success
|
||||
%w1 = shape.cstr_broadcastable([2,2], [3,2]) // Failure
|
||||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
|
||||
let results = (outs Shape_WitnessType:$result);
|
||||
}
|
||||
|
||||
def Shape_CstrEqOp : Shape_Op<"cstr_eq", []> {
|
||||
let summary = "Determines if all input shapes are equal.";
|
||||
let description = [{
|
||||
Given 1 or more input shapes, determine if all shapes are the exact same.
|
||||
|
||||
"cstr" operations represent runtime assertions.
|
||||
|
||||
Example:
|
||||
```mlir
|
||||
%w0 = shape.cstr_eq([1,2], [1,2], [1,2]) // Success
|
||||
%w1 = shape.cstr_eq([2,2], [1,2]) // Failure
|
||||
```
|
||||
}];
|
||||
let arguments = (ins Variadic<Shape_ShapeType>:$inputs);
|
||||
let results = (outs Shape_WitnessType:$result);
|
||||
}
|
||||
|
||||
|
||||
// Canonicalization patterns.
|
||||
|
||||
#endif // SHAPE_OPS
|
||||
|
|
|
@ -24,7 +24,8 @@ ShapeDialect::ShapeDialect(MLIRContext *context)
|
|||
#define GET_OP_LIST
|
||||
#include "mlir/Dialect/Shape/IR/ShapeOps.cpp.inc"
|
||||
>();
|
||||
addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType>();
|
||||
addTypes<ComponentType, ElementType, ShapeType, SizeType, ValueShapeType,
|
||||
WitnessType>();
|
||||
// Allow unknown operations during prototyping and testing. As the dialect is
|
||||
// still evolving it makes it simple to start with an unregistered ops and
|
||||
// try different variants before actually defining the op.
|
||||
|
@ -60,6 +61,8 @@ Type ShapeDialect::parseType(DialectAsmParser &parser) const {
|
|||
return SizeType::get(getContext());
|
||||
if (keyword == "value_shape")
|
||||
return ValueShapeType::get(getContext());
|
||||
if (keyword == "witness")
|
||||
return WitnessType::get(getContext());
|
||||
|
||||
parser.emitError(parser.getNameLoc(), "unknown shape type: ") << keyword;
|
||||
return Type();
|
||||
|
@ -83,11 +86,27 @@ void ShapeDialect::printType(Type type, DialectAsmPrinter &os) const {
|
|||
case ShapeTypes::ValueShape:
|
||||
os << "value_shape";
|
||||
return;
|
||||
case ShapeTypes::Witness:
|
||||
os << "witness";
|
||||
return;
|
||||
default:
|
||||
llvm_unreachable("unexpected 'shape' type kind");
|
||||
}
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// AnyOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
AnyOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes,
|
||||
RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.push_back(ShapeType::get(context));
|
||||
return success();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BroadcastOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -67,3 +67,16 @@ func @test_shape_of(%arg0: tensor<?xf32>) -> !shape.shape {
|
|||
%0 = shape.shape_of %arg0 : tensor<?xf32>
|
||||
return %0 : !shape.shape
|
||||
}
|
||||
|
||||
func @test_constraints() {
|
||||
%0 = shape.const_shape []
|
||||
%1 = shape.const_shape [1, 2, 3]
|
||||
%w0 = "shape.cstr_broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
|
||||
%w1 = "shape.cstr_eq"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.witness
|
||||
%w3 = "shape.assuming_all"(%w0, %w1) : (!shape.witness, !shape.witness) -> !shape.witness
|
||||
"shape.assuming"(%w3) ( {
|
||||
%2 = "shape.any"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
|
||||
"shape.assuming_yield"(%2) : (!shape.shape) -> ()
|
||||
}) : (!shape.witness) -> !shape.shape
|
||||
return
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue