[mlir] Change DenseArrayAttr to TensorType

Previously, DenseArrayAttr used VectorType for its shaped type.
VectorType is problematic for arrays because it doesn't support zero
dimensions, meaning that an empty array would have `vector<i32>` as its
type. ElementsAttr would think that an empty dense array is size 1, not
0. This patch switches over to TensorType, which does support zero
dimensions.

Fixes #56860

Reviewed By: mehdi_amini

Differential Revision: https://reviews.llvm.org/D130921
This commit is contained in:
Jeff Niu 2022-08-01 14:20:26 -04:00
parent 7a4902a0cc
commit ff52ad796c
2 changed files with 9 additions and 9 deletions

View File

@ -1880,7 +1880,7 @@ void AsmPrinter::Impl::printAttribute(Attribute attr,
os << "[:f64"; os << "[:f64";
break; break;
} }
if (denseArrayAttr.getType().getRank()) if (denseArrayAttr.size())
os << " "; os << " ";
denseArrayAttr.printWithoutBraces(os); denseArrayAttr.printWithoutBraces(os);
os << "]"; os << "]";

View File

@ -884,7 +884,7 @@ struct denseArrayAttrEltTypeBuilder<int8_t> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8; constexpr static auto eltType = DenseArrayBaseAttr::EltType::I8;
static ShapedType getShapedType(MLIRContext *context, static ShapedType getShapedType(MLIRContext *context,
ArrayRef<int64_t> shape) { ArrayRef<int64_t> shape) {
return VectorType::get(shape, IntegerType::get(context, 8)); return RankedTensorType::get(shape, IntegerType::get(context, 8));
} }
}; };
template <> template <>
@ -892,7 +892,7 @@ struct denseArrayAttrEltTypeBuilder<int16_t> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16; constexpr static auto eltType = DenseArrayBaseAttr::EltType::I16;
static ShapedType getShapedType(MLIRContext *context, static ShapedType getShapedType(MLIRContext *context,
ArrayRef<int64_t> shape) { ArrayRef<int64_t> shape) {
return VectorType::get(shape, IntegerType::get(context, 16)); return RankedTensorType::get(shape, IntegerType::get(context, 16));
} }
}; };
template <> template <>
@ -900,7 +900,7 @@ struct denseArrayAttrEltTypeBuilder<int32_t> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32; constexpr static auto eltType = DenseArrayBaseAttr::EltType::I32;
static ShapedType getShapedType(MLIRContext *context, static ShapedType getShapedType(MLIRContext *context,
ArrayRef<int64_t> shape) { ArrayRef<int64_t> shape) {
return VectorType::get(shape, IntegerType::get(context, 32)); return RankedTensorType::get(shape, IntegerType::get(context, 32));
} }
}; };
template <> template <>
@ -908,7 +908,7 @@ struct denseArrayAttrEltTypeBuilder<int64_t> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64; constexpr static auto eltType = DenseArrayBaseAttr::EltType::I64;
static ShapedType getShapedType(MLIRContext *context, static ShapedType getShapedType(MLIRContext *context,
ArrayRef<int64_t> shape) { ArrayRef<int64_t> shape) {
return VectorType::get(shape, IntegerType::get(context, 64)); return RankedTensorType::get(shape, IntegerType::get(context, 64));
} }
}; };
template <> template <>
@ -916,7 +916,7 @@ struct denseArrayAttrEltTypeBuilder<float> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32; constexpr static auto eltType = DenseArrayBaseAttr::EltType::F32;
static ShapedType getShapedType(MLIRContext *context, static ShapedType getShapedType(MLIRContext *context,
ArrayRef<int64_t> shape) { ArrayRef<int64_t> shape) {
return VectorType::get(shape, Float32Type::get(context)); return RankedTensorType::get(shape, Float32Type::get(context));
} }
}; };
template <> template <>
@ -924,7 +924,7 @@ struct denseArrayAttrEltTypeBuilder<double> {
constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64; constexpr static auto eltType = DenseArrayBaseAttr::EltType::F64;
static ShapedType getShapedType(MLIRContext *context, static ShapedType getShapedType(MLIRContext *context,
ArrayRef<int64_t> shape) { ArrayRef<int64_t> shape) {
return VectorType::get(shape, Float64Type::get(context)); return RankedTensorType::get(shape, Float64Type::get(context));
} }
}; };
} // namespace } // namespace
@ -934,8 +934,8 @@ template <typename T>
DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context, DenseArrayAttr<T> DenseArrayAttr<T>::get(MLIRContext *context,
ArrayRef<T> content) { ArrayRef<T> content) {
auto size = static_cast<int64_t>(content.size()); auto size = static_cast<int64_t>(content.size());
auto shapedType = denseArrayAttrEltTypeBuilder<T>::getShapedType( auto shapedType =
context, size ? ArrayRef<int64_t>{size} : ArrayRef<int64_t>{}); denseArrayAttrEltTypeBuilder<T>::getShapedType(context, size);
auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType; auto eltType = denseArrayAttrEltTypeBuilder<T>::eltType;
auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()), auto rawArray = ArrayRef<char>(reinterpret_cast<const char *>(content.data()),
content.size() * sizeof(T)); content.size() * sizeof(T));