forked from OSchip/llvm-project
Exclude all ShapedType subclasses other than TensorType subclasses from having non-scalar elements.
The current logic assumes that ShapedType indicates a vector or tensor, which will not be true soon when MemRef subclasses ShapedType -- PiperOrigin-RevId: 250586364
This commit is contained in:
parent
17022b1bc5
commit
1c681a7caf
|
@ -147,11 +147,12 @@ int64_t ShapedType::getSizeInBits() const {
|
|||
if (elementType.isIntOrFloat())
|
||||
return elementType.getIntOrFloatBitWidth() * getNumElements();
|
||||
|
||||
// Tensors can have vectors and other tensors as elements, vectors cannot.
|
||||
assert(!isa<VectorType>() && "unsupported vector element type");
|
||||
auto elementShapedType = elementType.dyn_cast<ShapedType>();
|
||||
assert(elementShapedType && "unsupported tensor element type");
|
||||
return getNumElements() * elementShapedType.getSizeInBits();
|
||||
// Tensors can have vectors and other tensors as elements, other shaped types
|
||||
// cannot.
|
||||
assert(isa<TensorType>() && "unsupported element type");
|
||||
assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
|
||||
"unsupported tensor element type");
|
||||
return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
|
||||
}
|
||||
|
||||
ArrayRef<int64_t> ShapedType::getShape() const {
|
||||
|
|
Loading…
Reference in New Issue