forked from OSchip/llvm-project
Make shape matching work for any shaped type.
The current implementation makes some assumptions about what can be a shaped type, which aren't really necessary. It also has strange behavior for types that aren't in the limited set it handles (e.g. dialect-defined types) Updated the comment to match the implementation. This is partially motivated by the desire to make MemRef a subclass of ShapedType -- PiperOrigin-RevId: 248859674
This commit is contained in:
parent
3de0c7696b
commit
22a8bc6ec3
|
@ -767,24 +767,21 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
|
|||
}
|
||||
|
||||
/// Returns success if the given two types have the same shape. That is,
|
||||
/// they are both scalars, or they are both static shaped types with the same
|
||||
/// dimension specifications. The element type does not matter.
|
||||
/// they are both scalars (not shaped), or they are both shaped types and at
|
||||
/// least one is unranked or they have the same shape. The element type does not
|
||||
/// matter.
|
||||
static LogicalResult verifyShapeMatch(Type type1, Type type2) {
|
||||
// Check scalar cases
|
||||
if (type1.isIntOrIndexOrFloat())
|
||||
return success(type2.isIntOrIndexOrFloat());
|
||||
auto sType1 = type1.dyn_cast<ShapedType>();
|
||||
auto sType2 = type2.dyn_cast<ShapedType>();
|
||||
|
||||
// Check unranked tensor cases
|
||||
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>())
|
||||
// Either both or neither type should be shaped.
|
||||
if (!sType1)
|
||||
return success(!sType2);
|
||||
|
||||
if (sType1.getRank() == -1 || sType2.getRank() == -1)
|
||||
return success();
|
||||
|
||||
// Check normal vector/tensor cases
|
||||
if (auto sType1 = type1.dyn_cast<ShapedType>()) {
|
||||
auto sType2 = type2.dyn_cast<ShapedType>();
|
||||
return success(sType2 && sType1.getShape() == sType2.getShape());
|
||||
}
|
||||
|
||||
return success();
|
||||
return success(sType1.getShape() == sType2.getShape());
|
||||
}
|
||||
|
||||
LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {
|
||||
|
|
Loading…
Reference in New Issue