Ensure that every Attribute contains a Type. If an Attribute does not provide a type explicitly, the type is defaulted to NoneType.

--

PiperOrigin-RevId: 246021088
This commit is contained in:
River Riddle 2019-04-30 14:26:04 -07:00 committed by Mehdi Amini
parent 0bd0571e72
commit 17d3acf40c
4 changed files with 82 additions and 106 deletions

View File

@ -47,7 +47,6 @@ struct AffineMapAttributeStorage;
struct IntegerSetAttributeStorage;
struct TypeAttributeStorage;
struct FunctionAttributeStorage;
struct ElementsAttributeStorage;
struct SplatElementsAttributeStorage;
struct DenseElementsAttributeStorage;
struct DenseIntElementsAttributeStorage;
@ -125,6 +124,9 @@ public:
/// Return the classification for this attribute.
Kind getKind() const;
/// Return the type of this attribute.
Type getType() const;
/// Return true if this field is, or contains, a function attribute.
bool isOrContainsFunction() const;
@ -177,8 +179,6 @@ class NumericAttr : public Attribute {
public:
using Attribute::Attribute;
Type getType() const;
static bool kindof(Kind kind);
};
@ -192,8 +192,6 @@ public:
bool getValue() const;
Type getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Bool; }
};
@ -211,8 +209,6 @@ public:
// TODO(jpienaar): Change callers to use getValue instead.
int64_t getInt() const;
Type getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Integer; }
};
@ -238,8 +234,6 @@ public:
/// precision.
double getValueAsDouble() const;
Type getType() const;
/// Methods for support type inquiry through isa, cast, and dyn_cast.
static bool kindof(Kind kind) { return kind == Kind::Float; }
};
@ -353,7 +347,6 @@ public:
class ElementsAttr : public NumericAttr {
public:
using NumericAttr::NumericAttr;
using ImplType = detail::ElementsAttributeStorage;
VectorOrTensorType getType() const;

View File

@ -24,23 +24,43 @@
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/Function.h"
#include "mlir/IR/Identifier.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardTypes.h"
#include "mlir/Support/StorageUniquer.h"
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/PointerIntPair.h"
#include "llvm/Support/TrailingObjects.h"
namespace mlir {
namespace detail {
/// Base storage class appearing in an attribute.
/// Base storage class appearing in an attribute. Derived storage classes should
/// only be constructed within the context of the AttributeUniquer.
struct AttributeStorage : public StorageUniquer::BaseStorage {
AttributeStorage(bool isOrContainsFunctionCache = false)
: isOrContainsFunctionCache(isOrContainsFunctionCache) {}
/// Construct a new attribute storage instance with the given type and a
/// boolean that signals if the derived attribute is or contains a function
/// pointer.
/// Note: All attributes require a valid type. If a null type is provided
/// here, the type of the attribute will automatically default to
/// NoneType upon initialization in the uniquer.
AttributeStorage(Type type = {}, bool isOrContainsFunctionCache = false)
: typeAndContainsFunctionAttrPair(type, isOrContainsFunctionCache) {}
AttributeStorage(bool isOrContainsFunctionCache)
: AttributeStorage(/*type=*/{}, isOrContainsFunctionCache) {}
/// This field is true if this is, or contains, a function attribute.
bool isOrContainsFunctionCache : 1;
bool isOrContainsFunctionCache() const {
return typeAndContainsFunctionAttrPair.getInt();
}
Type getType() const { return typeAndContainsFunctionAttrPair.getPointer(); }
void setType(Type type) { typeAndContainsFunctionAttrPair.setPointer(type); }
/// This field is a pair of:
/// - The type of the attribute value.
/// - A boolean that is true if this is, or contains, a function attribute.
llvm::PointerIntPair<Type, 1, bool> typeAndContainsFunctionAttrPair;
};
// A utility class to get, or create, unique instances of attributes within an
@ -54,7 +74,7 @@ public:
!std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
return ctx->getAttributeUniquer().getComplex<typename T::ImplType>(
/*initFn=*/{}, static_cast<unsigned>(kind),
getInitFn(ctx), static_cast<unsigned>(kind),
std::forward<Args>(args)...);
}
@ -66,7 +86,7 @@ public:
std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
get(MLIRContext *ctx, Attribute::Kind kind) {
return ctx->getAttributeUniquer().getSimple<AttributeStorage>(
/*initFn=*/{}, static_cast<unsigned>(kind));
getInitFn(ctx), static_cast<unsigned>(kind));
}
/// Erase a uniqued instance of attribute T. This overload is used for
@ -78,6 +98,15 @@ public:
return ctx->getAttributeUniquer().eraseComplex<typename T::ImplType>(
static_cast<unsigned>(kind), std::forward<Args>(args)...);
}
/// Generate a functor to initialize a new attribute storage instance.
static std::function<void(AttributeStorage *)> getInitFn(MLIRContext *ctx) {
return [ctx](AttributeStorage *storage) {
// If the attribute did not provide a type, then default to NoneType.
if (!storage->getType())
storage->setType(NoneType::get(ctx));
};
}
};
using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
@ -86,7 +115,8 @@ using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
struct BoolAttributeStorage : public AttributeStorage {
using KeyTy = std::pair<MLIRContext *, bool>;
BoolAttributeStorage(Type type, bool value) : type(type), value(value) {}
BoolAttributeStorage(Type type, bool value)
: AttributeStorage(type), value(value) {}
/// We only check equality for and hash with the boolean key parameter.
bool operator==(const KeyTy &key) const { return key.second == value; }
@ -100,7 +130,6 @@ struct BoolAttributeStorage : public AttributeStorage {
BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
}
Type type;
bool value;
};
@ -111,13 +140,13 @@ struct IntegerAttributeStorage final
using KeyTy = std::pair<Type, APInt>;
IntegerAttributeStorage(Type type, size_t numObjects)
: type(type), numObjects(numObjects) {
: AttributeStorage(type), numObjects(numObjects) {
assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key == KeyTy(type, getValue());
return key == KeyTy(getType(), getValue());
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
@ -142,13 +171,12 @@ struct IntegerAttributeStorage final
/// Returns an APInt representing the stored value.
APInt getValue() const {
if (type.isIndex())
if (getType().isIndex())
return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
return APInt(type.getIntOrFloatBitWidth(),
return APInt(getType().getIntOrFloatBitWidth(),
{getTrailingObjects<uint64_t>(), numObjects});
}
Type type;
size_t numObjects;
};
@ -160,12 +188,11 @@ struct FloatAttributeStorage final
FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
size_t numObjects)
: semantics(semantics), type(type.cast<FloatType>()),
numObjects(numObjects) {}
: AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key.first == type && key.second.bitwiseIsEqual(getValue());
return key.first == getType() && key.second.bitwiseIsEqual(getValue());
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
@ -197,7 +224,6 @@ struct FloatAttributeStorage final
}
const llvm::fltSemantics &semantics;
FloatType type;
size_t numObjects;
};
@ -307,7 +333,8 @@ struct FunctionAttributeStorage : public AttributeStorage {
using KeyTy = Function *;
FunctionAttributeStorage(Function *value)
: AttributeStorage(/*isOrContainsFunctionCache=*/true), value(value) {}
: AttributeStorage(value->getType(), /*isOrContainsFunctionCache=*/true),
value(value) {}
/// Key equality function.
bool operator==(const KeyTy &key) const { return key == value; }
@ -329,26 +356,17 @@ struct FunctionAttributeStorage : public AttributeStorage {
Function *value;
};
/// A base attribute representing a reference to a vector or tensor constant.
struct ElementsAttributeStorage : public AttributeStorage {
ElementsAttributeStorage(VectorOrTensorType type) : type(type) {}
VectorOrTensorType type;
};
/// An attribute representing a reference to a vector or tensor constant,
/// inwhich all elements have the same value.
struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
using KeyTy = std::pair<VectorOrTensorType, Attribute>;
struct SplatElementsAttributeStorage : public AttributeStorage {
using KeyTy = std::pair<Type, Attribute>;
SplatElementsAttributeStorage(VectorOrTensorType type, Attribute elt)
: ElementsAttributeStorage(type), elt(elt) {}
SplatElementsAttributeStorage(Type type, Attribute elt)
: AttributeStorage(type), elt(elt) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key == std::make_pair(type, elt);
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, key.second);
return key == std::make_pair(getType(), elt);
}
/// Construct a new storage instance.
@ -362,16 +380,15 @@ struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
};
/// An attribute representing a reference to a dense vector or tensor object.
struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
struct DenseElementsAttributeStorage : public AttributeStorage {
using KeyTy = std::pair<Type, ArrayRef<char>>;
DenseElementsAttributeStorage(VectorOrTensorType ty, ArrayRef<char> data)
: ElementsAttributeStorage(ty), data(data) {}
DenseElementsAttributeStorage(Type ty, ArrayRef<char> data)
: AttributeStorage(ty), data(data) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const { return key == KeyTy(type, data); }
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(key.first, key.second);
bool operator==(const KeyTy &key) const {
return key == KeyTy(getType(), data);
}
/// Construct a new storage instance.
@ -398,16 +415,15 @@ struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
/// An attribute representing a reference to a tensor constant with opaque
/// content.
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
using KeyTy = std::tuple<VectorOrTensorType, Dialect *, StringRef>;
struct OpaqueElementsAttributeStorage : public AttributeStorage {
using KeyTy = std::tuple<Type, Dialect *, StringRef>;
OpaqueElementsAttributeStorage(VectorOrTensorType type, Dialect *dialect,
StringRef bytes)
: ElementsAttributeStorage(type), dialect(dialect), bytes(bytes) {}
OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
: AttributeStorage(type), dialect(dialect), bytes(bytes) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key == std::make_tuple(type, dialect, bytes);
return key == std::make_tuple(getType(), dialect, bytes);
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
@ -429,18 +445,16 @@ struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
};
/// An attribute representing a reference to a sparse vector or tensor object.
struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
using KeyTy =
std::tuple<VectorOrTensorType, DenseIntElementsAttr, DenseElementsAttr>;
struct SparseElementsAttributeStorage : public AttributeStorage {
using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;
SparseElementsAttributeStorage(VectorOrTensorType type,
DenseIntElementsAttr indices,
SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
DenseElementsAttr values)
: ElementsAttributeStorage(type), indices(indices), values(values) {}
: AttributeStorage(type), indices(indices), values(values) {}
/// Key equality and hash functions.
bool operator==(const KeyTy &key) const {
return key == std::make_tuple(type, indices, values);
return key == std::make_tuple(getType(), indices, values);
}
static unsigned hashKey(const KeyTy &key) {
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),

View File

@ -30,8 +30,11 @@ Attribute::Kind Attribute::getKind() const {
return static_cast<Kind>(attr->getKind());
}
/// Return the type of this attribute.
Type Attribute::getType() const { return attr->getType(); }
bool Attribute::isOrContainsFunction() const {
return attr->isOrContainsFunctionCache;
return attr->isOrContainsFunctionCache();
}
// Given an attribute that could refer to a function attribute in the remapping
@ -79,19 +82,6 @@ UnitAttr UnitAttr::get(MLIRContext *context) {
// NumericAttr
//===----------------------------------------------------------------------===//
Type NumericAttr::getType() const {
if (auto boolAttr = dyn_cast<BoolAttr>())
return boolAttr.getType();
if (auto intAttr = dyn_cast<IntegerAttr>())
return intAttr.getType();
if (auto floatAttr = dyn_cast<FloatAttr>())
return floatAttr.getType();
if (auto elemAttr = dyn_cast<ElementsAttr>())
return elemAttr.getType();
llvm_unreachable("unhandled NumericAttr subclass");
}
bool NumericAttr::kindof(Kind kind) {
return BoolAttr::kindof(kind) || IntegerAttr::kindof(kind) ||
FloatAttr::kindof(kind) || ElementsAttr::kindof(kind);
@ -109,8 +99,6 @@ BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
//===----------------------------------------------------------------------===//
// IntegerAttr
//===----------------------------------------------------------------------===//
@ -135,10 +123,6 @@ APInt IntegerAttr::getValue() const {
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
Type IntegerAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
}
//===----------------------------------------------------------------------===//
// FloatAttr
//===----------------------------------------------------------------------===//
@ -185,8 +169,6 @@ APFloat FloatAttr::getValue() const {
return static_cast<ImplType *>(attr)->getValue();
}
Type FloatAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
double FloatAttr::getValueAsDouble() const {
const auto &semantics = getType().cast<FloatType>().getFloatSemantics();
auto value = getValue();
@ -281,14 +263,16 @@ Function *FunctionAttr::getValue() const {
return static_cast<ImplType *>(attr)->value;
}
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
FunctionType FunctionAttr::getType() const {
return Attribute::getType().cast<FunctionType>();
}
//===----------------------------------------------------------------------===//
// ElementsAttr
//===----------------------------------------------------------------------===//
VectorOrTensorType ElementsAttr::getType() const {
return static_cast<ImplType *>(attr)->type;
return Attribute::getType().cast<VectorOrTensorType>();
}
/// Return the value at the given index. If index does not refer to a valid
@ -315,8 +299,8 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
Attribute elt) {
assert(elt.cast<NumericAttr>().getType() == type.getElementType() &&
"value should be of the given type");
assert(elt.getType() == type.getElementType() &&
"value should be of the given element type");
return AttributeUniquer::get<SplatElementsAttr>(
type.getContext(), Attribute::Kind::SplatElements, type, elt);
}

View File

@ -920,24 +920,11 @@ void ConstantOp::build(Builder *builder, OperationState *result, Type type,
result->types.push_back(type);
}
// Extracts and returns a type of an attribute if it has one. Returns a null
// type otherwise. Currently, NumericAttrs and FunctionAttrs have types.
static Type getAttributeType(Attribute attr) {
assert(attr && "expected non-null attribute");
if (auto numericAttr = attr.dyn_cast<NumericAttr>())
return numericAttr.getType();
if (auto functionAttr = attr.dyn_cast<FunctionAttr>())
return functionAttr.getType();
return {};
}
/// Builds a constant with the specified attribute value and type extracted
/// from the attribute. The attribute must have a type.
void ConstantOp::build(Builder *builder, OperationState *result,
Attribute value) {
Type t = getAttributeType(value);
assert(t && "expected an attribute with a type");
return build(builder, result, t, value);
return build(builder, result, value.getType(), value);
}
void ConstantOp::print(OpAsmPrinter *p) {
@ -1018,9 +1005,7 @@ LogicalResult ConstantOp::verify() {
return success();
}
auto attrType = getAttributeType(value);
if (!attrType)
return emitOpError("requires 'value' attribute to have a type");
auto attrType = value.getType();
if (attrType != type)
return emitOpError("requires the type of the 'value' attribute to match "
"that of the operation result");