forked from OSchip/llvm-project
[mlir][shape] Update meet to handle all size & shape types
Also tighten up return type inference & compatibility functions. Differential Revision: https://reviews.llvm.org/D130866
This commit is contained in:
parent
8a4c40bfe8
commit
1f02ad7131
mlir
include/mlir/Dialect/Shape/IR
lib/Dialect/Shape/IR
test/Dialect/Shape
|
@ -110,6 +110,11 @@ def Shape_ShapeOrExtentTensorType : AnyTypeOf<[Shape_ShapeType,
|
||||||
|
|
||||||
def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
|
def Shape_SizeOrIndexType : AnyTypeOf<[Shape_SizeType, Index], "size or index">;
|
||||||
|
|
||||||
|
// Any type representing a shape or size/dim.
|
||||||
|
def Shape_AnyShapeOrSizeType : AnyTypeOf<
|
||||||
|
[Shape_SizeOrIndexType, Shape_ShapeOrExtentTensorType],
|
||||||
|
"any shape or size">;
|
||||||
|
|
||||||
def Shape_WitnessType : Shape_Type<"Witness", "witness"> {
|
def Shape_WitnessType : Shape_Type<"Witness", "witness"> {
|
||||||
let description = [{
|
let description = [{
|
||||||
A witness is a structural device in the compiler to maintain ordering of
|
A witness is a structural device in the compiler to maintain ordering of
|
||||||
|
|
|
@ -406,11 +406,11 @@ def Shape_MaxOp : Shape_Op<"max",
|
||||||
|
|
||||||
def Shape_MeetOp : Shape_Op<"meet",
|
def Shape_MeetOp : Shape_Op<"meet",
|
||||||
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
[Commutative, DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
|
||||||
let summary = "Returns the least general shape.shape of its operands";
|
let summary = "Returns the least general shape or size of its operands";
|
||||||
let description = [{
|
let description = [{
|
||||||
An operation that computes the least general shape of input operands.
|
An operation that computes the least general shape or dim of input operands.
|
||||||
This effectively asserts that corresponding static dimensions are equal.
|
This effectively asserts that corresponding static dimensions are equal.
|
||||||
The behavior is to match each element of the `shape.shape` and propagate the
|
The behavior is to match each element of the shape/size and propagate the
|
||||||
most restrictive information, returning an invalid shape if there are
|
most restrictive information, returning an invalid shape if there are
|
||||||
contradictory requirements. E.g., using pseudo code
|
contradictory requirements. E.g., using pseudo code
|
||||||
|
|
||||||
|
@ -433,9 +433,11 @@ def Shape_MeetOp : Shape_Op<"meet",
|
||||||
```
|
```
|
||||||
}];
|
}];
|
||||||
|
|
||||||
let arguments = (ins Shape_ShapeOrSizeType:$arg0, Shape_ShapeOrSizeType:$arg1,
|
let arguments = (ins
|
||||||
|
Shape_AnyShapeOrSizeType:$arg0,
|
||||||
|
Shape_AnyShapeOrSizeType:$arg1,
|
||||||
OptionalAttr<StrAttr>:$error);
|
OptionalAttr<StrAttr>:$error);
|
||||||
let results = (outs Shape_ShapeOrSizeType:$result);
|
let results = (outs Shape_AnyShapeOrSizeType:$result);
|
||||||
|
|
||||||
let assemblyFormat = [{
|
let assemblyFormat = [{
|
||||||
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
|
$arg0 `,` $arg1 (`,` `error` `=` $error^)? attr-dict `:`
|
||||||
|
|
|
@ -1309,7 +1309,53 @@ LogicalResult mlir::shape::MeetOp::inferReturnTypes(
|
||||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||||
DictionaryAttr attributes, RegionRange regions,
|
DictionaryAttr attributes, RegionRange regions,
|
||||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||||
inferredReturnTypes.assign({operands[0].getType()});
|
if (operands.empty())
|
||||||
|
return failure();
|
||||||
|
|
||||||
|
auto isShapeType = [](Type arg) {
|
||||||
|
if (arg.isa<ShapeType>())
|
||||||
|
return true;
|
||||||
|
return isExtentTensorType(arg);
|
||||||
|
};
|
||||||
|
|
||||||
|
ValueRange::type_range types = operands.getTypes();
|
||||||
|
Type acc = types.front();
|
||||||
|
for (auto t : drop_begin(types)) {
|
||||||
|
Type l = acc, r = t;
|
||||||
|
if (!l.isa<ShapeType, SizeType>())
|
||||||
|
std::swap(l, r);
|
||||||
|
|
||||||
|
// Handle sizes, propagate error type if present.
|
||||||
|
if (l.isa<SizeType>()) {
|
||||||
|
if (r.isa<SizeType, IndexType>())
|
||||||
|
acc = l;
|
||||||
|
else
|
||||||
|
return emitOptionalError(location, "requires all sizes or shapes");
|
||||||
|
} else if (l.isa<IndexType>()) {
|
||||||
|
if (r.isa<IndexType>())
|
||||||
|
acc = r;
|
||||||
|
else
|
||||||
|
return emitOptionalError(location, "requires all sizes or shapes");
|
||||||
|
} else if (l.isa<ShapeType>()) {
|
||||||
|
// Handle shapes, propagate error type if present.
|
||||||
|
if (isShapeType(r))
|
||||||
|
acc = l;
|
||||||
|
else
|
||||||
|
return emitOptionalError(location, "requires all sizes or shapes");
|
||||||
|
} else if (isExtentTensorType(l)) {
|
||||||
|
auto rank1 = l.cast<RankedTensorType>().getShape()[0];
|
||||||
|
auto rank2 = r.cast<RankedTensorType>().getShape()[0];
|
||||||
|
if (ShapedType::isDynamic(rank1))
|
||||||
|
acc = l;
|
||||||
|
else if (ShapedType::isDynamic(rank2))
|
||||||
|
acc = r;
|
||||||
|
else if (rank1 != rank2)
|
||||||
|
return emitOptionalError(location, "unequal shape cardinality");
|
||||||
|
else
|
||||||
|
acc = l;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
inferredReturnTypes.assign({acc});
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -1322,11 +1368,13 @@ bool mlir::shape::MeetOp::isCompatibleReturnTypes(TypeRange l, TypeRange r) {
|
||||||
Type lhs = l.front();
|
Type lhs = l.front();
|
||||||
Type rhs = r.front();
|
Type rhs = r.front();
|
||||||
|
|
||||||
if (lhs != rhs)
|
if (!lhs.isa<ShapeType, SizeType>())
|
||||||
return false;
|
std::swap(lhs, rhs);
|
||||||
|
|
||||||
if (lhs.isa<SizeType>() || lhs.isa<ShapeType>())
|
if (lhs.isa<SizeType>())
|
||||||
return true;
|
return rhs.isa<SizeType, IndexType>();
|
||||||
|
if (lhs.isa<ShapeType>())
|
||||||
|
return rhs.isa<ShapeType, TensorType>();
|
||||||
|
|
||||||
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
|
if (succeeded(verifyCompatibleShapes({lhs, rhs})))
|
||||||
return true;
|
return true;
|
||||||
|
|
|
@ -272,3 +272,20 @@ func.func @const_shape() {
|
||||||
%0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
|
%0 = shape.const_shape [4, 5, 6] : tensor<2xindex>
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @invalid_meet(%arg0 : !shape.shape, %arg1 : index) -> index {
|
||||||
|
// expected-error@+1 {{requires all sizes or shapes}}
|
||||||
|
%result = shape.meet %arg0, %arg1 : !shape.shape, index -> index
|
||||||
|
return %result : index
|
||||||
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
|
||||||
|
func.func @invalid_meet(%arg0 : tensor<2xindex>, %arg1 : tensor<3xindex>) -> tensor<?xindex> {
|
||||||
|
// expected-error@+1 {{unequal shape cardinality}}
|
||||||
|
%result = shape.meet %arg0, %arg1 : tensor<2xindex>, tensor<3xindex> -> tensor<?xindex>
|
||||||
|
return %result : tensor<?xindex>
|
||||||
|
}
|
||||||
|
|
||||||
|
|
|
@ -325,3 +325,9 @@ func.func @size_lower_bounded_by_constant(%a: !shape.size) -> !shape.size {
|
||||||
!shape.size, !shape.size -> !shape.size
|
!shape.size, !shape.size -> !shape.size
|
||||||
return %2 : !shape.size
|
return %2 : !shape.size
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func.func @meet_index(%arg0 : index, %arg1 : index) -> index {
|
||||||
|
%result = shape.meet %arg0, %arg1 : index, index -> index
|
||||||
|
return %result : index
|
||||||
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue