forked from OSchip/llvm-project
[mlir][shape] Update meet to handle all size & shape types
Also tighten up return type inference & compatibility functions. Differential Revision: https://reviews.llvm.org/D130866
This commit is contained in:
parent
8a4c40bfe8
commit
1f02ad7131
|
@ -110,6 +110,11 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
|
|||
|
||||
def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
|
||||
|
||||
// Any type representing a shape or size/dim.
|
||||
def Shape_AnyShapeOrSizeType : AnyTypeOf<
|
||||
[Shape_SizeOrIndexType, Shape_ShapeOrExtentTensorType],
|
||||
"any shape or size">;
|
||||
|
||||
def Shape_WitnessType : Shape_Type<"Witness", "witness"> {
|
||||
let description = [{
|
||||
A witness is a structural device in the compiler to maintain ordering of
|
||||
|
|
|
@ -406,11 +406,11 @@ def Shape_MaxOp : Shape_Op<"max",
|
|||
|
||||
def Shape_MeetOp : Shape_Op<"meet",
|
||||
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||
let summary = "Returns the least general shape.shape of its operands";
|
||||
let summary = "Returns the least general shape or size of its operands";
|
||||
let description = [{
|
||||
An operation that computes the least general shape of input operands.
|
||||
An operation that computes the least general shape or dim of input operands.
|
||||
This effectively asserts that corresponding static dimensions are equal.
|
||||
The behavior is to match each element of the `shape.shape` and propagate the
|
||||
The behavior is to match each element of the shape/size and propagate the
|
||||
most restrictive information, returning an invalid shape if there are
|
||||
contradictory requirements. E.g., using pseudo code
|
||||
|
||||
|
@ -433,9 +433,11 @@ def Shape_MeetOp : Shape_Op<"meet",
|
|||
```
|
||||
}];
|
||||
|
||||
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
|
||||
let arguments = (ins
|
||||
Shape_AnyShapeOrSizeType:$arg0,
|
||||
Shape_AnyShapeOrSizeType:$arg1,
|
||||
OptionalAttr<StrAttr>:$error);
|
||||
let results = (outs Shape_ShapeOrSizeType:$result);
|
||||
let results = (outs Shape_AnyShapeOrSizeType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
|
||||
|
|
|
@ -1309,7 +1309,53 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
|
|||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
inferredReturnTypes.assign({operands[0].getType()});
|
||||
if (operands.empty())
|
||||
return failure();
|
||||
|
||||
auto isShapeType = [](Type arg) {
|
||||
if (arg.isa<ShapeType>())
|
||||
return true;
|
||||
return isExtentTensorType(arg);
|
||||
};
|
||||
|
||||
ValueRange::type_range types = operands.getTypes();
|
||||
Type acc = types.front();
|
||||
for (auto t : drop_begin(types)) {
|
||||
Type l = acc, r = t;
|
||||
if (!l.isa<ShapeType, SizeType>())
|
||||
std::swap(l, r);
|
||||
|
||||
// Handle sizes, propagate error type if present.
|
||||
if (l.isa<SizeType>()) {
|
||||
if (r.isa<SizeType, IndexType>())
|
||||
acc = l;
|
||||
else
|
||||
return emitOptionalError(location, "requires all sizes or shapes");
|
||||
} else if (l.isa<IndexType>()) {
|
||||
if (r.isa<IndexType>())
|
||||
acc = r;
|
||||
else
|
||||
return emitOptionalError(location, "requires all sizes or shapes");
|
||||
} else if (l.isa<ShapeType>()) {
|
||||
// Handle shapes, propagate error type if present.
|
||||
if (isShapeType(r))
|
||||
acc = l;
|
||||
else
|
||||
return emitOptionalError(location, "requires all sizes or shapes");
|
||||
} else if (isExtentTensorType(l)) {
|
||||
auto rank1 = l.cast<RankedTensorType>().getShape()[0];
|
||||
auto rank2 = r.cast<RankedTensorType>().getShape()[0];
|
||||
if (ShapedType::isDynamic(rank1))
|
||||
acc = l;
|
||||
else if (ShapedType::isDynamic(rank2))
|
||||
acc = r;
|
||||
else if (rank1 != rank2)
|
||||
return emitOptionalError(location, "unequal shape cardinality");
|
||||
else
|
||||
acc = l;
|
||||
}
|
||||
}
|
||||
inferredReturnTypes.assign({acc});
|
||||
return success();
|
||||
}
|
||||
|
||||
|
@ -1322,11 +1368,13 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
|||
Type lhs = l.front();
|
||||
Type rhs = r.front();
|
||||
|
||||
if (lhs != rhs)
|
||||
return false;
|
||||
if (!lhs.isa<ShapeType, SizeType>())
|
||||
std::swap(lhs, rhs);
|
||||
|
||||
if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
|
||||
return true;
|
||||
if (lhs.isa<SizeType>())
|
||||
return rhs.isa<SizeType, IndexType>();
|
||||
if (lhs.isa<ShapeType>())
|
||||
return rhs.isa<ShapeType, TensorType>();
|
||||
|
||||
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
|
||||
return true;
|
||||
|
|
|
@ -272,3 +272,20 @@ func.func @const_shape() {
|
|||
%0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
|
||||
// expected-error@+1 {{requires all sizes or shapes}}
|
||||
%result = shape.meet %arg0, %arg1 : !shape.shape, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func.func @invalid_meet(%arg0 : tensor<2xindex>, %arg1 : tensor<3xindex>) -> tensor<?xindex> {
|
||||
// expected-error@+1 {{unequal shape cardinality}}
|
||||
%result = shape.meet %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||
return %result : tensor<?xindex>
|
||||
}
|
||||
|
||||
|
|
|
@ -325,3 +325,9 @@ func.func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
|
|||
!shape.size, !shape.size -> !shape.size
|
||||
return %2 : !shape.size
|
||||
}
|
||||
|
||||
func.func @meet_index(%arg0 : index, %arg1 : index) -> index {
|
||||
%result = shape.meet %arg0, %arg1 : index, index -> index
|
||||
return %result : index
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue