forked from OSchip/llvm-project
[mlir] Remove the non-templated DenseElementsAttr::getSplatValue
This predates the templated variant, and has been simply forwarding to getSplatValue<Attribute> for some time. Removing this makes the API a bit more uniform, and also helps prevent users from thinking it is "cheap".
This commit is contained in:
parent
7480efd6f0
commit
937e40a8cf
|
@ -353,7 +353,6 @@ public:
|
|||
|
||||
/// Return the splat value for this attribute. This asserts that the attribute
|
||||
/// corresponds to a splat.
|
||||
Attribute getSplatValue() const { return getSplatValue<Attribute>(); }
|
||||
template <typename T>
|
||||
typename std::enable_if<!std::is_base_of<Attribute, T>::value ||
|
||||
std::is_same<Attribute, T>::value,
|
||||
|
|
|
@ -110,7 +110,7 @@ struct constant_int_op_binder {
|
|||
if (type.isa<VectorType, RankedTensorType>()) {
|
||||
if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
|
||||
return attr_value_binder<IntegerAttr>(bind_value)
|
||||
.match(splatAttr.getSplatValue());
|
||||
.match(splatAttr.getSplatValue<Attribute>());
|
||||
}
|
||||
}
|
||||
return false;
|
||||
|
|
|
@ -451,7 +451,7 @@ struct GlobalMemrefOpLowering
|
|||
// For scalar memrefs, the global variable created is of the element type,
|
||||
// so unpack the elements attribute to extract the value.
|
||||
if (type.getRank() == 0)
|
||||
initialValue = elementsAttr.getValues<Attribute>()[0];
|
||||
initialValue = elementsAttr.getSplatValue<Attribute>();
|
||||
}
|
||||
|
||||
uint64_t alignment = global.alignment().getValueOr(0);
|
||||
|
|
|
@ -349,7 +349,8 @@ static void convertConstantOp(arith::ConstantOp op,
|
|||
llvm::DenseMap<Value, Value> &valueMapping) {
|
||||
assert(constantSupportsMMAMatrixType(op));
|
||||
OpBuilder b(op);
|
||||
Attribute splat = op.getValue().cast<SplatElementsAttr>().getSplatValue();
|
||||
Attribute splat =
|
||||
op.getValue().cast<SplatElementsAttr>().getSplatValue<Attribute>();
|
||||
auto scalarConstant =
|
||||
b.create<arith::ConstantOp>(op.getLoc(), splat.getType(), splat);
|
||||
const char *fragType = inferFragType(op);
|
||||
|
|
|
@ -1574,7 +1574,7 @@ static bool isZeroAttribute(Attribute value) {
|
|||
if (auto fpValue = value.dyn_cast<FloatAttr>())
|
||||
return fpValue.getValue().isZero();
|
||||
if (auto splatValue = value.dyn_cast<SplatElementsAttr>())
|
||||
return isZeroAttribute(splatValue.getSplatValue());
|
||||
return isZeroAttribute(splatValue.getSplatValue<Attribute>());
|
||||
if (auto elementsValue = value.dyn_cast<ElementsAttr>())
|
||||
return llvm::all_of(elementsValue.getValues<Attribute>(), isZeroAttribute);
|
||||
if (auto arrayValue = value.dyn_cast<ArrayAttr>())
|
||||
|
|
|
@ -1395,7 +1395,7 @@ OpFoldResult BroadcastOp::fold(ArrayRef<Attribute> operands) {
|
|||
if (operands[0].getType().isIntOrIndexOrFloat())
|
||||
return DenseElementsAttr::get(vectorType, operands[0]);
|
||||
if (auto attr = operands[0].dyn_cast<SplatElementsAttr>())
|
||||
return DenseElementsAttr::get(vectorType, attr.getSplatValue());
|
||||
return DenseElementsAttr::get(vectorType, attr.getSplatValue<Attribute>());
|
||||
return {};
|
||||
}
|
||||
|
||||
|
@ -2212,7 +2212,7 @@ public:
|
|||
if (!dense)
|
||||
return failure();
|
||||
auto newAttr = DenseElementsAttr::get(extractStridedSliceOp.getType(),
|
||||
dense.getSplatValue());
|
||||
dense.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(extractStridedSliceOp,
|
||||
newAttr);
|
||||
return success();
|
||||
|
@ -3670,8 +3670,9 @@ public:
|
|||
auto dense = constantOp.getValue().dyn_cast<SplatElementsAttr>();
|
||||
if (!dense)
|
||||
return failure();
|
||||
auto newAttr = DenseElementsAttr::get(
|
||||
shapeCastOp.getType().cast<VectorType>(), dense.getSplatValue());
|
||||
auto newAttr =
|
||||
DenseElementsAttr::get(shapeCastOp.getType().cast<VectorType>(),
|
||||
dense.getSplatValue<Attribute>());
|
||||
rewriter.replaceOpWithNewOp<arith::ConstantOp>(shapeCastOp, newAttr);
|
||||
return success();
|
||||
}
|
||||
|
|
|
@ -139,7 +139,7 @@ convertDenseElementsAttr(Location loc, DenseElementsAttr denseElementsAttr,
|
|||
if (denseElementsAttr.isSplat() &&
|
||||
(type.isa<VectorType>() || hasVectorElementType)) {
|
||||
llvm::Constant *splatValue = LLVM::detail::getLLVMConstant(
|
||||
innermostLLVMType, denseElementsAttr.getSplatValue(), loc,
|
||||
innermostLLVMType, denseElementsAttr.getSplatValue<Attribute>(), loc,
|
||||
moduleTranslation, /*isTopLevel=*/false);
|
||||
llvm::Constant *splatVector =
|
||||
llvm::ConstantDataVector::getSplat(0, splatValue);
|
||||
|
@ -254,8 +254,9 @@ llvm::Constant *mlir::LLVM::detail::getLLVMConstant(
|
|||
isa<llvm::ArrayType, llvm::VectorType>(elementType);
|
||||
llvm::Constant *child = getLLVMConstant(
|
||||
elementType,
|
||||
elementTypeSequential ? splatAttr : splatAttr.getSplatValue(), loc,
|
||||
moduleTranslation, false);
|
||||
elementTypeSequential ? splatAttr
|
||||
: splatAttr.getSplatValue<Attribute>(),
|
||||
loc, moduleTranslation, false);
|
||||
if (!child)
|
||||
return nullptr;
|
||||
if (llvmType->isVectorTy())
|
||||
|
|
Loading…
Reference in New Issue