Add Type to int/float attributes.

* Optionally attach the type of integer and floating point attributes to the attributes, this allows restricting a int/float to specific width.
  - Currently this allows suffixing int/float constant with type [this might be revised in future].
  - Default to i64 and f32 if not specified.
* For index types the APInt width used is 64.
* Change callers to request a specific attribute type.
* Store iN type with APInt of width N.
* This change does not handle the folding of constants of different types (e.g., doing int type promotions to support constant folding i3 and i32), and instead restricts the constant folding to only operate on the same types.

PiperOrigin-RevId: 221722699
This commit is contained in:
Jacques Pienaar 2018-11-15 17:53:51 -08:00 committed by jpienaar
parent c7df0651d3
commit 711047c0cd
13 changed files with 230 additions and 90 deletions

View File

@ -158,13 +158,15 @@ public:
IntegerAttr() = default;
/* implicit */ IntegerAttr(Attribute::ImplType *ptr);
static IntegerAttr get(int64_t value, MLIRContext *context);
static IntegerAttr get(const APInt &value, MLIRContext *context);
static IntegerAttr get(Type type, int64_t value);
static IntegerAttr get(Type type, const APInt &value);
APInt getValue() const;
// TODO(jpienaar): Change callers to use getValue instead.
int64_t getInt() const;
Type getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Integer; }
};
@ -177,13 +179,15 @@ public:
FloatAttr() = default;
/* implicit */ FloatAttr(Attribute::ImplType *ptr);
static FloatAttr get(double value, MLIRContext *context);
static FloatAttr get(const APFloat &value, MLIRContext *context);
static FloatAttr get(Type type, double value);
static FloatAttr get(Type type, const APFloat &value);
APFloat getValue() const;
double getDouble() const;
Type getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Float; }
};

View File

@ -97,9 +97,14 @@ public:
// Attributes.
BoolAttr getBoolAttr(bool value);
IntegerAttr getIntegerAttr(Type type, int64_t value);
IntegerAttr getIntegerAttr(Type type, const APInt &value);
FloatAttr getFloatAttr(Type type, double value);
FloatAttr getFloatAttr(Type type, const APFloat &value);
// Convenience methods that assumes fixed type.
// TODO(jpienaar): remove these.
IntegerAttr getIntegerAttr(int64_t value);
FloatAttr getFloatAttr(double value);
FloatAttr getFloatAttr(const APFloat &value);
StringAttr getStringAttr(StringRef bytes);
ArrayAttr getArrayAttr(ArrayRef<Attribute> value);
AffineMapAttr getAffineMapAttr(AffineMap map);

View File

@ -167,6 +167,9 @@ public:
/// This parses... a comma!
virtual bool parseComma() = 0;
/// Parse a type.
virtual bool parseType(Type &result) = 0;
/// Parse a colon followed by a type.
virtual bool parseColonType(Type &result) = 0;
@ -213,22 +216,27 @@ public:
}
/// Parse an arbitrary attribute and return it in result. This also adds the
/// attribute to the specified attribute list with the specified name. this
/// captures the location of the attribute in 'loc' if it is non-null.
/// attribute to the specified attribute list with the specified name.
virtual bool parseAttribute(Attribute &result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an attribute of a specific kind, capturing the location into `loc`
/// if specified.
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
virtual bool parseAttribute(Attribute &result, Type type,
const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) = 0;
/// Parse an attribute of a specific kind and type.
template <typename AttrType>
bool parseAttribute(AttrType &result, const char *attrName,
bool parseAttribute(AttrType &result, Type type, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) {
llvm::SMLoc loc;
getCurrentLocation(&loc);
// Parse any kind of attribute.
Attribute attr;
if (parseAttribute(attr, attrName, attrs))
if (parseAttribute(attr, type, attrName, attrs))
return true;
// Check for the right kind of attribute.

View File

@ -19,6 +19,7 @@
#include "AffineMapDetail.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Types.h"
#include "mlir/Support/MathExtras.h"
#include "llvm/ADT/StringRef.h"
@ -55,8 +56,8 @@ public:
return constantFoldBinExpr(
expr, [](int64_t lhs, uint64_t rhs) { return ceilDiv(lhs, rhs); });
case AffineExprKind::Constant:
return IntegerAttr::get(expr.cast<AffineConstantExpr>().getValue(),
expr.getContext());
return IntegerAttr::get(Type::getIndex(expr.getContext()),
expr.cast<AffineConstantExpr>().getValue());
case AffineExprKind::DimId:
return operandConsts[expr.cast<AffineDimExpr>().getPosition()]
.dyn_cast_or_null<IntegerAttr>();
@ -77,7 +78,7 @@ private:
auto rhs = constantFold(binOpExpr.getRHS());
if (!lhs || !rhs)
return nullptr;
return IntegerAttr::get(op(lhs.getInt(), rhs.getInt()), expr.getContext());
return IntegerAttr::get(lhs.getType(), op(lhs.getInt(), rhs.getInt()));
}
// The number of dimension operands in AffineMap containing this expression.

View File

@ -412,9 +412,14 @@ void ModulePrinter::printAttribute(Attribute attr) {
case Attribute::Kind::Bool:
os << (attr.cast<BoolAttr>().getValue() ? "true" : "false");
break;
case Attribute::Kind::Integer:
os << attr.cast<IntegerAttr>().getValue();
case Attribute::Kind::Integer: {
auto intAttr = attr.cast<IntegerAttr>();
// Print all integer attributes as signed unless i1.
bool isSigned =
intAttr.getType().isIndex() || intAttr.getType().getBitWidth() != 1;
intAttr.getValue().print(os, isSigned);
break;
}
case Attribute::Kind::Float:
printFloatValue(attr.cast<FloatAttr>().getValue(), os);
break;

View File

@ -27,6 +27,7 @@
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Types.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
@ -49,13 +50,20 @@ struct BoolAttributeStorage : public AttributeStorage {
struct IntegerAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<IntegerAttributeStorage, uint64_t> {
const unsigned numBits;
IntegerAttributeStorage(AttributeStorage &&as, Type type, size_t numObjects)
: AttributeStorage(as), type(type), numObjects(numObjects) {
assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
}
const Type type;
size_t numObjects;
/// Returns an APInt representing the stored value.
APInt getValue() const {
auto val = APInt(numBits, {getTrailingObjects<uint64_t>(), numObjects});
return val;
if (type.isIndex())
return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
return APInt(type.getBitWidth(),
{getTrailingObjects<uint64_t>(), numObjects});
}
};
@ -63,7 +71,13 @@ struct IntegerAttributeStorage final
struct FloatAttributeStorage final
: public AttributeStorage,
public llvm::TrailingObjects<FloatAttributeStorage, uint64_t> {
FloatAttributeStorage(AttributeStorage &&as,
const llvm::fltSemantics &semantics, Type type,
size_t numObjects)
: AttributeStorage(as), semantics(semantics),
type(type.cast<FloatType>()), numObjects(numObjects) {}
const llvm::fltSemantics &semantics;
const FloatType type;
size_t numObjects;
/// Returns an APFloat representing the stored value.

View File

@ -76,12 +76,18 @@ APInt IntegerAttr::getValue() const {
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
Type IntegerAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
}
FloatAttr::FloatAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
APFloat FloatAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
}
Type FloatAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
double FloatAttr::getDouble() const { return getValue().convertToDouble(); }
StringAttr::StringAttr(Attribute::ImplType *ptr) : Attribute(ptr) {}
@ -200,14 +206,13 @@ uint64_t DenseIntElementsAttr::readBits(const char *rawData, size_t bitPos,
void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto bitsWidth = static_cast<ImplType *>(attr)->bitsWidth;
auto elementNum = getType().getNumElements();
auto context = getType().getContext();
values.reserve(elementNum);
if (bitsWidth == 64) {
ArrayRef<int64_t> vs(
{reinterpret_cast<const int64_t *>(getRawData().data()),
getRawData().size() / 8});
for (auto value : vs) {
auto attr = IntegerAttr::get(value, context);
auto attr = IntegerAttr::get(getType().getElementType(), value);
values.push_back(attr);
}
} else {
@ -215,7 +220,8 @@ void DenseIntElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
for (size_t pos = 0; pos < elementNum * bitsWidth; pos += bitsWidth) {
uint64_t bits = readBits(rawData, pos, bitsWidth);
APInt value(bitsWidth, bits, /*isSigned=*/true);
auto attr = IntegerAttr::get(value.getSExtValue(), context);
auto attr =
IntegerAttr::get(getType().getElementType(), value.getSExtValue());
values.push_back(attr);
}
}
@ -226,12 +232,11 @@ DenseFPElementsAttr::DenseFPElementsAttr(Attribute::ImplType *ptr)
void DenseFPElementsAttr::getValues(SmallVectorImpl<Attribute> &values) const {
auto elementNum = getType().getNumElements();
auto context = getType().getContext();
ArrayRef<double> vs({reinterpret_cast<const double *>(getRawData().data()),
getRawData().size() / 8});
values.reserve(elementNum);
for (auto v : vs) {
auto attr = FloatAttr::get(v, context);
auto attr = FloatAttr::get(getType().getElementType(), v);
values.push_back(attr);
}
}

View File

@ -120,15 +120,29 @@ BoolAttr Builder::getBoolAttr(bool value) {
}
IntegerAttr Builder::getIntegerAttr(int64_t value) {
return IntegerAttr::get(value, context);
return IntegerAttr::get(getIntegerType(64), APInt(64, value));
}
IntegerAttr Builder::getIntegerAttr(Type type, int64_t value) {
if (type.isIndex())
return IntegerAttr::get(type, APInt(64, value));
return IntegerAttr::get(type, APInt(type.getBitWidth(), value));
}
IntegerAttr Builder::getIntegerAttr(Type type, const APInt &value) {
return IntegerAttr::get(type, value);
}
FloatAttr Builder::getFloatAttr(double value) {
return FloatAttr::get(APFloat(value), context);
return FloatAttr::get(getF32Type(), APFloat(value));
}
FloatAttr Builder::getFloatAttr(const APFloat &value) {
return FloatAttr::get(value, context);
FloatAttr Builder::getFloatAttr(Type type, double value) {
return FloatAttr::get(type, APFloat(value));
}
FloatAttr Builder::getFloatAttr(Type type, const APFloat &value) {
return FloatAttr::get(type, value);
}
StringAttr Builder::getStringAttr(StringRef bytes) {

View File

@ -334,8 +334,14 @@ bool ConstantOp::parse(OpAsmParser *parser, OperationState *result) {
if (auto fnAttr = valueAttr.dyn_cast<FunctionAttr>())
return parser->addTypeToList(fnAttr.getValue()->getType(), result->types);
return parser->parseColonType(type) ||
parser->addTypeToList(type, result->types);
if (auto intAttr = valueAttr.dyn_cast<IntegerAttr>()) {
type = intAttr.getType();
} else if (auto fpAttr = valueAttr.dyn_cast<FloatAttr>()) {
type = fpAttr.getType();
} else if (parser->parseColonType(type)) {
return true;
}
return parser->addTypeToList(type, result->types);
}
/// The constant op requires an attribute, and furthermore requires that it
@ -389,7 +395,7 @@ Attribute ConstantOp::constantFold(ArrayRef<Attribute> operands,
void ConstantFloatOp::build(Builder *builder, OperationState *result,
const APFloat &value, FloatType type) {
ConstantOp::build(builder, result, builder->getFloatAttr(value), type);
ConstantOp::build(builder, result, builder->getFloatAttr(type, value), type);
}
bool ConstantFloatOp::isClassFor(const Operation *op) {
@ -405,8 +411,9 @@ bool ConstantIntOp::isClassFor(const Operation *op) {
void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, unsigned width) {
ConstantOp::build(builder, result, builder->getIntegerAttr(value),
builder->getIntegerType(width));
Type type = builder->getIntegerType(width);
ConstantOp::build(builder, result, builder->getIntegerAttr(type, value),
type);
}
/// Build a constant int op producing an integer with the specified type,
@ -414,7 +421,8 @@ void ConstantIntOp::build(Builder *builder, OperationState *result,
void ConstantIntOp::build(Builder *builder, OperationState *result,
int64_t value, Type type) {
assert(type.isa<IntegerType>() && "ConstantIntOp can only have integer type");
ConstantOp::build(builder, result, builder->getIntegerAttr(value), type);
ConstantOp::build(builder, result, builder->getIntegerAttr(type, value),
type);
}
/// ConstantIndexOp only matches values whose result type is Index.
@ -424,8 +432,9 @@ bool ConstantIndexOp::isClassFor(const Operation *op) {
void ConstantIndexOp::build(Builder *builder, OperationState *result,
int64_t value) {
ConstantOp::build(builder, result, builder->getIntegerAttr(value),
builder->getIndexType());
Type type = builder->getIndexType();
ConstantOp::build(builder, result, builder->getIntegerAttr(type, value),
type);
}
//===----------------------------------------------------------------------===//

View File

@ -174,31 +174,37 @@ struct MemRefTypeKeyInfo : DenseMapInfo<MemRefTypeStorage *> {
struct FloatAttrKeyInfo : DenseMapInfo<FloatAttributeStorage *> {
// Float attributes are uniqued based on wrapped APFloat.
using KeyTy = APFloat;
using KeyTy = std::pair<Type, APFloat>;
using DenseMapInfo<FloatAttributeStorage *>::getHashValue;
using DenseMapInfo<FloatAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
static unsigned getHashValue(KeyTy key) {
return hash_combine(key.first, llvm::hash_value(key.second));
}
static bool isEqual(const KeyTy &lhs, const FloatAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs.bitwiseIsEqual(rhs->getValue());
return lhs.first == rhs->type && lhs.second.bitwiseIsEqual(rhs->getValue());
}
};
struct IntegerAttrKeyInfo : DenseMapInfo<IntegerAttributeStorage *> {
// Integer attributes are uniqued based on wrapped APInt.
using KeyTy = APInt;
using KeyTy = std::pair<Type, APInt>;
using DenseMapInfo<IntegerAttributeStorage *>::getHashValue;
using DenseMapInfo<IntegerAttributeStorage *>::isEqual;
static unsigned getHashValue(KeyTy key) { return llvm::hash_value(key); }
static unsigned getHashValue(KeyTy key) {
return hash_combine(key.first, llvm::hash_value(key.second));
}
static bool isEqual(const KeyTy &lhs, const IntegerAttributeStorage *rhs) {
if (rhs == getEmptyKey() || rhs == getTombstoneKey())
return false;
return lhs == rhs->getValue();
assert(lhs.first.isIndex() ||
(lhs.first.getBitWidth() == lhs.second.getBitWidth()));
return lhs.first == rhs->type && lhs.second == rhs->getValue();
}
};
@ -1074,11 +1080,12 @@ BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
return result;
}
IntegerAttr IntegerAttr::get(const APInt &value, MLIRContext *context) {
auto &impl = context->getImpl();
IntegerAttr IntegerAttr::get(Type type, const APInt &value) {
auto &impl = type.getContext()->getImpl();
// Look to see if the integer attribute has been created already.
auto existing = impl.integerAttrs.insert_as(nullptr, value);
IntegerAttrKeyInfo::KeyTy key({type, value});
auto existing = impl.integerAttrs.insert_as(nullptr, key);
// If it has been created, return it.
if (!existing.second)
@ -1094,28 +1101,29 @@ IntegerAttr IntegerAttr::get(const APInt &value, MLIRContext *context) {
// TODO: This uses 64 bit APInts by default without consideration of value.
auto result = ::new (rawMem) IntegerAttributeStorage{
{Attribute::Kind::Integer, /*isOrContainsFunction=*/false},
{},
/*numBits*/ 64,
type,
elements.size()};
std::uninitialized_copy(elements.begin(), elements.end(),
result->getTrailingObjects<uint64_t>());
return *existing.first = result;
}
IntegerAttr IntegerAttr::get(int64_t value, MLIRContext *context) {
// TODO: This uses 64 bit APInts by default.
return get(APInt(64, value, /*isSigned=*/true), context);
IntegerAttr IntegerAttr::get(Type type, int64_t value) {
// This uses 64 bit APInts by default for index type.
auto width = type.isIndex() ? 64 : type.getBitWidth();
return get(type, APInt(width, value));
}
FloatAttr FloatAttr::get(double value, MLIRContext *context) {
return get(APFloat(value), context);
FloatAttr FloatAttr::get(Type type, double value) {
return get(type, APFloat(value));
}
FloatAttr FloatAttr::get(const APFloat &value, MLIRContext *context) {
auto &impl = context->getImpl();
FloatAttr FloatAttr::get(Type type, const APFloat &value) {
auto &impl = type.getContext()->getImpl();
// Look to see if the float attribute has been created already.
auto existing = impl.floatAttrs.insert_as(nullptr, value);
FloatAttrKeyInfo::KeyTy key({type, value});
auto existing = impl.floatAttrs.insert_as(nullptr, key);
// If it has been created, return it.
if (!existing.second)
@ -1132,8 +1140,8 @@ FloatAttr FloatAttr::get(const APFloat &value, MLIRContext *context) {
impl.allocator.Allocate(byteSize, alignof(FloatAttributeStorage));
auto result = ::new (rawMem) FloatAttributeStorage{
{Attribute::Kind::Float, /*isOrContainsFunction=*/false},
{},
value.getSemantics(),
type,
elements.size()};
std::uninitialized_copy(elements.begin(), elements.end(),
result->getTrailingObjects<uint64_t>());

View File

@ -196,7 +196,7 @@ public:
// Attribute parsing.
Function *resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
FunctionType type);
Attribute parseAttribute();
Attribute parseAttribute(Type type = {});
ParseResult parseAttributeDict(SmallVectorImpl<NamedAttribute> &attributes);
@ -785,8 +785,8 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
/// Attribute parsing.
///
/// attribute-value ::= bool-literal
/// | integer-literal
/// | float-literal
/// | integer-literal (`:` integer-type)
/// | float-literal (`:` float-type)
/// | string-literal
/// | type
/// | `[` (attribute-value (`,` attribute-value)*)? `]`
@ -796,7 +796,7 @@ Function *Parser::resolveFunctionReference(StringRef nameStr, SMLoc nameLoc,
/// | `sparse<` (tensor-type | vector-type)`,`
/// attribute-value`, ` attribute-value `>`
///
Attribute Parser::parseAttribute() {
Attribute Parser::parseAttribute(Type type) {
switch (getToken().getKind()) {
case Token::kw_true:
consumeToken(Token::kw_true);
@ -811,14 +811,38 @@ Attribute Parser::parseAttribute() {
return (emitError("floating point value too large for attribute"),
nullptr);
consumeToken(Token::floatliteral);
return builder.getFloatAttr(APFloat(val.getValue()));
if (!type) {
if (consumeIf(Token::colon)) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to F32 when no type is specified.
type = builder.getF32Type();
}
}
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for specified type"),
nullptr);
return builder.getFloatAttr(type, APFloat(val.getValue()));
}
case Token::integer: {
auto val = getToken().getUInt64IntegerValue();
if (!val.hasValue() || (int64_t)val.getValue() < 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
return builder.getIntegerAttr((int64_t)val.getValue());
if (!type) {
if (consumeIf(Token::colon)) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to i64 if not type is specified.
type = builder.getIntegerType(64);
}
}
if (!type.isa<IntegerType>() && !type.isa<IndexType>())
return (emitError("integer value not valid for specified type"), nullptr);
int width = type.isIndex() ? 64 : type.getBitWidth();
return builder.getIntegerAttr(type, APInt(width, val.getValue()));
}
case Token::minus: {
@ -828,7 +852,19 @@ Attribute Parser::parseAttribute() {
if (!val.hasValue() || (int64_t)-val.getValue() >= 0)
return (emitError("integer too large for attribute"), nullptr);
consumeToken(Token::integer);
return builder.getIntegerAttr((int64_t)-val.getValue());
if (!type) {
if (consumeIf(Token::colon)) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to i64 if not type is specified.
type = builder.getIntegerType(64);
}
}
if (!type.isa<IntegerType>() && !type.isa<IndexType>())
return (emitError("integer value not valid for type"), nullptr);
int width = type.isIndex() ? 64 : type.getBitWidth();
return builder.getIntegerAttr(type, -APInt(width, val.getValue()));
}
if (getToken().is(Token::floatliteral)) {
auto val = getToken().getFloatingPointValue();
@ -836,7 +872,18 @@ Attribute Parser::parseAttribute() {
return (emitError("floating point value too large for attribute"),
nullptr);
consumeToken(Token::floatliteral);
return builder.getFloatAttr(APFloat(-val.getValue()));
if (!type) {
if (consumeIf(Token::colon)) {
if (!(type = parseType()))
return nullptr;
} else {
// Default to F32 when no type is specified.
type = builder.getF32Type();
}
}
if (!type.isa<FloatType>())
return (emitError("floating point value not valid for type"), nullptr);
return builder.getFloatAttr(type, APFloat(-val.getValue()));
}
return (emitError("expected constant integer or floating point value"),
@ -926,7 +973,7 @@ Attribute Parser::parseAttribute() {
case Token::floatliteral:
case Token::integer:
case Token::minus: {
auto scalar = parseAttribute();
auto scalar = parseAttribute(type.getElementType());
if (parseToken(Token::greater, "expected '>'"))
return nullptr;
return builder.getSplatElementsAttr(type, scalar);
@ -2234,6 +2281,10 @@ public:
return parser.parseToken(Token::comma, "expected ','");
}
bool parseType(Type &result) override {
return !(result = parser.parseType());
}
bool parseColonType(Type &result) override {
return parser.parseToken(Token::colon, "expected ':'") ||
!(result = parser.parseType());
@ -2261,12 +2312,12 @@ public:
return !(result = parser.parseType());
}
/// Parse an arbitrary attribute and return it in result. This also adds
/// the attribute to the specified attribute list with the specified name.
/// this captures the location of the attribute in 'loc' if it is non-null.
bool parseAttribute(Attribute &result, const char *attrName,
/// Parse an arbitrary attribute of a given type and return it in result. This
/// also adds the attribute to the specified attribute list with the specified
/// name.
bool parseAttribute(Attribute &result, Type type, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) override {
result = parser.parseAttribute();
result = parser.parseAttribute(type);
if (!result)
return true;
@ -2275,6 +2326,13 @@ public:
return false;
}
/// Parse an arbitrary attribute and return it in result. This also adds
/// the attribute to the specified attribute list with the specified name.
bool parseAttribute(Attribute &result, const char *attrName,
SmallVectorImpl<NamedAttribute> &attrs) override {
return parseAttribute(result, Type(), attrName, attrs);
}
/// If a named attribute list is present, parse is into result.
bool
parseOptionalAttributeDict(SmallVectorImpl<NamedAttribute> &result) override {

View File

@ -85,7 +85,8 @@ Attribute AddFOp::constantFold(ArrayRef<Attribute> operands,
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs.getValue() + rhs.getValue(), context);
if (lhs.getType() == rhs.getType())
return FloatAttr::get(lhs.getType(), lhs.getValue() + rhs.getValue());
}
return nullptr;
@ -101,7 +102,8 @@ Attribute AddIOp::constantFold(ArrayRef<Attribute> operands,
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(lhs.getValue() + rhs.getValue(), context);
if (lhs.getType() == rhs.getType())
return IntegerAttr::get(lhs.getType(), lhs.getValue() + rhs.getValue());
}
return nullptr;
@ -504,7 +506,8 @@ void CmpIOp::build(Builder *build, OperationState *result,
result->addOperands({lhs, rhs});
result->types.push_back(getI1SameShape(build, lhs->getType()));
result->addAttribute(getPredicateAttrName(),
build->getIntegerAttr(static_cast<int64_t>(predicate)));
build->getIntegerAttr(build->getIntegerType(64),
static_cast<int64_t>(predicate)));
}
bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
@ -526,8 +529,8 @@ bool CmpIOp::parse(OpAsmParser *parser, OperationState *result) {
return parser->emitError(parser->getNameLoc(),
"unknown comparison predicate \"" +
Twine(predicateName.getValue()) + "\"");
attrs[0].second =
parser->getBuilder().getIntegerAttr(static_cast<int64_t>(predicate));
auto builder = parser->getBuilder();
attrs[0].second = builder.getIntegerAttr(static_cast<int64_t>(predicate));
result->attributes = attrs;
// The result of comparison is formed from i1s in the same shape as type.
@ -616,8 +619,9 @@ void DeallocOp::getCanonicalizationPatterns(OwningPatternList &results,
void DimOp::build(Builder *builder, OperationState *result,
SSAValue *memrefOrTensor, unsigned index) {
result->addOperands(memrefOrTensor);
result->addAttribute("index", builder->getIntegerAttr(index));
result->types.push_back(builder->getIndexType());
auto type = builder->getIndexType();
result->addAttribute("index", builder->getIntegerAttr(type, index));
result->types.push_back(type);
}
void DimOp::print(OpAsmPrinter *p) const {
@ -630,14 +634,15 @@ bool DimOp::parse(OpAsmParser *parser, OperationState *result) {
OpAsmParser::OperandType operandInfo;
IntegerAttr indexAttr;
Type type;
Type indexType = parser->getBuilder().getIndexType();
return parser->parseOperand(operandInfo) || parser->parseComma() ||
parser->parseAttribute(indexAttr, "index", result->attributes) ||
parser->parseAttribute(indexAttr, indexType, "index",
result->attributes) ||
parser->parseOptionalAttributeDict(result->attributes) ||
parser->parseColonType(type) ||
parser->resolveOperand(operandInfo, type, result->operands) ||
parser->addTypeToList(parser->getBuilder().getIndexType(),
result->types);
parser->addTypeToList(indexType, result->types);
}
bool DimOp::verify() const {
@ -676,7 +681,7 @@ Attribute DimOp::constantFold(ArrayRef<Attribute> operands,
}
if (indexSize >= 0)
return IntegerAttr::get(indexSize, context);
return IntegerAttr::get(Type::getIndex(context), indexSize);
return nullptr;
}
@ -1019,7 +1024,8 @@ Attribute MulFOp::constantFold(ArrayRef<Attribute> operands,
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs.getValue() * rhs.getValue(), context);
if (lhs.getType() == rhs.getType())
return FloatAttr::get(lhs.getType(), lhs.getValue() * rhs.getValue());
}
return nullptr;
@ -1040,7 +1046,8 @@ Attribute MulIOp::constantFold(ArrayRef<Attribute> operands,
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
// TODO: Handle the overflow case.
return IntegerAttr::get(lhs.getValue() * rhs.getValue(), context);
if (lhs.getType() == rhs.getType())
return IntegerAttr::get(lhs.getType(), lhs.getValue() * rhs.getValue());
}
// x*0 == 0
@ -1161,7 +1168,8 @@ Attribute SubFOp::constantFold(ArrayRef<Attribute> operands,
if (auto lhs = operands[0].dyn_cast_or_null<FloatAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<FloatAttr>())
return FloatAttr::get(lhs.getValue() - rhs.getValue(), context);
if (lhs.getType() == rhs.getType())
return FloatAttr::get(lhs.getType(), lhs.getValue() - rhs.getValue());
}
return nullptr;
@ -1177,7 +1185,8 @@ Attribute SubIOp::constantFold(ArrayRef<Attribute> operands,
if (auto lhs = operands[0].dyn_cast_or_null<IntegerAttr>()) {
if (auto rhs = operands[1].dyn_cast_or_null<IntegerAttr>())
return IntegerAttr::get(lhs.getValue() - rhs.getValue(), context);
if (lhs.getType() == rhs.getType())
return IntegerAttr::get(lhs.getType(), lhs.getValue() - rhs.getValue());
}
return nullptr;

View File

@ -71,13 +71,13 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32, %idx : index):
%i6 = muli %i2, %i2 : i32
// CHECK: %c42_i32 = constant 42 : i32
%x = "constant"(){value: 42} : () -> i32
%x = "constant"(){value: 42: i32} : () -> i32
// CHECK: %c42_i32_0 = constant 42 : i32
%7 = constant 42 : i32
// CHECK: %c43 = constant 43 {crazy: "foo"} : index
%8 = constant 43 {crazy: "foo"} : index
%8 = constant 43: index {crazy: "foo"}
// CHECK: %cst = constant 4.300000e+01 : bf16
%9 = constant 43.0 : bf16
@ -128,8 +128,8 @@ bb42(%t: tensor<4x4x?xf32>, %f: f32, %i: i32, %idx : index):
// CHECK-LABEL: cfgfunc @affine_apply() {
cfgfunc @affine_apply() {
bb0:
%i = "constant"() {value: 0} : () -> index
%j = "constant"() {value: 1} : () -> index
%i = "constant"() {value: 0: index} : () -> index
%j = "constant"() {value: 1: index} : () -> index
// CHECK: affine_apply #map0(%c0)
%a = "affine_apply" (%i) { map: (d0) -> (d0 + 1) } :
@ -195,7 +195,7 @@ mlfunc @calls(%arg0 : i32) {
// CHECK-LABEL: mlfunc @extract_element(%arg0 : tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
mlfunc @extract_element(%arg0 : tensor<*xi32>, %arg1 : tensor<4x4xf32>) -> i32 {
%c0 = "constant"() {value: 0} : () -> index
%c0 = "constant"() {value: 0: index} : () -> index
// CHECK: %0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<*xi32>
%0 = extract_element %arg0[%c0, %c0, %c0, %c0] : tensor<*xi32>