forked from OSchip/llvm-project
Exclude all ShapedType subclasses other than TensorType subclasses from having non-scalar elements.
The current logic assumes that ShapedType indicates a vector or tensor, which will not be true soon when MemRef subclasses ShapedType -- PiperOrigin-RevId: 250586364
This commit is contained in:
parent
17022b1bc5
commit
1c681a7caf
|
@ -147,11 +147,12 @@ int64_t ShapedType::getSizeInBits() const {
|
||||||
if (elementType.isIntOrFloat())
|
if (elementType.isIntOrFloat())
|
||||||
return elementType.getIntOrFloatBitWidth() * getNumElements();
|
return elementType.getIntOrFloatBitWidth() * getNumElements();
|
||||||
|
|
||||||
// Tensors can have vectors and other tensors as elements, vectors cannot.
|
// Tensors can have vectors and other tensors as elements, other shaped types
|
||||||
assert(!isa<VectorType>() && "unsupported vector element type");
|
// cannot.
|
||||||
auto elementShapedType = elementType.dyn_cast<ShapedType>();
|
assert(isa<TensorType>() && "unsupported element type");
|
||||||
assert(elementShapedType && "unsupported tensor element type");
|
assert((elementType.isa<VectorType>() || elementType.isa<TensorType>()) &&
|
||||||
return getNumElements() * elementShapedType.getSizeInBits();
|
"unsupported tensor element type");
|
||||||
|
return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
|
||||||
}
|
}
|
||||||
|
|
||||||
ArrayRef<int64_t> ShapedType::getShape() const {
|
ArrayRef<int64_t> ShapedType::getShape() const {
|
||||||
|
|
Loading…
Reference in New Issue