diff --git a/mlir/include/mlir/IR/Attributes.h b/mlir/include/mlir/IR/Attributes.h index 7c3039742c83..9496a35eae2d 100644 --- a/mlir/include/mlir/IR/Attributes.h +++ b/mlir/include/mlir/IR/Attributes.h @@ -144,14 +144,17 @@ public: class IntegerAttr : public Attribute { public: using ImplType = detail::IntegerAttributeStorage; - using ValueType = int64_t; + using ValueType = APInt; IntegerAttr() = default; /* implicit */ IntegerAttr(Attribute::ImplType *ptr); static IntegerAttr get(int64_t value, MLIRContext *context); + static IntegerAttr get(const APInt &value, MLIRContext *context); - int64_t getValue() const; + APInt getValue() const; + // TODO(jpienaar): Change callers to use getValue instead. + int64_t getInt() const; /// Methods for support type inquiry through isa, cast, and dyn_cast. static bool kindof(Kind kind) { return kind == Kind::Integer; } diff --git a/mlir/include/mlir/IR/BuiltinOps.h b/mlir/include/mlir/IR/BuiltinOps.h index 5d810a91e4e7..43607f9db213 100644 --- a/mlir/include/mlir/IR/BuiltinOps.h +++ b/mlir/include/mlir/IR/BuiltinOps.h @@ -153,7 +153,7 @@ public: Type type); int64_t getValue() const { - return getAttrOfType("value").getValue(); + return getAttrOfType("value").getInt(); } static bool isClassFor(const Operation *op); @@ -174,7 +174,7 @@ public: static void build(Builder *builder, OperationState *result, int64_t value); int64_t getValue() const { - return getAttrOfType("value").getValue(); + return getAttrOfType("value").getInt(); } static bool isClassFor(const Operation *op); diff --git a/mlir/include/mlir/IR/Matchers.h b/mlir/include/mlir/IR/Matchers.h index ad97dd2865cb..1b6ccb25b643 100644 --- a/mlir/include/mlir/IR/Matchers.h +++ b/mlir/include/mlir/IR/Matchers.h @@ -90,10 +90,9 @@ struct constant_int_op_binder { // The matcher that matches a given target constant scalar / vector splat / // tensor splat integer value. -template -struct constant_int_value_matcher { +template struct constant_int_value_matcher { bool match(Operation *op) { - IntegerAttr::ValueType value; + APInt value; return constant_int_op_binder(&value).match(op) && TargetValue == value; } diff --git a/mlir/include/mlir/IR/op_base.td b/mlir/include/mlir/IR/op_base.td index 832d1f1863aa..0c6b45b6adf7 100644 --- a/mlir/include/mlir/IR/op_base.td +++ b/mlir/include/mlir/IR/op_base.td @@ -85,6 +85,7 @@ class F32Attr : Attr { class I32Attr : Attr> { let storageType = [{ IntegerAttr }]; let returnType = [{ int }]; + let convertFromStorage = [{ return {0}.getSExtValue(); }]; } // Class representing a Trait (defined in a C++ file that needs to be included diff --git a/mlir/include/mlir/StandardOps/StandardOps.h b/mlir/include/mlir/StandardOps/StandardOps.h index 173e7c88ca5c..44ab4d7f8f48 100644 --- a/mlir/include/mlir/StandardOps/StandardOps.h +++ b/mlir/include/mlir/StandardOps/StandardOps.h @@ -226,7 +226,7 @@ class CmpIOp : public Op(getPredicateAttrName()) - .getValue(); + .getInt(); } static StringRef getOperationName() { return "cmpi"; } @@ -293,8 +293,7 @@ public: /// This returns the dimension number that the 'dim' is inspecting. unsigned getIndex() const { - return static_cast( - getAttrOfType("index").getValue()); + return getAttrOfType("index").getValue().getZExtValue(); } static StringRef getOperationName() { return "dim"; } diff --git a/mlir/lib/IR/AffineMap.cpp b/mlir/lib/IR/AffineMap.cpp index 8acea85f01ad..27ee5ff80efa 100644 --- a/mlir/lib/IR/AffineMap.cpp +++ b/mlir/lib/IR/AffineMap.cpp @@ -68,6 +68,7 @@ public: } private: + // TODO: Change these to operate on APInts too. IntegerAttr constantFoldBinExpr(AffineExpr expr, std::function op) { @@ -76,8 +77,7 @@ private: auto rhs = constantFold(binOpExpr.getRHS()); if (!lhs || !rhs) return nullptr; - return IntegerAttr::get(op(lhs.getValue(), rhs.getValue()), - expr.getContext()); + return IntegerAttr::get(op(lhs.getInt(), rhs.getInt()), expr.getContext()); } // The number of dimension operands in AffineMap containing this expression. diff --git a/mlir/lib/IR/AttributeDetail.h b/mlir/lib/IR/AttributeDetail.h index 63ad544fa482..bf0baef503be 100644 --- a/mlir/lib/IR/AttributeDetail.h +++ b/mlir/lib/IR/AttributeDetail.h @@ -46,8 +46,17 @@ struct BoolAttributeStorage : public AttributeStorage { }; /// An attribute representing a integral value. -struct IntegerAttributeStorage : public AttributeStorage { - int64_t value; +struct IntegerAttributeStorage final + : public AttributeStorage, + public llvm::TrailingObjects { + const unsigned numBits; + size_t numObjects; + + /// Returns an APInt representing the stored value. + APInt getValue() const { + auto val = APInt(numBits, {getTrailingObjects(), numObjects}); + return val; + } }; /// An attribute representing a floating point value. diff --git a/mlir/lib/IR/Attributes.cpp b/mlir/lib/IR/Attributes.cpp index 58b5b90d43d9..ee8aa250edee 100644 --- a/mlir/lib/IR/Attributes.cpp +++ b/mlir/lib/IR/Attributes.cpp @@ -37,10 +37,12 @@ bool BoolAttr::getValue() const { return static_cast(attr)->value; } IntegerAttr::IntegerAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} -int64_t IntegerAttr::getValue() const { - return static_cast(attr)->value; +APInt IntegerAttr::getValue() const { + return static_cast(attr)->getValue(); } +int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); } + FloatAttr::FloatAttr(Attribute::ImplType *ptr) : Attribute(ptr) {} APFloat FloatAttr::getValue() const { diff --git a/mlir/lib/IR/MLIRContext.cpp b/mlir/lib/IR/MLIRContext.cpp index 0b461d198139..a431b0cae049 100644 --- a/mlir/lib/IR/MLIRContext.cpp +++ b/mlir/lib/IR/MLIRContext.cpp @@ -187,6 +187,21 @@ struct FloatAttrKeyInfo : DenseMapInfo { } }; +struct IntegerAttrKeyInfo : DenseMapInfo { + // Integer attributes are uniqued based on wrapped APInt. + using KeyTy = APInt; + using DenseMapInfo::getHashValue; + using DenseMapInfo::isEqual; + + static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); } + + static bool isEqual(const KeyTy &lhs, const IntegerAttributeStorage *rhs) { + if (rhs == getEmptyKey() || rhs == getTombstoneKey()) + return false; + return lhs == rhs->getValue(); + } +}; + struct ArrayAttrKeyInfo : DenseMapInfo { // Array attributes are uniqued based on their elements. using KeyTy = ArrayRef; @@ -377,7 +392,7 @@ public: // Attribute uniquing. BoolAttributeStorage *boolAttrs[2] = {nullptr}; - DenseMap integerAttrs; + DenseSet integerAttrs; DenseSet floatAttrs; StringMap stringAttrs; using ArrayAttrSet = DenseSet; @@ -1040,17 +1055,37 @@ BoolAttr BoolAttr::get(bool value, MLIRContext *context) { return result; } -IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) { - auto *&result = context->getImpl().integerAttrs[value]; - if (result) - return result; +IntegerAttr IntegerAttr::get(const APInt &value, MLIRContext *context) { + auto &impl = context->getImpl(); - result = context->getImpl().allocator.Allocate(); - new (result) IntegerAttributeStorage{{Attribute::Kind::Integer, - /*isOrContainsFunction=*/false}, - value}; - result->value = value; - return result; + // Look to see if the integer attribute has been created already. + auto existing = impl.integerAttrs.insert_as(nullptr, value); + + // If it has been created, return it. + if (!existing.second) + return *existing.first; + + // If it doesn't, create one and return it. + auto elements = ArrayRef(value.getRawData(), value.getNumWords()); + + auto byteSize = + IntegerAttributeStorage::totalSizeToAlloc(elements.size()); + auto rawMem = + impl.allocator.Allocate(byteSize, alignof(IntegerAttributeStorage)); + // TODO: This uses 64 bit APInts by default without consideration of value. + auto result = ::new (rawMem) IntegerAttributeStorage{ + {Attribute::Kind::Integer, /*isOrContainsFunction=*/false}, + {}, + /*numBits*/ 64, + elements.size()}; + std::uninitialized_copy(elements.begin(), elements.end(), + result->getTrailingObjects()); + return *existing.first = result; +} + +IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) { + // TODO: This uses 64 bit APInts by default. + return get(APInt(64, value, /*isSigned=*/true), context); } FloatAttr FloatAttr::get(double value, MLIRContext *context) { diff --git a/mlir/lib/Parser/Parser.cpp b/mlir/lib/Parser/Parser.cpp index 359ac5a55c22..d65b5ba97d35 100644 --- a/mlir/lib/Parser/Parser.cpp +++ b/mlir/lib/Parser/Parser.cpp @@ -700,13 +700,10 @@ TensorLiteralParser::parseElementOrList(llvm::SmallVectorImpl &dims) { if (!result.isa()) return p.emitError("expected tensor literal element has integer type"); auto value = result.cast().getValue(); - // If we couldn't successfully round trip the value, it means some bits - // are truncated and we should give up here. - llvm::APInt apint(bitsWidth, (uint64_t)value, /*isSigned=*/true); - if (apint.getSExtValue() != value) + if (value.getMinSignedBits() > bitsWidth) return p.emitError("tensor literal element has more bits than that " "specified in the type"); - addToStorage((uint64_t)value); + addToStorage(value.getSExtValue()); break; } default: diff --git a/mlir/lib/StandardOps/StandardOps.cpp b/mlir/lib/StandardOps/StandardOps.cpp index 10f62548abbf..2426dfa7d05a 100644 --- a/mlir/lib/StandardOps/StandardOps.cpp +++ b/mlir/lib/StandardOps/StandardOps.cpp @@ -538,8 +538,8 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) { void CmpIOp::print(OpAsmPrinter *p) const { *p << getOperationName() << " "; - int predicateValue = - getAttrOfType(getPredicateAttrName()).getValue(); + auto predicateValue = + getAttrOfType(getPredicateAttrName()).getInt(); assert(predicateValue >= static_cast(CmpIPredicate::FirstValidValue) && predicateValue < static_cast(CmpIPredicate::NumPredicates) && "unknown predicate index"); @@ -561,7 +561,7 @@ bool CmpIOp::verify() const { auto predicateAttr = getAttrOfType(getPredicateAttrName()); if (!predicateAttr) return emitOpError("requires an integer attribute named 'predicate'"); - auto predicate = predicateAttr.getValue(); + auto predicate = predicateAttr.getInt(); if (predicate < (int64_t)CmpIPredicate::FirstValidValue || predicate >= (int64_t)CmpIPredicate::NumPredicates) return emitOpError("'predicate' attribute value out of range"); @@ -645,7 +645,7 @@ bool DimOp::verify() const { auto indexAttr = getAttrOfType("index"); if (!indexAttr) return emitOpError("requires an integer attribute named 'index'"); - uint64_t index = (uint64_t)indexAttr.getValue(); + uint64_t index = indexAttr.getValue().getZExtValue(); auto type = getOperand()->getType(); if (auto tensorType = type.dyn_cast()) { diff --git a/mlir/lib/Transforms/Utils/Utils.cpp b/mlir/lib/Transforms/Utils/Utils.cpp index c1ba21f156bf..42b8a61f916b 100644 --- a/mlir/lib/Transforms/Utils/Utils.cpp +++ b/mlir/lib/Transforms/Utils/Utils.cpp @@ -380,11 +380,11 @@ bool mlir::constantFoldBounds(ForStmt *forStmt) { auto maxOrMin = foldedResults[0].cast().getValue(); for (unsigned i = 1; i < foldedResults.size(); i++) { auto foldedResult = foldedResults[i].cast().getValue(); - maxOrMin = lower ? std::max(maxOrMin, foldedResult) - : std::min(maxOrMin, foldedResult); + maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult) + : llvm::APIntOps::smin(maxOrMin, foldedResult); } - lower ? forStmt->setConstantLowerBound(maxOrMin) - : forStmt->setConstantUpperBound(maxOrMin); + lower ? forStmt->setConstantLowerBound(maxOrMin.getSExtValue()) + : forStmt->setConstantUpperBound(maxOrMin.getSExtValue()); // Return false on success. return false;