Exclude all ShapedType subclasses other than TensorType subclasses from having non-scalar elements.

The current logic assumes that ShapedType indicates a vector or tensor, which will not be true soon when MemRef subclasses ShapedType

--

PiperOrigin-RevId: 250586364
This commit is contained in:
Geoffrey Martin-Noble 2019-05-29 16:07:17 -07:00 committed by Mehdi Amini
parent 17022b1bc5
commit 1c681a7caf
1 changed files with 6 additions and 5 deletions

View File

@ -147,11 +147,12 @@ int64_t ShapedType::getSizeInBits() const {
if (elementType.isIntOrFloat()) if (elementType.isIntOrFloat())
return elementType.getIntOrFloatBitWidth() * getNumElements(); return elementType.getIntOrFloatBitWidth() * getNumElements();
// Tensors can have vectors and other tensors as elements, vectors cannot. // Tensors can have vectors and other tensors as elements, other shaped types
assert(!isa<VectorType>() && "unsupported vector element type"); // cannot.
auto elementShapedType = elementType.dyn_cast<ShapedType>(); assert(isa<TensorType>() && "unsupported element type");
assert(elementShapedType && "unsupported tensor element type"); assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
return getNumElements() * elementShapedType.getSizeInBits(); "unsupported tensor element type");
return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
} }
ArrayRef<int64_t> ShapedType::getShape() const { ArrayRef<int64_t> ShapedType::getShape() const {