forked from OSchip/llvm-project
Switch IntegerAttr to use APInt.
Change the storage type to APInt from int64_t for IntegerAttr (following the change to APFloat storage in FloatAttr). Effectively a direct change from int64_t to 64-bit APInt throughout (the bitwidth hardcoded). This change also adds a getInt convenience method to IntegerAttr and replaces previous getValue calls with getInt calls. While this changes updates the storage type, it does not update all constant folding calls. PiperOrigin-RevId: 221082788
This commit is contained in:
parent
b2f77e1b8f
commit
25e6b541cd
|
@ -144,14 +144,17 @@ public:
|
|||
class IntegerAttr : public Attribute {
|
||||
public:
|
||||
using ImplType = detail::IntegerAttributeStorage;
|
||||
using ValueType = int64_t;
|
||||
using ValueType = APInt;
|
||||
|
||||
IntegerAttr() = default;
|
||||
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
|
||||
|
||||
static IntegerAttr get(int64_t value, MLIRContext *context);
|
||||
static IntegerAttr get(const APInt &value, MLIRContext *context);
|
||||
|
||||
int64_t getValue() const;
|
||||
APInt getValue() const;
|
||||
// TODO(jpienaar): Change callers to use getValue instead.
|
||||
int64_t getInt() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Integer; }
|
||||
|
|
|
@ -153,7 +153,7 @@ public:
|
|||
Type type);
|
||||
|
||||
int64_t getValue() const {
|
||||
return getAttrOfType<IntegerAttr>("value").getValue();
|
||||
return getAttrOfType<IntegerAttr>("value").getInt();
|
||||
}
|
||||
|
||||
static bool isClassFor(const Operation *op);
|
||||
|
@ -174,7 +174,7 @@ public:
|
|||
static void build(Builder *builder, OperationState *result, int64_t value);
|
||||
|
||||
int64_t getValue() const {
|
||||
return getAttrOfType<IntegerAttr>("value").getValue();
|
||||
return getAttrOfType<IntegerAttr>("value").getInt();
|
||||
}
|
||||
|
||||
static bool isClassFor(const Operation *op);
|
||||
|
|
|
@ -90,10 +90,9 @@ struct constant_int_op_binder {
|
|||
|
||||
// The matcher that matches a given target constant scalar / vector splat /
|
||||
// tensor splat integer value.
|
||||
template <IntegerAttr::ValueType TargetValue>
|
||||
struct constant_int_value_matcher {
|
||||
template <int64_t TargetValue> struct constant_int_value_matcher {
|
||||
bool match(Operation *op) {
|
||||
IntegerAttr::ValueType value;
|
||||
APInt value;
|
||||
|
||||
return constant_int_op_binder(&value).match(op) && TargetValue == value;
|
||||
}
|
||||
|
|
|
@ -85,6 +85,7 @@ class F32Attr : Attr<F32> {
|
|||
class I32Attr : Attr<I<32>> {
|
||||
let storageType = [{ IntegerAttr }];
|
||||
let returnType = [{ int }];
|
||||
let convertFromStorage = [{ return {0}.getSExtValue(); }];
|
||||
}
|
||||
|
||||
// Class representing a Trait (defined in a C++ file that needs to be included
|
||||
|
|
|
@ -226,7 +226,7 @@ class CmpIOp : public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
|
|||
public:
|
||||
CmpIPredicate getPredicate() const {
|
||||
return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
|
||||
.getValue();
|
||||
.getInt();
|
||||
}
|
||||
|
||||
static StringRef getOperationName() { return "cmpi"; }
|
||||
|
@ -293,8 +293,7 @@ public:
|
|||
|
||||
/// This returns the dimension number that the 'dim' is inspecting.
|
||||
unsigned getIndex() const {
|
||||
return static_cast<unsigned>(
|
||||
getAttrOfType<IntegerAttr>("index").getValue());
|
||||
return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
|
||||
}
|
||||
|
||||
static StringRef getOperationName() { return "dim"; }
|
||||
|
|
|
@ -68,6 +68,7 @@ public:
|
|||
}
|
||||
|
||||
private:
|
||||
// TODO: Change these to operate on APInts too.
|
||||
IntegerAttr
|
||||
constantFoldBinExpr(AffineExpr expr,
|
||||
std::function<uint64_t(int64_t, uint64_t)> op) {
|
||||
|
@ -76,8 +77,7 @@ private:
|
|||
auto rhs = constantFold(binOpExpr.getRHS());
|
||||
if (!lhs || !rhs)
|
||||
return nullptr;
|
||||
return IntegerAttr::get(op(lhs.getValue(), rhs.getValue()),
|
||||
expr.getContext());
|
||||
return IntegerAttr::get(op(lhs.getInt(), rhs.getInt()), expr.getContext());
|
||||
}
|
||||
|
||||
// The number of dimension operands in AffineMap containing this expression.
|
||||
|
|
|
@ -46,8 +46,17 @@ struct BoolAttributeStorage : public AttributeStorage {
|
|||
};
|
||||
|
||||
/// An attribute representing a integral value.
|
||||
struct IntegerAttributeStorage : public AttributeStorage {
|
||||
int64_t value;
|
||||
struct IntegerAttributeStorage final
|
||||
: public AttributeStorage,
|
||||
public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
|
||||
const unsigned numBits;
|
||||
size_t numObjects;
|
||||
|
||||
/// Returns an APInt representing the stored value.
|
||||
APInt getValue() const {
|
||||
auto val = APInt(numBits, {getTrailingObjects<uint64_t>(), numObjects});
|
||||
return val;
|
||||
}
|
||||
};
|
||||
|
||||
/// An attribute representing a floating point value.
|
||||
|
|
|
@ -37,10 +37,12 @@ bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
|||
|
||||
IntegerAttr::IntegerAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
int64_t IntegerAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->value;
|
||||
APInt IntegerAttr::getValue() const {
|
||||
return static_cast<ImplType *>(attr)->getValue();
|
||||
}
|
||||
|
||||
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
|
||||
|
||||
FloatAttr::FloatAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
|
||||
|
||||
APFloat FloatAttr::getValue() const {
|
||||
|
|
|
@ -187,6 +187,21 @@ struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
|
|||
}
|
||||
};
|
||||
|
||||
struct IntegerAttrKeyInfo : DenseMapInfo<IntegerAttributeStorage *> {
|
||||
// Integer attributes are uniqued based on wrapped APInt.
|
||||
using KeyTy = APInt;
|
||||
using DenseMapInfo<IntegerAttributeStorage *>::getHashValue;
|
||||
using DenseMapInfo<IntegerAttributeStorage *>::isEqual;
|
||||
|
||||
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
|
||||
|
||||
static bool isEqual(const KeyTy &lhs, const IntegerAttributeStorage *rhs) {
|
||||
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
|
||||
return false;
|
||||
return lhs == rhs->getValue();
|
||||
}
|
||||
};
|
||||
|
||||
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttributeStorage *> {
|
||||
// Array attributes are uniqued based on their elements.
|
||||
using KeyTy = ArrayRef<Attribute>;
|
||||
|
@ -377,7 +392,7 @@ public:
|
|||
|
||||
// Attribute uniquing.
|
||||
BoolAttributeStorage *boolAttrs[2] = {nullptr};
|
||||
DenseMap<int64_t, IntegerAttributeStorage *> integerAttrs;
|
||||
DenseSet<IntegerAttributeStorage *, IntegerAttrKeyInfo> integerAttrs;
|
||||
DenseSet<FloatAttributeStorage *, FloatAttrKeyInfo> floatAttrs;
|
||||
StringMap<StringAttributeStorage *> stringAttrs;
|
||||
using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
|
||||
|
@ -1040,17 +1055,37 @@ BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
|
|||
return result;
|
||||
}
|
||||
|
||||
IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
|
||||
auto *&result = context->getImpl().integerAttrs[value];
|
||||
if (result)
|
||||
return result;
|
||||
IntegerAttr IntegerAttr::get(const APInt &value, MLIRContext *context) {
|
||||
auto &impl = context->getImpl();
|
||||
|
||||
result = context->getImpl().allocator.Allocate<IntegerAttributeStorage>();
|
||||
new (result) IntegerAttributeStorage{{Attribute::Kind::Integer,
|
||||
/*isOrContainsFunction=*/false},
|
||||
value};
|
||||
result->value = value;
|
||||
return result;
|
||||
// Look to see if the integer attribute has been created already.
|
||||
auto existing = impl.integerAttrs.insert_as(nullptr, value);
|
||||
|
||||
// If it has been created, return it.
|
||||
if (!existing.second)
|
||||
return *existing.first;
|
||||
|
||||
// If it doesn't, create one and return it.
|
||||
auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
|
||||
|
||||
auto byteSize =
|
||||
IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
|
||||
auto rawMem =
|
||||
impl.allocator.Allocate(byteSize, alignof(IntegerAttributeStorage));
|
||||
// TODO: This uses 64 bit APInts by default without consideration of value.
|
||||
auto result = ::new (rawMem) IntegerAttributeStorage{
|
||||
{Attribute::Kind::Integer, /*isOrContainsFunction=*/false},
|
||||
{},
|
||||
/*numBits*/ 64,
|
||||
elements.size()};
|
||||
std::uninitialized_copy(elements.begin(), elements.end(),
|
||||
result->getTrailingObjects<uint64_t>());
|
||||
return *existing.first = result;
|
||||
}
|
||||
|
||||
IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
|
||||
// TODO: This uses 64 bit APInts by default.
|
||||
return get(APInt(64, value, /*isSigned=*/true), context);
|
||||
}
|
||||
|
||||
FloatAttr FloatAttr::get(double value, MLIRContext *context) {
|
||||
|
|
|
@ -700,13 +700,10 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
|
|||
if (!result.isa<IntegerAttr>())
|
||||
return p.emitError("expected tensor literal element has integer type");
|
||||
auto value = result.cast<IntegerAttr>().getValue();
|
||||
// If we couldn't successfully round trip the value, it means some bits
|
||||
// are truncated and we should give up here.
|
||||
llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true);
|
||||
if (apint.getSExtValue() != value)
|
||||
if (value.getMinSignedBits() > bitsWidth)
|
||||
return p.emitError("tensor literal element has more bits than that "
|
||||
"specified in the type");
|
||||
addToStorage((uint64_t)value);
|
||||
addToStorage(value.getSExtValue());
|
||||
break;
|
||||
}
|
||||
default:
|
||||
|
|
|
@ -538,8 +538,8 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
|
|||
void CmpIOp::print(OpAsmPrinter *p) const {
|
||||
*p << getOperationName() << " ";
|
||||
|
||||
int predicateValue =
|
||||
getAttrOfType<IntegerAttr>(getPredicateAttrName()).getValue();
|
||||
auto predicateValue =
|
||||
getAttrOfType<IntegerAttr>(getPredicateAttrName()).getInt();
|
||||
assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
|
||||
predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
|
||||
"unknown predicate index");
|
||||
|
@ -561,7 +561,7 @@ bool CmpIOp::verify() const {
|
|||
auto predicateAttr = getAttrOfType<IntegerAttr>(getPredicateAttrName());
|
||||
if (!predicateAttr)
|
||||
return emitOpError("requires an integer attribute named 'predicate'");
|
||||
auto predicate = predicateAttr.getValue();
|
||||
auto predicate = predicateAttr.getInt();
|
||||
if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
|
||||
predicate >= (int64_t)CmpIPredicate::NumPredicates)
|
||||
return emitOpError("'predicate' attribute value out of range");
|
||||
|
@ -645,7 +645,7 @@ bool DimOp::verify() const {
|
|||
auto indexAttr = getAttrOfType<IntegerAttr>("index");
|
||||
if (!indexAttr)
|
||||
return emitOpError("requires an integer attribute named 'index'");
|
||||
uint64_t index = (uint64_t)indexAttr.getValue();
|
||||
uint64_t index = indexAttr.getValue().getZExtValue();
|
||||
|
||||
auto type = getOperand()->getType();
|
||||
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {
|
||||
|
|
|
@ -380,11 +380,11 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
|
|||
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
|
||||
for (unsigned i = 1; i < foldedResults.size(); i++) {
|
||||
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
|
||||
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
|
||||
: std::min(maxOrMin, foldedResult);
|
||||
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
|
||||
: llvm::APIntOps::smin(maxOrMin, foldedResult);
|
||||
}
|
||||
lower ? forStmt->setConstantLowerBound(maxOrMin)
|
||||
: forStmt->setConstantUpperBound(maxOrMin);
|
||||
lower ? forStmt->setConstantLowerBound(maxOrMin.getSExtValue())
|
||||
: forStmt->setConstantUpperBound(maxOrMin.getSExtValue());
|
||||
|
||||
// Return false on success.
|
||||
return false;
|
||||
|
|
Loading…
Reference in New Issue