Simplify usages of SplatElementsAttr now that it inherits from DenseElementsAttr.

PiperOrigin-RevId: 253910543
This commit is contained in:
River Riddle 2019-06-18 18:26:26 -07:00 committed by Mehdi Amini
parent 18743a33ac
commit 30bbd91056
12 changed files with 11 additions and 101 deletions

View File

@ -820,25 +820,6 @@ public:
class SplatElementsAttr : public DenseElementsAttr {
public:
using DenseElementsAttr::DenseElementsAttr;
using ValueType = Attribute;
/// 'type' must be a vector or tensor with static shape.
static SplatElementsAttr get(ShapedType type, Attribute elt);
Attribute getValue() const { return getSplatValue(); }
/// Generates a new SplatElementsAttr by mapping each int value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain integers.
SplatElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const;
/// Generates a new SplatElementsAttr by mapping each float value to a new
/// underlying APInt. The new values can represent either a integer or float.
/// This ElementsAttr should contain floats.
SplatElementsAttr
mapValues(Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const;
/// Method for support type inquiry through isa, cast and dyn_cast.
static bool classof(Attribute attr) {

View File

@ -113,7 +113,6 @@ public:
TypeAttr getTypeAttr(Type type);
FunctionAttr getFunctionAttr(Function *value);
FunctionAttr getFunctionAttr(StringRef value);
ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt);
ElementsAttr getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values);
ElementsAttr getDenseIntElementsAttr(ShapedType type,

View File

@ -104,7 +104,7 @@ struct constant_int_op_binder {
if (type.isa<VectorType>() || type.isa<RankedTensorType>()) {
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
return attr_value_binder<IntegerAttr>(bind_value)
.match(splatAttr.getValue());
.match(splatAttr.getSplatValue());
}
}
return false;

View File

@ -73,7 +73,6 @@ public:
/// DenseFPElementsAttr
/// OpaqueElementsAttr (with Float based type)
/// SparseElementAttr (with Float based type)
/// SplatElementsAttr
class AttributeTensorStatistics : public AbstractTensorStatistics {
public:
AttributeTensorStatistics(Attribute attr) : attr(attr) {}

View File

@ -80,8 +80,8 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier,
}
// Is the constant value a type expressed in a way that we support?
if (!value.isa<FloatAttr>() && !value.isa<SplatElementsAttr>() &&
!value.isa<DenseElementsAttr>() && !value.isa<SparseElementsAttr>()) {
if (!value.isa<FloatAttr>() && !value.isa<DenseElementsAttr>() &&
!value.isa<SparseElementsAttr>()) {
return matchFailure();
}

View File

@ -66,36 +66,6 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr,
return DenseIntElementsAttr::get(newDenseType, quantValues);
}
/// Converts a real expressed SplatElementsAttr to a corresponding
/// SplatElementsAttr containing quantized storage values assuming the given
/// quantizedElementType and converter.
static SplatElementsAttr
convertSplatElementsAttr(SplatElementsAttr realSplatAttr,
QuantizedType quantizedElementType,
const UniformQuantizedValueConverter &converter) {
// Since the splat just references a single primitive value, use the
// function for converting primitives.
// NOTE: When implementing per-channel, we will need to promote the
// splat to a dense and handle channels individually.
Type unusedPrimitiveType;
auto elementAttr =
convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType,
converter, unusedPrimitiveType);
if (!elementAttr) {
return nullptr;
}
// Cast from an expressed-type-based type to storage-type-based type,
// preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>).
ShapedType newSplatType =
quantizedElementType.castExpressedToStorageType(realSplatAttr.getType())
.dyn_cast_or_null<ShapedType>();
if (!newSplatType) {
return nullptr;
}
return SplatElementsAttr::get(newSplatType, elementAttr);
}
/// Converts a real expressed SplatElementsAttr to a corresponding
/// SplatElementsAttr containing quantized storage values assuming the given
/// quantizedElementType and converter.
@ -134,13 +104,7 @@ Attribute quantizeAttrUniform(Attribute realValue,
const UniformQuantizedValueConverter &converter,
Type &outConvertedType) {
// Fork to handle different variants of constants supported.
if (realValue.isa<SplatElementsAttr>()) {
// Splatted tensor or vector constant.
auto converted = convertSplatElementsAttr(
realValue.cast<SplatElementsAttr>(), quantizedElementType, converter);
outConvertedType = converted.getType();
return converted;
} else if (realValue.isa<DenseFPElementsAttr>()) {
if (realValue.isa<DenseFPElementsAttr>()) {
// Dense tensor or vector constant.
auto converted = convertDenseFPElementsAttr(
realValue.cast<DenseFPElementsAttr>(), quantizedElementType, converter);

View File

@ -940,28 +940,6 @@ Attribute SparseElementsAttr::getValue(ArrayRef<uint64_t> index) const {
return getValues().getValue(it->second);
}
//===----------------------------------------------------------------------===//
// SplatElementsAttr
//===----------------------------------------------------------------------===//
SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) {
return DenseElementsAttr::get(type, elt).cast<SplatElementsAttr>();
}
SplatElementsAttr SplatElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APInt &)> mapping) const {
return DenseElementsAttr::mapValues(newElementType, mapping)
.cast<SplatElementsAttr>();
}
SplatElementsAttr SplatElementsAttr::mapValues(
Type newElementType,
llvm::function_ref<APInt(const APFloat &)> mapping) const {
return DenseElementsAttr::mapValues(newElementType, mapping)
.cast<SplatElementsAttr>();
}
//===----------------------------------------------------------------------===//
// NamedAttributeList
//===----------------------------------------------------------------------===//

View File

@ -180,10 +180,6 @@ FunctionAttr Builder::getFunctionAttr(StringRef value) {
return FunctionAttr::get(value, getContext());
}
ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) {
return SplatElementsAttr::get(type, elt);
}
ElementsAttr Builder::getDenseElementsAttr(ShapedType type,
ArrayRef<Attribute> values) {
return DenseElementsAttr::get(type, values);
@ -255,7 +251,7 @@ Attribute Builder::getZeroAttr(Type type) {
auto element = getZeroAttr(vtType.getElementType());
if (!element)
return {};
return getSplatElementsAttr(vtType, element);
return getDenseElementsAttr(vtType, element);
}
default:
break;

View File

@ -182,11 +182,11 @@ Attribute constFoldBinaryOp(ArrayRef<Attribute> operands,
return {};
auto elementResult = constFoldBinaryOp<AttrElementT>(
{lhs.getValue(), rhs.getValue()}, calculate);
{lhs.getSplatValue(), rhs.getSplatValue()}, calculate);
if (!elementResult)
return {};
return SplatElementsAttr::get(lhs.getType(), elementResult);
return DenseElementsAttr::get(lhs.getType(), elementResult);
}
return {};
}
@ -1614,7 +1614,7 @@ OpFoldResult ExtractElementOp::fold(ArrayRef<Attribute> operands) {
// If this is a splat elements attribute, simply return the value. All of the
// elements of a splat attribute are the same.
if (auto splatAggregate = aggregate.dyn_cast<SplatElementsAttr>())
return splatAggregate.getValue();
return splatAggregate.getSplatValue();
// Otherwise, collect the constant indices into the aggregate.
SmallVector<uint64_t, 8> indices;

View File

@ -92,18 +92,11 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
auto *child = getLLVMConstant(vectorType->getElementType(),
splatAttr.getValue(), loc);
splatAttr.getSplatValue(), loc);
return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
}
if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
auto *vectorType = cast<llvm::VectorType>(llvmType);
if (denseAttr.isSplat()) {
auto *child = getLLVMConstant(vectorType->getElementType(),
denseAttr.getSplatValue(), loc);
return llvm::ConstantVector::getSplat(vectorType->getNumElements(),
child);
}
SmallVector<llvm::Constant *, 8> constants;
uint64_t numElements = vectorType->getNumElements();
constants.reserve(numElements);

View File

@ -388,7 +388,7 @@ materializeAttributes(Operation *opInst, VectorType hwVectorType) {
SmallVector<NamedAttribute, 1> res;
for (auto a : opInst->getAttrs()) {
if (auto splat = a.second.dyn_cast<SplatElementsAttr>()) {
auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue());
auto attr = SplatElementsAttr::get(hwVectorType, splat.getSplatValue());
res.push_back(NamedAttribute(a.first, attr));
} else {
res.push_back(a);

View File

@ -128,7 +128,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) {
IntegerType convertedType = IntegerType::get(8, &ctx);
auto quantizedType = getTestQuantizedType(convertedType, &ctx);
TestUniformQuantizedValueConverter converter(quantizedType);
auto realValue = getTestElementsAttr<SplatElementsAttr, Attribute>(
auto realValue = getTestElementsAttr<DenseElementsAttr, Attribute>(
&ctx, {1, 2}, getTestFloatAttr(1.0, &ctx));
Type returnedType;