Make shape matching work for any shaped type.

The current implementation makes some assumptions about what can be a shaped type, which aren't really necessary. It also has strange behavior for types that aren't in the limited set it handles (e.g. dialect-defined types)

    Updated the comment to match the implementation.

    This is partially motivated by the desire to make MemRef a subclass of ShapedType

--

PiperOrigin-RevId: 248859674
This commit is contained in:
Geoffrey Martin-Noble 2019-05-18 05:31:35 -07:00 committed by Mehdi Amini
parent 3de0c7696b
commit 22a8bc6ec3
1 changed files with 11 additions and 14 deletions

View File

@ -767,24 +767,21 @@ LogicalResult OpTrait::impl::verifyAtLeastNResults(Operation *op,
}
/// Returns success if the given two types have the same shape. That is,
/// they are both scalars, or they are both static shaped types with the same
/// dimension specifications. The element type does not matter.
/// they are both scalars (not shaped), or they are both shaped types and at
/// least one is unranked or they have the same shape. The element type does not
/// matter.
static LogicalResult verifyShapeMatch(Type type1, Type type2) {
// Check scalar cases
if (type1.isIntOrIndexOrFloat())
return success(type2.isIntOrIndexOrFloat());
auto sType1 = type1.dyn_cast<ShapedType>();
auto sType2 = type2.dyn_cast<ShapedType>();
// Check unranked tensor cases
if (type1.isa<UnrankedTensorType>() || type2.isa<UnrankedTensorType>())
// Either both or neither type should be shaped.
if (!sType1)
return success(!sType2);
if (sType1.getRank() == -1 || sType2.getRank() == -1)
return success();
// Check normal vector/tensor cases
if (auto sType1 = type1.dyn_cast<ShapedType>()) {
auto sType2 = type2.dyn_cast<ShapedType>();
return success(sType2 && sType1.getShape() == sType2.getShape());
}
return success();
return success(sType1.getShape() == sType2.getShape());
}
LogicalResult OpTrait::impl::verifySameOperandsAndResultShape(Operation *op) {