Switch IntegerAttr to use APInt.

Change the storage type to APInt from int64_t for IntegerAttr (following the change to APFloat storage in FloatAttr). Effectively a direct change from int64_t to 64-bit APInt throughout (the bitwidth hardcoded). This change also adds a getInt convenience method to IntegerAttr and replaces previous getValue calls with getInt calls.

While this changes updates the storage type, it does not update all constant folding calls.

PiperOrigin-RevId: 221082788
This commit is contained in:
Jacques Pienaar 2018-11-12 06:33:22 -08:00 committed by jpienaar
parent b2f77e1b8f
commit 25e6b541cd
12 changed files with 85 additions and 40 deletions

View File

@ -144,14 +144,17 @@ public:
class IntegerAttr : public Attribute {
public:
using ImplType = detail::IntegerAttributeStorage;
using ValueType = int64_t;
using ValueType = APInt;
IntegerAttr() = default;
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
static IntegerAttr get(int64_t value, MLIRContext *context);
static IntegerAttr get(const APInt &value, MLIRContext *context);
int64_t getValue() const;
APInt getValue() const;
// TODO(jpienaar): Change callers to use getValue instead.
int64_t getInt() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Integer; }

View File

@ -153,7 +153,7 @@ public:
Type type);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value").getValue();
return getAttrOfType<IntegerAttr>("value").getInt();
}
static bool isClassFor(const Operation *op);
@ -174,7 +174,7 @@ public:
static void build(Builder *builder, OperationState *result, int64_t value);
int64_t getValue() const {
return getAttrOfType<IntegerAttr>("value").getValue();
return getAttrOfType<IntegerAttr>("value").getInt();
}
static bool isClassFor(const Operation *op);

View File

@ -90,10 +90,9 @@ struct constant_int_op_binder {
// The matcher that matches a given target constant scalar / vector splat /
// tensor splat integer value.
template <IntegerAttr::ValueType TargetValue>
struct constant_int_value_matcher {
template <int64_t TargetValue> struct constant_int_value_matcher {
bool match(Operation *op) {
IntegerAttr::ValueType value;
APInt value;
return constant_int_op_binder(&value).match(op) && TargetValue == value;
}

View File

@ -85,6 +85,7 @@ class F32Attr : Attr<F32> {
class I32Attr : Attr<I<32>> {
let storageType = [{ IntegerAttr }];
let returnType = [{ int }];
let convertFromStorage = [{ return {0}.getSExtValue(); }];
}
// Class representing a Trait (defined in a C++ file that needs to be included

View File

@ -226,7 +226,7 @@ class CmpIOp : public Op<CmpIOp, OpTrait::OperandsAreIntegerLike,
public:
CmpIPredicate getPredicate() const {
return (CmpIPredicate)getAttrOfType<IntegerAttr>(getPredicateAttrName())
.getValue();
.getInt();
}
static StringRef getOperationName() { return "cmpi"; }
@ -293,8 +293,7 @@ public:
/// This returns the dimension number that the 'dim' is inspecting.
unsigned getIndex() const {
return static_cast<unsigned>(
getAttrOfType<IntegerAttr>("index").getValue());
return getAttrOfType<IntegerAttr>("index").getValue().getZExtValue();
}
static StringRef getOperationName() { return "dim"; }

View File

@ -68,6 +68,7 @@ public:
}
private:
// TODO: Change these to operate on APInts too.
IntegerAttr
constantFoldBinExpr(AffineExpr expr,
std::function<uint64_t(int64_t, uint64_t)> op) {
@ -76,8 +77,7 @@ private:
auto rhs = constantFold(binOpExpr.getRHS());
if (!lhs || !rhs)
return nullptr;
return IntegerAttr::get(op(lhs.getValue(), rhs.getValue()),
expr.getContext());
return IntegerAttr::get(op(lhs.getInt(), rhs.getInt()), expr.getContext());
}
// The number of dimension operands in AffineMap containing this expression.

View File

@ -46,8 +46,17 @@ struct BoolAttributeStorage : public AttributeStorage {
};
/// An attribute representing a integral value.
struct IntegerAttributeStorage : public AttributeStorage {
int64_t value;
struct IntegerAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
const unsigned numBits;
size_t numObjects;
/// Returns an APInt representing the stored value.
APInt getValue() const {
auto val = APInt(numBits, {getTrailingObjects<uint64_t>(), numObjects});
return val;
}
};
/// An attribute representing a floating point value.

View File

@ -37,10 +37,12 @@ bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
IntegerAttr::IntegerAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
int64_t IntegerAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
APInt IntegerAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
}
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
FloatAttr::FloatAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
APFloat FloatAttr::getValue() const {

View File

@ -187,6 +187,21 @@ struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
}
};
struct IntegerAttrKeyInfo : DenseMapInfo<IntegerAttributeStorage *> {
// Integer attributes are uniqued based on wrapped APInt.
using KeyTy = APInt;
using DenseMapInfo<IntegerAttributeStorage *>::getHashValue;
using DenseMapInfo<IntegerAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
static bool isEqual(const KeyTy &lhs, const IntegerAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == rhs->getValue();
}
};
struct ArrayAttrKeyInfo : DenseMapInfo<ArrayAttributeStorage *> {
// Array attributes are uniqued based on their elements.
using KeyTy = ArrayRef<Attribute>;
@ -377,7 +392,7 @@ public:
// Attribute uniquing.
BoolAttributeStorage *boolAttrs[2] = {nullptr};
DenseMap<int64_t, IntegerAttributeStorage *> integerAttrs;
DenseSet<IntegerAttributeStorage *, IntegerAttrKeyInfo> integerAttrs;
DenseSet<FloatAttributeStorage *, FloatAttrKeyInfo> floatAttrs;
StringMap<StringAttributeStorage *> stringAttrs;
using ArrayAttrSet = DenseSet<ArrayAttributeStorage *, ArrayAttrKeyInfo>;
@ -1040,17 +1055,37 @@ BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
return result;
}
IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
auto *&result = context->getImpl().integerAttrs[value];
if (result)
return result;
IntegerAttr IntegerAttr::get(const APInt &value, MLIRContext *context) {
auto &impl = context->getImpl();
result = context->getImpl().allocator.Allocate<IntegerAttributeStorage>();
new (result) IntegerAttributeStorage{{Attribute::Kind::Integer,
/*isOrContainsFunction=*/false},
value};
result->value = value;
return result;
// Look to see if the integer attribute has been created already.
auto existing = impl.integerAttrs.insert_as(nullptr, value);
// If it has been created, return it.
if (!existing.second)
return *existing.first;
// If it doesn't, create one and return it.
auto elements = ArrayRef<uint64_t>(value.getRawData(), value.getNumWords());
auto byteSize =
IntegerAttributeStorage::totalSizeToAlloc<uint64_t>(elements.size());
auto rawMem =
impl.allocator.Allocate(byteSize, alignof(IntegerAttributeStorage));
// TODO: This uses 64 bit APInts by default without consideration of value.
auto result = ::new (rawMem) IntegerAttributeStorage{
{Attribute::Kind::Integer, /*isOrContainsFunction=*/false},
{},
/*numBits*/ 64,
elements.size()};
std::uninitialized_copy(elements.begin(), elements.end(),
result->getTrailingObjects<uint64_t>());
return *existing.first = result;
}
IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
// TODO: This uses 64 bit APInts by default.
return get(APInt(64, value, /*isSigned=*/true), context);
}
FloatAttr FloatAttr::get(double value, MLIRContext *context) {

View File

@ -700,13 +700,10 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl<int> &dims) {
if (!result.isa<IntegerAttr>())
return p.emitError("expected tensor literal element has integer type");
auto value = result.cast<IntegerAttr>().getValue();
// If we couldn't successfully round trip the value, it means some bits
// are truncated and we should give up here.
llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true);
if (apint.getSExtValue() != value)
if (value.getMinSignedBits() > bitsWidth)
return p.emitError("tensor literal element has more bits than that "
"specified in the type");
addToStorage((uint64_t)value);
addToStorage(value.getSExtValue());
break;
}
default:

View File

@ -538,8 +538,8 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
void CmpIOp::print(OpAsmPrinter *p) const {
*p << getOperationName() << " ";
int predicateValue =
getAttrOfType<IntegerAttr>(getPredicateAttrName()).getValue();
auto predicateValue =
getAttrOfType<IntegerAttr>(getPredicateAttrName()).getInt();
assert(predicateValue >= static_cast<int>(CmpIPredicate::FirstValidValue) &&
predicateValue < static_cast<int>(CmpIPredicate::NumPredicates) &&
"unknown predicate index");
@ -561,7 +561,7 @@ bool CmpIOp::verify() const {
auto predicateAttr = getAttrOfType<IntegerAttr>(getPredicateAttrName());
if (!predicateAttr)
return emitOpError("requires an integer attribute named 'predicate'");
auto predicate = predicateAttr.getValue();
auto predicate = predicateAttr.getInt();
if (predicate < (int64_t)CmpIPredicate::FirstValidValue ||
predicate >= (int64_t)CmpIPredicate::NumPredicates)
return emitOpError("'predicate' attribute value out of range");
@ -645,7 +645,7 @@ bool DimOp::verify() const {
auto indexAttr = getAttrOfType<IntegerAttr>("index");
if (!indexAttr)
return emitOpError("requires an integer attribute named 'index'");
uint64_t index = (uint64_t)indexAttr.getValue();
uint64_t index = indexAttr.getValue().getZExtValue();
auto type = getOperand()->getType();
if (auto tensorType = type.dyn_cast<RankedTensorType>()) {

View File

@ -380,11 +380,11 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) {
auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
for (unsigned i = 1; i < foldedResults.size(); i++) {
auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
maxOrMin = lower ? std::max(maxOrMin, foldedResult)
: std::min(maxOrMin, foldedResult);
maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
: llvm::APIntOps::smin(maxOrMin, foldedResult);
}
lower ? forStmt->setConstantLowerBound(maxOrMin)
: forStmt->setConstantUpperBound(maxOrMin);
lower ? forStmt->setConstantLowerBound(maxOrMin.getSExtValue())
: forStmt->setConstantUpperBound(maxOrMin.getSExtValue());
// Return false on success.
return false;