[mlir][shape] Update meet to handle all size & shape types

Also tighten up return type inference & compatibility functions.

Differential Revision: https://reviews.llvm.org/D130866
This commit is contained in:
Jacques Pienaar 2022-08-10 05:08:24 -07:00
parent 8a4c40bfe8
commit 1f02ad7131
5 changed files with 89 additions and 11 deletions

View File

@ -110,6 +110,11 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
// Any type representing a shape or size/dim.
def Shape_AnyShapeOrSizeType : AnyTypeOf<
[Shape_SizeOrIndexType, Shape_ShapeOrExtentTensorType],
"any shape or size">;
def Shape_WitnessType : Shape_Type<"Witness", "witness"> {
let description = [{
A witness is a structural device in the compiler to maintain ordering of

View File

@ -406,11 +406,11 @@ def Shape_MaxOp : Shape_Op<"max",
def Shape_MeetOp : Shape_Op<"meet",
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Returns the least general shape.shape of its operands";
let summary = "Returns the least general shape or size of its operands";
let description = [{
An operation that computes the least general shape of input operands.
An operation that computes the least general shape or dim of input operands.
This effectively asserts that corresponding static dimensions are equal.
The behavior is to match each element of the `shape.shape` and propagate the
The behavior is to match each element of the shape/size and propagate the
most restrictive information, returning an invalid shape if there are
contradictory requirements. E.g., using pseudo code
@ -433,9 +433,11 @@ def Shape_MeetOp : Shape_Op<"meet",
```
}];
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_ShapeOrSizeType:$result);
let arguments = (ins
Shape_AnyShapeOrSizeType:$arg0,
Shape_AnyShapeOrSizeType:$arg1,
OptionalAttr<StrAttr>:$error);
let results = (outs Shape_AnyShapeOrSizeType:$result);
let assemblyFormat = [{
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`

View File

@ -1309,7 +1309,53 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.assign({operands[0].getType()});
if (operands.empty())
return failure();
auto isShapeType = [](Type arg) {
if (arg.isa<ShapeType>())
return true;
return isExtentTensorType(arg);
};
ValueRange::type_range types = operands.getTypes();
Type acc = types.front();
for (auto t : drop_begin(types)) {
Type l = acc, r = t;
if (!l.isa<ShapeType, SizeType>())
std::swap(l, r);
// Handle sizes, propagate error type if present.
if (l.isa<SizeType>()) {
if (r.isa<SizeType, IndexType>())
acc = l;
else
return emitOptionalError(location, "requires all sizes or shapes");
} else if (l.isa<IndexType>()) {
if (r.isa<IndexType>())
acc = r;
else
return emitOptionalError(location, "requires all sizes or shapes");
} else if (l.isa<ShapeType>()) {
// Handle shapes, propagate error type if present.
if (isShapeType(r))
acc = l;
else
return emitOptionalError(location, "requires all sizes or shapes");
} else if (isExtentTensorType(l)) {
auto rank1 = l.cast<RankedTensorType>().getShape()[0];
auto rank2 = r.cast<RankedTensorType>().getShape()[0];
if (ShapedType::isDynamic(rank1))
acc = l;
else if (ShapedType::isDynamic(rank2))
acc = r;
else if (rank1 != rank2)
return emitOptionalError(location, "unequal shape cardinality");
else
acc = l;
}
}
inferredReturnTypes.assign({acc});
return success();
}
@ -1322,11 +1368,13 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
Type lhs = l.front();
Type rhs = r.front();
if (lhs != rhs)
return false;
if (!lhs.isa<ShapeType, SizeType>())
std::swap(lhs, rhs);
if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
return true;
if (lhs.isa<SizeType>())
return rhs.isa<SizeType, IndexType>();
if (lhs.isa<ShapeType>())
return rhs.isa<ShapeType, TensorType>();
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
return true;

View File

@ -272,3 +272,20 @@ func.func @const_shape() {
%0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
return
}
// -----
func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
// expected-error@+1 {{requires all sizes or shapes}}
%result = shape.meet %arg0, %arg1 : !shape.shape, index -> index
return %result : index
}
// -----
func.func @invalid_meet(%arg0 : tensor<2xindex>, %arg1 : tensor<3xindex>) -> tensor<?xindex> {
// expected-error@+1 {{unequal shape cardinality}}
%result = shape.meet %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
return %result : tensor<?xindex>
}

View File

@ -325,3 +325,9 @@ func.func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
!shape.size, !shape.size -> !shape.size
return %2 : !shape.size
}
func.func @meet_index(%arg0 : index, %arg1 : index) -> index {
%result = shape.meet %arg0, %arg1 : index, index -> index
return %result : index
}