From 30bbd910565a1319bf121b0ef87031b8217cf1c2 Mon Sep 17 00:00:00 2001 From: River Riddle Date: Tue, 18 Jun 2019 18:26:26 -0700 Subject: [PATCH] Simplify usages of SplatElementsAttr now that it inherits from DenseElementsAttr. PiperOrigin-RevId: 253910543 --- mlir/include/mlir/IR/Attributes.h | 19 ---------- mlir/include/mlir/IR/Builders.h | 1 - mlir/include/mlir/IR/Matchers.h | 2 +- .../mlir/Quantizer/Support/Statistics.h | 1 - .../QuantOps/Transforms/ConvertConst.cpp | 4 +- .../Dialect/QuantOps/Utils/QuantizeUtils.cpp | 38 +------------------ mlir/lib/IR/Attributes.cpp | 22 ----------- mlir/lib/IR/Builders.cpp | 6 +-- mlir/lib/StandardOps/Ops.cpp | 6 +-- mlir/lib/Target/LLVMIR/ModuleTranslation.cpp | 9 +---- mlir/lib/Transforms/MaterializeVectors.cpp | 2 +- .../QuantOps/QuantizationUtilsTest.cpp | 2 +- 12 files changed, 11 insertions(+), 101 deletions(-) diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 76e57e527e4b..04ba52b8c65a 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -820,25 +820,6 @@ public: class SplatElementsAttr : public DenseElementsAttr { public: using DenseElementsAttr::DenseElementsAttr; - using ValueType = Attribute; - - /// 'type' must be a vector or tensor with static shape. - static SplatElementsAttr get(ShapedType type, Attribute elt); - Attribute getValue() const { return getSplatValue(); } - - /// Generates a new SplatElementsAttr by mapping each int value to a new - /// underlying APInt. The new values can represent either a integer or float. - /// This ElementsAttr should contain integers. - SplatElementsAttr - mapValues(Type newElementType, - llvm::function_ref mapping) const; - - /// Generates a new SplatElementsAttr by mapping each float value to a new - /// underlying APInt. The new values can represent either a integer or float. - /// This ElementsAttr should contain floats. - SplatElementsAttr - mapValues(Type newElementType, - llvm::function_ref mapping) const; /// Method for support type inquiry through isa, cast and dyn_cast. static bool classof(Attribute attr) { diff --git a/mlir/include/mlir/IR/Builders.h b/mlir/include/mlir/IR/Builders.h index 80bdbc35cf4c..c04ca7ad214c 100644 --- a/mlir/include/mlir/IR/Builders.h +++ b/mlir/include/mlir/IR/Builders.h @@ -113,7 +113,6 @@ public: TypeAttr getTypeAttr(Type type); FunctionAttr getFunctionAttr(Function *value); FunctionAttr getFunctionAttr(StringRef value); - ElementsAttr getSplatElementsAttr(ShapedType type, Attribute elt); ElementsAttr getDenseElementsAttr(ShapedType type, ArrayRef values); ElementsAttr getDenseIntElementsAttr(ShapedType type, diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index 61796ff09ab1..4ea1ce2c6210 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -104,7 +104,7 @@ struct constant_int_op_binder { if (type.isa() || type.isa()) { if (auto splatAttr = attr.dyn_cast()) { return attr_value_binder(bind_value) - .match(splatAttr.getValue()); + .match(splatAttr.getSplatValue()); } } return false; diff --git a/mlir/include/mlir/Quantizer/Support/Statistics.h b/mlir/include/mlir/Quantizer/Support/Statistics.h index d4641d66cf2e..c6f059efd796 100644 --- a/mlir/include/mlir/Quantizer/Support/Statistics.h +++ b/mlir/include/mlir/Quantizer/Support/Statistics.h @@ -73,7 +73,6 @@ public: /// DenseFPElementsAttr /// OpaqueElementsAttr (with Float based type) /// SparseElementAttr (with Float based type) -/// SplatElementsAttr class AttributeTensorStatistics : public AbstractTensorStatistics { public: AttributeTensorStatistics(Attribute attr) : attr(attr) {} diff --git a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp index 0c8ba3171aa3..9dcc6df6beac 100644 --- a/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp +++ b/mlir/lib/Dialect/QuantOps/Transforms/ConvertConst.cpp @@ -80,8 +80,8 @@ QuantizedConstRewrite::matchAndRewrite(QuantizeCastOp qbarrier, } // Is the constant value a type expressed in a way that we support? - if (!value.isa() && !value.isa() && - !value.isa() && !value.isa()) { + if (!value.isa() && !value.isa() && + !value.isa()) { return matchFailure(); } diff --git a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp index 850f1224cbb9..7cfedf9412df 100644 --- a/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp +++ b/mlir/lib/Dialect/QuantOps/Utils/QuantizeUtils.cpp @@ -66,36 +66,6 @@ convertDenseFPElementsAttr(DenseFPElementsAttr realFPElementsAttr, return DenseIntElementsAttr::get(newDenseType, quantValues); } -/// Converts a real expressed SplatElementsAttr to a corresponding -/// SplatElementsAttr containing quantized storage values assuming the given -/// quantizedElementType and converter. -static SplatElementsAttr -convertSplatElementsAttr(SplatElementsAttr realSplatAttr, - QuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter) { - // Since the splat just references a single primitive value, use the - // function for converting primitives. - // NOTE: When implementing per-channel, we will need to promote the - // splat to a dense and handle channels individually. - Type unusedPrimitiveType; - auto elementAttr = - convertPrimitiveValueAttr(realSplatAttr.getValue(), quantizedElementType, - converter, unusedPrimitiveType); - if (!elementAttr) { - return nullptr; - } - - // Cast from an expressed-type-based type to storage-type-based type, - // preserving the splat shape (i.e. tensor<4xf32> -> tensor<4xi8>). - ShapedType newSplatType = - quantizedElementType.castExpressedToStorageType(realSplatAttr.getType()) - .dyn_cast_or_null(); - if (!newSplatType) { - return nullptr; - } - return SplatElementsAttr::get(newSplatType, elementAttr); -} - /// Converts a real expressed SplatElementsAttr to a corresponding /// SplatElementsAttr containing quantized storage values assuming the given /// quantizedElementType and converter. @@ -134,13 +104,7 @@ Attribute quantizeAttrUniform(Attribute realValue, const UniformQuantizedValueConverter &converter, Type &outConvertedType) { // Fork to handle different variants of constants supported. - if (realValue.isa()) { - // Splatted tensor or vector constant. - auto converted = convertSplatElementsAttr( - realValue.cast(), quantizedElementType, converter); - outConvertedType = converted.getType(); - return converted; - } else if (realValue.isa()) { + if (realValue.isa()) { // Dense tensor or vector constant. auto converted = convertDenseFPElementsAttr( realValue.cast(), quantizedElementType, converter); diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index ce33508830c5..f4a6cf11bca5 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -940,28 +940,6 @@ Attribute SparseElementsAttr::getValue(ArrayRef index) const { return getValues().getValue(it->second); } -//===----------------------------------------------------------------------===// -// SplatElementsAttr -//===----------------------------------------------------------------------===// - -SplatElementsAttr SplatElementsAttr::get(ShapedType type, Attribute elt) { - return DenseElementsAttr::get(type, elt).cast(); -} - -SplatElementsAttr SplatElementsAttr::mapValues( - Type newElementType, - llvm::function_ref mapping) const { - return DenseElementsAttr::mapValues(newElementType, mapping) - .cast(); -} - -SplatElementsAttr SplatElementsAttr::mapValues( - Type newElementType, - llvm::function_ref mapping) const { - return DenseElementsAttr::mapValues(newElementType, mapping) - .cast(); -} - //===----------------------------------------------------------------------===// // NamedAttributeList //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/Builders.cpp b/mlir/lib/IR/Builders.cpp index 72eaa91211ea..e2c3a55421b6 100644 --- a/mlir/lib/IR/Builders.cpp +++ b/mlir/lib/IR/Builders.cpp @@ -180,10 +180,6 @@ FunctionAttr Builder::getFunctionAttr(StringRef value) { return FunctionAttr::get(value, getContext()); } -ElementsAttr Builder::getSplatElementsAttr(ShapedType type, Attribute elt) { - return SplatElementsAttr::get(type, elt); -} - ElementsAttr Builder::getDenseElementsAttr(ShapedType type, ArrayRef values) { return DenseElementsAttr::get(type, values); @@ -255,7 +251,7 @@ Attribute Builder::getZeroAttr(Type type) { auto element = getZeroAttr(vtType.getElementType()); if (!element) return {}; - return getSplatElementsAttr(vtType, element); + return getDenseElementsAttr(vtType, element); } default: break; diff --git a/mlir/lib/StandardOps/Ops.cpp b/mlir/lib/StandardOps/Ops.cpp index 7e4ce29d1469..9a4a3f26d65b 100644 --- a/mlir/lib/StandardOps/Ops.cpp +++ b/mlir/lib/StandardOps/Ops.cpp @@ -182,11 +182,11 @@ Attribute constFoldBinaryOp(ArrayRef operands, return {}; auto elementResult = constFoldBinaryOp( - {lhs.getValue(), rhs.getValue()}, calculate); + {lhs.getSplatValue(), rhs.getSplatValue()}, calculate); if (!elementResult) return {}; - return SplatElementsAttr::get(lhs.getType(), elementResult); + return DenseElementsAttr::get(lhs.getType(), elementResult); } return {}; } @@ -1614,7 +1614,7 @@ OpFoldResult ExtractElementOp::fold(ArrayRef operands) { // If this is a splat elements attribute, simply return the value. All of the // elements of a splat attribute are the same. if (auto splatAggregate = aggregate.dyn_cast()) - return splatAggregate.getValue(); + return splatAggregate.getSplatValue(); // Otherwise, collect the constant indices into the aggregate. SmallVector indices; diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp index ef9cbe82eb2d..36d04a9ae6c4 100644 --- a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp @@ -92,18 +92,11 @@ llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType, if (auto splatAttr = attr.dyn_cast()) { auto *vectorType = cast(llvmType); auto *child = getLLVMConstant(vectorType->getElementType(), - splatAttr.getValue(), loc); + splatAttr.getSplatValue(), loc); return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child); } if (auto denseAttr = attr.dyn_cast()) { auto *vectorType = cast(llvmType); - if (denseAttr.isSplat()) { - auto *child = getLLVMConstant(vectorType->getElementType(), - denseAttr.getSplatValue(), loc); - return llvm::ConstantVector::getSplat(vectorType->getNumElements(), - child); - } - SmallVector constants; uint64_t numElements = vectorType->getNumElements(); constants.reserve(numElements); diff --git a/mlir/lib/Transforms/MaterializeVectors.cpp b/mlir/lib/Transforms/MaterializeVectors.cpp index 6819a4ef62a2..2204c42dec12 100644 --- a/mlir/lib/Transforms/MaterializeVectors.cpp +++ b/mlir/lib/Transforms/MaterializeVectors.cpp @@ -388,7 +388,7 @@ materializeAttributes(Operation *opInst, VectorType hwVectorType) { SmallVector res; for (auto a : opInst->getAttrs()) { if (auto splat = a.second.dyn_cast()) { - auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue()); + auto attr = SplatElementsAttr::get(hwVectorType, splat.getSplatValue()); res.push_back(NamedAttribute(a.first, attr)); } else { res.push_back(a); diff --git a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp index d2b551f02966..d10623e3d1da 100644 --- a/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp +++ b/mlir/unittests/Dialect/QuantOps/QuantizationUtilsTest.cpp @@ -128,7 +128,7 @@ TEST(QuantizationUtilsTest, convertRankedSplatAttrUniform) { IntegerType convertedType = IntegerType::get(8, &ctx); auto quantizedType = getTestQuantizedType(convertedType, &ctx); TestUniformQuantizedValueConverter converter(quantizedType); - auto realValue = getTestElementsAttr( + auto realValue = getTestElementsAttr( &ctx, {1, 2}, getTestFloatAttr(1.0, &ctx)); Type returnedType;