forked from OSchip/llvm-project
Ensure that every Attribute contains a Type. If an Attribute does not provide a type explicitly, the type is defaulted to NoneType.
-- PiperOrigin-RevId: 246021088
This commit is contained in:
parent
0bd0571e72
commit
17d3acf40c
|
@ -47,7 +47,6 @@ struct AffineMapAttributeStorage;
|
|||
struct IntegerSetAttributeStorage;
|
||||
struct TypeAttributeStorage;
|
||||
struct FunctionAttributeStorage;
|
||||
struct ElementsAttributeStorage;
|
||||
struct SplatElementsAttributeStorage;
|
||||
struct DenseElementsAttributeStorage;
|
||||
struct DenseIntElementsAttributeStorage;
|
||||
|
@ -125,6 +124,9 @@ public:
|
|||
/// Return the classification for this attribute.
|
||||
Kind getKind() const;
|
||||
|
||||
/// Return the type of this attribute.
|
||||
Type getType() const;
|
||||
|
||||
/// Return true if this field is, or contains, a function attribute.
|
||||
bool isOrContainsFunction() const;
|
||||
|
||||
|
@ -177,8 +179,6 @@ class NumericAttr : public Attribute {
|
|||
public:
|
||||
using Attribute::Attribute;
|
||||
|
||||
Type getType() const;
|
||||
|
||||
static bool kindof(Kind kind);
|
||||
};
|
||||
|
||||
|
@ -192,8 +192,6 @@ public:
|
|||
|
||||
bool getValue() const;
|
||||
|
||||
Type getType() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Bool; }
|
||||
};
|
||||
|
@ -211,8 +209,6 @@ public:
|
|||
// TODO(jpienaar): Change callers to use getValue instead.
|
||||
int64_t getInt() const;
|
||||
|
||||
Type getType() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Integer; }
|
||||
};
|
||||
|
@ -238,8 +234,6 @@ public:
|
|||
/// precision.
|
||||
double getValueAsDouble() const;
|
||||
|
||||
Type getType() const;
|
||||
|
||||
/// Methods for support type inquiry through isa, cast, and dyn_cast.
|
||||
static bool kindof(Kind kind) { return kind == Kind::Float; }
|
||||
};
|
||||
|
@ -353,7 +347,6 @@ public:
|
|||
class ElementsAttr : public NumericAttr {
|
||||
public:
|
||||
using NumericAttr::NumericAttr;
|
||||
using ImplType = detail::ElementsAttributeStorage;
|
||||
|
||||
VectorOrTensorType getType() const;
|
||||
|
||||
|
|
|
@ -24,23 +24,43 @@
|
|||
|
||||
#include "mlir/IR/AffineMap.h"
|
||||
#include "mlir/IR/Attributes.h"
|
||||
#include "mlir/IR/Function.h"
|
||||
#include "mlir/IR/Identifier.h"
|
||||
#include "mlir/IR/IntegerSet.h"
|
||||
#include "mlir/IR/MLIRContext.h"
|
||||
#include "mlir/IR/StandardTypes.h"
|
||||
#include "mlir/Support/StorageUniquer.h"
|
||||
#include "llvm/ADT/APFloat.h"
|
||||
#include "llvm/ADT/PointerIntPair.h"
|
||||
#include "llvm/Support/TrailingObjects.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace detail {
|
||||
/// Base storage class appearing in an attribute.
|
||||
/// Base storage class appearing in an attribute. Derived storage classes should
|
||||
/// only be constructed within the context of the AttributeUniquer.
|
||||
struct AttributeStorage : public StorageUniquer::BaseStorage {
|
||||
AttributeStorage(bool isOrContainsFunctionCache = false)
|
||||
: isOrContainsFunctionCache(isOrContainsFunctionCache) {}
|
||||
/// Construct a new attribute storage instance with the given type and a
|
||||
/// boolean that signals if the derived attribute is or contains a function
|
||||
/// pointer.
|
||||
/// Note: All attributes require a valid type. If a null type is provided
|
||||
/// here, the type of the attribute will automatically default to
|
||||
/// NoneType upon initialization in the uniquer.
|
||||
AttributeStorage(Type type = {}, bool isOrContainsFunctionCache = false)
|
||||
: typeAndContainsFunctionAttrPair(type, isOrContainsFunctionCache) {}
|
||||
AttributeStorage(bool isOrContainsFunctionCache)
|
||||
: AttributeStorage(/*type=*/{}, isOrContainsFunctionCache) {}
|
||||
|
||||
/// This field is true if this is, or contains, a function attribute.
|
||||
bool isOrContainsFunctionCache : 1;
|
||||
bool isOrContainsFunctionCache() const {
|
||||
return typeAndContainsFunctionAttrPair.getInt();
|
||||
}
|
||||
|
||||
Type getType() const { return typeAndContainsFunctionAttrPair.getPointer(); }
|
||||
void setType(Type type) { typeAndContainsFunctionAttrPair.setPointer(type); }
|
||||
|
||||
/// This field is a pair of:
|
||||
/// - The type of the attribute value.
|
||||
/// - A boolean that is true if this is, or contains, a function attribute.
|
||||
llvm::PointerIntPair<Type, 1, bool> typeAndContainsFunctionAttrPair;
|
||||
};
|
||||
|
||||
// A utility class to get, or create, unique instances of attributes within an
|
||||
|
@ -54,7 +74,7 @@ public:
|
|||
!std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
|
||||
get(MLIRContext *ctx, Attribute::Kind kind, Args &&... args) {
|
||||
return ctx->getAttributeUniquer().getComplex<typename T::ImplType>(
|
||||
/*initFn=*/{}, static_cast<unsigned>(kind),
|
||||
getInitFn(ctx), static_cast<unsigned>(kind),
|
||||
std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
|
@ -66,7 +86,7 @@ public:
|
|||
std::is_same<typename T::ImplType, AttributeStorage>::value, T>::type
|
||||
get(MLIRContext *ctx, Attribute::Kind kind) {
|
||||
return ctx->getAttributeUniquer().getSimple<AttributeStorage>(
|
||||
/*initFn=*/{}, static_cast<unsigned>(kind));
|
||||
getInitFn(ctx), static_cast<unsigned>(kind));
|
||||
}
|
||||
|
||||
/// Erase a uniqued instance of attribute T. This overload is used for
|
||||
|
@ -78,6 +98,15 @@ public:
|
|||
return ctx->getAttributeUniquer().eraseComplex<typename T::ImplType>(
|
||||
static_cast<unsigned>(kind), std::forward<Args>(args)...);
|
||||
}
|
||||
|
||||
/// Generate a functor to initialize a new attribute storage instance.
|
||||
static std::function<void(AttributeStorage *)> getInitFn(MLIRContext *ctx) {
|
||||
return [ctx](AttributeStorage *storage) {
|
||||
// If the attribute did not provide a type, then default to NoneType.
|
||||
if (!storage->getType())
|
||||
storage->setType(NoneType::get(ctx));
|
||||
};
|
||||
}
|
||||
};
|
||||
|
||||
using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
|
||||
|
@ -86,7 +115,8 @@ using AttributeStorageAllocator = StorageUniquer::StorageAllocator;
|
|||
struct BoolAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = std::pair<MLIRContext *, bool>;
|
||||
|
||||
BoolAttributeStorage(Type type, bool value) : type(type), value(value) {}
|
||||
BoolAttributeStorage(Type type, bool value)
|
||||
: AttributeStorage(type), value(value) {}
|
||||
|
||||
/// We only check equality for and hash with the boolean key parameter.
|
||||
bool operator==(const KeyTy &key) const { return key.second == value; }
|
||||
|
@ -100,7 +130,6 @@ struct BoolAttributeStorage : public AttributeStorage {
|
|||
BoolAttributeStorage(IntegerType::get(1, key.first), key.second);
|
||||
}
|
||||
|
||||
Type type;
|
||||
bool value;
|
||||
};
|
||||
|
||||
|
@ -111,13 +140,13 @@ struct IntegerAttributeStorage final
|
|||
using KeyTy = std::pair<Type, APInt>;
|
||||
|
||||
IntegerAttributeStorage(Type type, size_t numObjects)
|
||||
: type(type), numObjects(numObjects) {
|
||||
: AttributeStorage(type), numObjects(numObjects) {
|
||||
assert((type.isIndex() || type.isa<IntegerType>()) && "invalid type");
|
||||
}
|
||||
|
||||
/// Key equality and hash functions.
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(type, getValue());
|
||||
return key == KeyTy(getType(), getValue());
|
||||
}
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
|
||||
|
@ -142,13 +171,12 @@ struct IntegerAttributeStorage final
|
|||
|
||||
/// Returns an APInt representing the stored value.
|
||||
APInt getValue() const {
|
||||
if (type.isIndex())
|
||||
if (getType().isIndex())
|
||||
return APInt(64, {getTrailingObjects<uint64_t>(), numObjects});
|
||||
return APInt(type.getIntOrFloatBitWidth(),
|
||||
return APInt(getType().getIntOrFloatBitWidth(),
|
||||
{getTrailingObjects<uint64_t>(), numObjects});
|
||||
}
|
||||
|
||||
Type type;
|
||||
size_t numObjects;
|
||||
};
|
||||
|
||||
|
@ -160,12 +188,11 @@ struct FloatAttributeStorage final
|
|||
|
||||
FloatAttributeStorage(const llvm::fltSemantics &semantics, Type type,
|
||||
size_t numObjects)
|
||||
: semantics(semantics), type(type.cast<FloatType>()),
|
||||
numObjects(numObjects) {}
|
||||
: AttributeStorage(type), semantics(semantics), numObjects(numObjects) {}
|
||||
|
||||
/// Key equality and hash functions.
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key.first == type && key.second.bitwiseIsEqual(getValue());
|
||||
return key.first == getType() && key.second.bitwiseIsEqual(getValue());
|
||||
}
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(key.first, llvm::hash_value(key.second));
|
||||
|
@ -197,7 +224,6 @@ struct FloatAttributeStorage final
|
|||
}
|
||||
|
||||
const llvm::fltSemantics &semantics;
|
||||
FloatType type;
|
||||
size_t numObjects;
|
||||
};
|
||||
|
||||
|
@ -307,7 +333,8 @@ struct FunctionAttributeStorage : public AttributeStorage {
|
|||
using KeyTy = Function *;
|
||||
|
||||
FunctionAttributeStorage(Function *value)
|
||||
: AttributeStorage(/*isOrContainsFunctionCache=*/true), value(value) {}
|
||||
: AttributeStorage(value->getType(), /*isOrContainsFunctionCache=*/true),
|
||||
value(value) {}
|
||||
|
||||
/// Key equality function.
|
||||
bool operator==(const KeyTy &key) const { return key == value; }
|
||||
|
@ -329,26 +356,17 @@ struct FunctionAttributeStorage : public AttributeStorage {
|
|||
Function *value;
|
||||
};
|
||||
|
||||
/// A base attribute representing a reference to a vector or tensor constant.
|
||||
struct ElementsAttributeStorage : public AttributeStorage {
|
||||
ElementsAttributeStorage(VectorOrTensorType type) : type(type) {}
|
||||
VectorOrTensorType type;
|
||||
};
|
||||
|
||||
/// An attribute representing a reference to a vector or tensor constant,
|
||||
/// inwhich all elements have the same value.
|
||||
struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
using KeyTy = std::pair<VectorOrTensorType, Attribute>;
|
||||
struct SplatElementsAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = std::pair<Type, Attribute>;
|
||||
|
||||
SplatElementsAttributeStorage(VectorOrTensorType type, Attribute elt)
|
||||
: ElementsAttributeStorage(type), elt(elt) {}
|
||||
SplatElementsAttributeStorage(Type type, Attribute elt)
|
||||
: AttributeStorage(type), elt(elt) {}
|
||||
|
||||
/// Key equality and hash functions.
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == std::make_pair(type, elt);
|
||||
}
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(key.first, key.second);
|
||||
return key == std::make_pair(getType(), elt);
|
||||
}
|
||||
|
||||
/// Construct a new storage instance.
|
||||
|
@ -362,16 +380,15 @@ struct SplatElementsAttributeStorage : public ElementsAttributeStorage {
|
|||
};
|
||||
|
||||
/// An attribute representing a reference to a dense vector or tensor object.
|
||||
struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
using KeyTy = std::pair<VectorOrTensorType, ArrayRef<char>>;
|
||||
struct DenseElementsAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = std::pair<Type, ArrayRef<char>>;
|
||||
|
||||
DenseElementsAttributeStorage(VectorOrTensorType ty, ArrayRef<char> data)
|
||||
: ElementsAttributeStorage(ty), data(data) {}
|
||||
DenseElementsAttributeStorage(Type ty, ArrayRef<char> data)
|
||||
: AttributeStorage(ty), data(data) {}
|
||||
|
||||
/// Key equality and hash functions.
|
||||
bool operator==(const KeyTy &key) const { return key == KeyTy(type, data); }
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(key.first, key.second);
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == KeyTy(getType(), data);
|
||||
}
|
||||
|
||||
/// Construct a new storage instance.
|
||||
|
@ -398,16 +415,15 @@ struct DenseElementsAttributeStorage : public ElementsAttributeStorage {
|
|||
|
||||
/// An attribute representing a reference to a tensor constant with opaque
|
||||
/// content.
|
||||
struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
using KeyTy = std::tuple<VectorOrTensorType, Dialect *, StringRef>;
|
||||
struct OpaqueElementsAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = std::tuple<Type, Dialect *, StringRef>;
|
||||
|
||||
OpaqueElementsAttributeStorage(VectorOrTensorType type, Dialect *dialect,
|
||||
StringRef bytes)
|
||||
: ElementsAttributeStorage(type), dialect(dialect), bytes(bytes) {}
|
||||
OpaqueElementsAttributeStorage(Type type, Dialect *dialect, StringRef bytes)
|
||||
: AttributeStorage(type), dialect(dialect), bytes(bytes) {}
|
||||
|
||||
/// Key equality and hash functions.
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == std::make_tuple(type, dialect, bytes);
|
||||
return key == std::make_tuple(getType(), dialect, bytes);
|
||||
}
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
|
||||
|
@ -429,18 +445,16 @@ struct OpaqueElementsAttributeStorage : public ElementsAttributeStorage {
|
|||
};
|
||||
|
||||
/// An attribute representing a reference to a sparse vector or tensor object.
|
||||
struct SparseElementsAttributeStorage : public ElementsAttributeStorage {
|
||||
using KeyTy =
|
||||
std::tuple<VectorOrTensorType, DenseIntElementsAttr, DenseElementsAttr>;
|
||||
struct SparseElementsAttributeStorage : public AttributeStorage {
|
||||
using KeyTy = std::tuple<Type, DenseIntElementsAttr, DenseElementsAttr>;
|
||||
|
||||
SparseElementsAttributeStorage(VectorOrTensorType type,
|
||||
DenseIntElementsAttr indices,
|
||||
SparseElementsAttributeStorage(Type type, DenseIntElementsAttr indices,
|
||||
DenseElementsAttr values)
|
||||
: ElementsAttributeStorage(type), indices(indices), values(values) {}
|
||||
: AttributeStorage(type), indices(indices), values(values) {}
|
||||
|
||||
/// Key equality and hash functions.
|
||||
bool operator==(const KeyTy &key) const {
|
||||
return key == std::make_tuple(type, indices, values);
|
||||
return key == std::make_tuple(getType(), indices, values);
|
||||
}
|
||||
static unsigned hashKey(const KeyTy &key) {
|
||||
return llvm::hash_combine(std::get<0>(key), std::get<1>(key),
|
||||
|
|
|
@ -30,8 +30,11 @@ Attribute::Kind Attribute::getKind() const {
|
|||
return static_cast<Kind>(attr->getKind());
|
||||
}
|
||||
|
||||
/// Return the type of this attribute.
|
||||
Type Attribute::getType() const { return attr->getType(); }
|
||||
|
||||
bool Attribute::isOrContainsFunction() const {
|
||||
return attr->isOrContainsFunctionCache;
|
||||
return attr->isOrContainsFunctionCache();
|
||||
}
|
||||
|
||||
// Given an attribute that could refer to a function attribute in the remapping
|
||||
|
@ -79,19 +82,6 @@ UnitAttr UnitAttr::get(MLIRContext *context) {
|
|||
// NumericAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
Type NumericAttr::getType() const {
|
||||
if (auto boolAttr = dyn_cast<BoolAttr>())
|
||||
return boolAttr.getType();
|
||||
if (auto intAttr = dyn_cast<IntegerAttr>())
|
||||
return intAttr.getType();
|
||||
if (auto floatAttr = dyn_cast<FloatAttr>())
|
||||
return floatAttr.getType();
|
||||
if (auto elemAttr = dyn_cast<ElementsAttr>())
|
||||
return elemAttr.getType();
|
||||
|
||||
llvm_unreachable("unhandled NumericAttr subclass");
|
||||
}
|
||||
|
||||
bool NumericAttr::kindof(Kind kind) {
|
||||
return BoolAttr::kindof(kind) || IntegerAttr::kindof(kind) ||
|
||||
FloatAttr::kindof(kind) || ElementsAttr::kindof(kind);
|
||||
|
@ -109,8 +99,6 @@ BoolAttr BoolAttr::get(bool value, MLIRContext *context) {
|
|||
|
||||
bool BoolAttr::getValue() const { return static_cast<ImplType *>(attr)->value; }
|
||||
|
||||
Type BoolAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// IntegerAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -135,10 +123,6 @@ APInt IntegerAttr::getValue() const {
|
|||
|
||||
int64_t IntegerAttr::getInt() const { return getValue().getSExtValue(); }
|
||||
|
||||
Type IntegerAttr::getType() const {
|
||||
return static_cast<ImplType *>(attr)->type;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// FloatAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -185,8 +169,6 @@ APFloat FloatAttr::getValue() const {
|
|||
return static_cast<ImplType *>(attr)->getValue();
|
||||
}
|
||||
|
||||
Type FloatAttr::getType() const { return static_cast<ImplType *>(attr)->type; }
|
||||
|
||||
double FloatAttr::getValueAsDouble() const {
|
||||
const auto &semantics = getType().cast<FloatType>().getFloatSemantics();
|
||||
auto value = getValue();
|
||||
|
@ -281,14 +263,16 @@ Function *FunctionAttr::getValue() const {
|
|||
return static_cast<ImplType *>(attr)->value;
|
||||
}
|
||||
|
||||
FunctionType FunctionAttr::getType() const { return getValue()->getType(); }
|
||||
FunctionType FunctionAttr::getType() const {
|
||||
return Attribute::getType().cast<FunctionType>();
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ElementsAttr
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
VectorOrTensorType ElementsAttr::getType() const {
|
||||
return static_cast<ImplType *>(attr)->type;
|
||||
return Attribute::getType().cast<VectorOrTensorType>();
|
||||
}
|
||||
|
||||
/// Return the value at the given index. If index does not refer to a valid
|
||||
|
@ -315,8 +299,8 @@ Attribute ElementsAttr::getValue(ArrayRef<uint64_t> index) const {
|
|||
|
||||
SplatElementsAttr SplatElementsAttr::get(VectorOrTensorType type,
|
||||
Attribute elt) {
|
||||
assert(elt.cast<NumericAttr>().getType() == type.getElementType() &&
|
||||
"value should be of the given type");
|
||||
assert(elt.getType() == type.getElementType() &&
|
||||
"value should be of the given element type");
|
||||
return AttributeUniquer::get<SplatElementsAttr>(
|
||||
type.getContext(), Attribute::Kind::SplatElements, type, elt);
|
||||
}
|
||||
|
|
|
@ -920,24 +920,11 @@ void ConstantOp::build(Builder *builder, OperationState *result, Type type,
|
|||
result->types.push_back(type);
|
||||
}
|
||||
|
||||
// Extracts and returns a type of an attribute if it has one. Returns a null
|
||||
// type otherwise. Currently, NumericAttrs and FunctionAttrs have types.
|
||||
static Type getAttributeType(Attribute attr) {
|
||||
assert(attr && "expected non-null attribute");
|
||||
if (auto numericAttr = attr.dyn_cast<NumericAttr>())
|
||||
return numericAttr.getType();
|
||||
if (auto functionAttr = attr.dyn_cast<FunctionAttr>())
|
||||
return functionAttr.getType();
|
||||
return {};
|
||||
}
|
||||
|
||||
/// Builds a constant with the specified attribute value and type extracted
|
||||
/// from the attribute. The attribute must have a type.
|
||||
void ConstantOp::build(Builder *builder, OperationState *result,
|
||||
Attribute value) {
|
||||
Type t = getAttributeType(value);
|
||||
assert(t && "expected an attribute with a type");
|
||||
return build(builder, result, t, value);
|
||||
return build(builder, result, value.getType(), value);
|
||||
}
|
||||
|
||||
void ConstantOp::print(OpAsmPrinter *p) {
|
||||
|
@ -1018,9 +1005,7 @@ LogicalResult ConstantOp::verify() {
|
|||
return success();
|
||||
}
|
||||
|
||||
auto attrType = getAttributeType(value);
|
||||
if (!attrType)
|
||||
return emitOpError("requires 'value' attribute to have a type");
|
||||
auto attrType = value.getType();
|
||||
if (attrType != type)
|
||||
return emitOpError("requires the type of the 'value' attribute to match "
|
||||
"that of the operation result");
|
||||
|
|
Loading…
Reference in New Issue