forked from OSchip/llvm-project
Simplify usages of SplatElementsAttr now that it inherits from DenseElementsAttr.
PiperOrigin-RevId: 253910543
This commit is contained in:
parent
18743a33ac
commit
30bbd91056
|
@ -820,25 +820,6 @@ public:
|
|||
class SplatElementsAttr : public DenseElementsAttr {
|
||||
public:
|
||||
using DenseElementsAttr::DenseElementsAttr;
|
||||
using ValueType = Attribute;
|
||||
|
||||
/// 'type' must be a vector or tensor with static shape.
|
||||
static SplatElementsAttr get(ShapedType type, Attribute elt);
|
||||
Attribute getValue() const { return getSplatValue(); }
|
||||
|
||||
/// Generates a new SplatElementsAttr by mapping each int value to a new
|
||||
/// underlying APInt. The new values can represent either a integer or float.
|
||||
/// This ElementsAttr should contain integers.
|
||||
SplatElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const;
|
||||
|
||||
/// Generates a new SplatElementsAttr by mapping each float value to a new
|
||||
/// underlying APInt. The new values can represent either a integer or float.
|
||||
/// This ElementsAttr should contain floats.
|
||||
SplatElementsAttr
|
||||
mapValues(Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const;
|
||||
|
||||
/// Method for support type inquiry through isa, cast and dyn_cast.
|
||||
static bool classof(Attribute attr) {
|
||||
|
|
|
@ -113,7 +113,6 @@ public:
|
|||
TypeAttr getTypeAttr(Type type);
|
||||
FunctionAttr getFunctionAttr(Function *value);
|
||||
FunctionAttr getFunctionAttr(StringRef value);
|
||||
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
|
||||
ElementsAttr getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<Attribute> values);
|
||||
ElementsAttr getDenseIntElementsAttr(ShapedType type,
|
||||
|
|
|
@ -104,7 +104,7 @@ struct constant_int_op_binder {
|
|||
if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
|
||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value)
|
||||
.match(splatAttr.getValue());
|
||||
.match(splatAttr.getSplatValue());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -73,7 +73,6 @@ public:
|
|||
/// DenseFPElementsAttr
|
||||
/// OpaqueElementsAttr (with Float based type)
|
||||
/// SparseElementAttr (with Float based type)
|
||||
/// SplatElementsAttr
|
||||
class AttributeTensorStatistics : public AbstractTensorStatistics {
|
||||
public:
|
||||
AttributeTensorStatistics(Attribute attr) : attr(attr) {}
|
||||
|
|
|
@ -80,8 +80,8 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
|
|||
}
|
||||
|
||||
// Is the constant value a type expressed in a way that we support?
|
||||
if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
|
||||
!value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
|
||||
if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
|
||||
!value.isa<SparseElementsAttr>()) {
|
||||
return matchFailure();
|
||||
}
|
||||
|
||||
|
|
|
@ -66,36 +66,6 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
|
|||
return DenseIntElementsAttr::get(newDenseType, quantValues);
|
||||
}
|
||||
|
||||
/// Converts a real expressed SplatElementsAttr to a corresponding
|
||||
/// SplatElementsAttr containing quantized storage values assuming the given
|
||||
/// quantizedElementType and converter.
|
||||
static SplatElementsAttr
|
||||
convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
|
||||
QuantizedType quantizedElementType,
|
||||
const UniformQuantizedValueConverter &converter) {
|
||||
// Since the splat just references a single primitive value, use the
|
||||
// function for converting primitives.
|
||||
// NOTE: When implementing per-channel, we will need to promote the
|
||||
// splat to a dense and handle channels individually.
|
||||
Type unusedPrimitiveType;
|
||||
auto elementAttr =
|
||||
convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType,
|
||||
converter, unusedPrimitiveType);
|
||||
if (!elementAttr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
// Cast from an expressed-type-based type to storage-type-based type,
|
||||
// preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
|
||||
ShapedType newSplatType =
|
||||
quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
|
||||
.dyn_cast_or_null<ShapedType>();
|
||||
if (!newSplatType) {
|
||||
return nullptr;
|
||||
}
|
||||
return SplatElementsAttr::get(newSplatType, elementAttr);
|
||||
}
|
||||
|
||||
/// Converts a real expressed SplatElementsAttr to a corresponding
|
||||
/// SplatElementsAttr containing quantized storage values assuming the given
|
||||
/// quantizedElementType and converter.
|
||||
|
@ -134,13 +104,7 @@ Attribute quantizeAttrUniform(Attribute realValue,
|
|||
const UniformQuantizedValueConverter &converter,
|
||||
Type &outConvertedType) {
|
||||
// Fork to handle different variants of constants supported.
|
||||
if (realValue.isa<SplatElementsAttr>()) {
|
||||
// Splatted tensor or vector constant.
|
||||
auto converted = convertSplatElementsAttr(
|
||||
realValue.cast<SplatElementsAttr>(), quantizedElementType, converter);
|
||||
outConvertedType = converted.getType();
|
||||
return converted;
|
||||
} else if (realValue.isa<DenseFPElementsAttr>()) {
|
||||
if (realValue.isa<DenseFPElementsAttr>()) {
|
||||
// Dense tensor or vector constant.
|
||||
auto converted = convertDenseFPElementsAttr(
|
||||
realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);
|
||||
|
|
|
@ -940,28 +940,6 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
return getValues().getValue(it->second);
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// SplatElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
|
||||
return DenseElementsAttr::get(type, elt).cast<SplatElementsAttr>();
|
||||
}
|
||||
|
||||
SplatElementsAttr SplatElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APInt &)> mapping) const {
|
||||
return DenseElementsAttr::mapValues(newElementType, mapping)
|
||||
.cast<SplatElementsAttr>();
|
||||
}
|
||||
|
||||
SplatElementsAttr SplatElementsAttr::mapValues(
|
||||
Type newElementType,
|
||||
llvm::function_ref<APInt(const APFloat &)> mapping) const {
|
||||
return DenseElementsAttr::mapValues(newElementType, mapping)
|
||||
.cast<SplatElementsAttr>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// NamedAttributeList
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
|
|
@ -180,10 +180,6 @@ FunctionAttr Builder::getFunctionAttr(StringRef value) {
|
|||
return FunctionAttr::get(value, getContext());
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
|
||||
return SplatElementsAttr::get(type, elt);
|
||||
}
|
||||
|
||||
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
|
||||
ArrayRef<Attribute> values) {
|
||||
return DenseElementsAttr::get(type, values);
|
||||
|
@ -255,7 +251,7 @@ Attribute Builder::getZeroAttr(Type type) {
|
|||
auto element = getZeroAttr(vtType.getElementType());
|
||||
if (!element)
|
||||
return {};
|
||||
return getSplatElementsAttr(vtType, element);
|
||||
return getDenseElementsAttr(vtType, element);
|
||||
}
|
||||
default:
|
||||
break;
|
||||
|
|
|
@ -182,11 +182,11 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
|
|||
return {};
|
||||
|
||||
auto elementResult = constFoldBinaryOp<AttrElementT>(
|
||||
{lhs.getValue(), rhs.getValue()}, calculate);
|
||||
{lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
|
||||
if (!elementResult)
|
||||
return {};
|
||||
|
||||
return SplatElementsAttr::get(lhs.getType(), elementResult);
|
||||
return DenseElementsAttr::get(lhs.getType(), elementResult);
|
||||
}
|
||||
return {};
|
||||
}
|
||||
|
@ -1614,7 +1614,7 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
|
|||
// If this is a splat elements attribute, simply return the value. All of the
|
||||
// elements of a splat attribute are the same.
|
||||
if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
|
||||
return splatAggregate.getValue();
|
||||
return splatAggregate.getSplatValue();
|
||||
|
||||
// Otherwise, collect the constant indices into the aggregate.
|
||||
SmallVector<uint64_t, 8> indices;
|
||||
|
|
|
@ -92,18 +92,11 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
|
|||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
auto *vectorType = cast<llvm::VectorType>(llvmType);
|
||||
auto *child = getLLVMConstant(vectorType->getElementType(),
|
||||
splatAttr.getValue(), loc);
|
||||
splatAttr.getSplatValue(), loc);
|
||||
return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
|
||||
}
|
||||
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
|
||||
auto *vectorType = cast<llvm::VectorType>(llvmType);
|
||||
if (denseAttr.isSplat()) {
|
||||
auto *child = getLLVMConstant(vectorType->getElementType(),
|
||||
denseAttr.getSplatValue(), loc);
|
||||
return llvm::ConstantVector::getSplat(vectorType->getNumElements(),
|
||||
child);
|
||||
}
|
||||
|
||||
SmallVector<llvm::Constant *, 8> constants;
|
||||
uint64_t numElements = vectorType->getNumElements();
|
||||
constants.reserve(numElements);
|
||||
|
|
|
@ -388,7 +388,7 @@ materializeAttributes(Operation *opInst, VectorType hwVectorType) {
|
|||
SmallVector<NamedAttribute, 1> res;
|
||||
for (auto a : opInst->getAttrs()) {
|
||||
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
|
||||
auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue());
|
||||
auto attr = SplatElementsAttr::get(hwVectorType, splat.getSplatValue());
|
||||
res.push_back(NamedAttribute(a.first, attr));
|
||||
} else {
|
||||
res.push_back(a);
|
||||
|
|
|
@ -128,7 +128,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
|
|||
IntegerType convertedType = IntegerType::get(8, &ctx);
|
||||
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
|
||||
TestUniformQuantizedValueConverter converter(quantizedType);
|
||||
auto realValue = getTestElementsAttr<SplatElementsAttr, Attribute>(
|
||||
auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
|
||||
&ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
|
||||
|
||||
Type returnedType;
|
||||
|
|
Loading…
Reference in New Issue